Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions apps/application/flow/i_step_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def get_loop_workflow_node(node_list):


def get_workflow_state(workflow):
if workflow.is_the_task_interrupted():
return State.REVOKED
details = workflow.get_runtime_details()
node_list = details.values()
all_node = [*node_list, *get_loop_workflow_node(node_list)]
Expand Down
7 changes: 5 additions & 2 deletions apps/application/flow/knowledge_workflow_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ def __init__(self, flow: Workflow,
work_flow_post_handler: WorkFlowPostHandler,
base_to_response: BaseToResponse = SystemToResponse(),
start_node_id=None,
start_node_data=None, chat_record=None, child_node=None):
start_node_data=None, chat_record=None, child_node=None, is_the_task_interrupted=lambda: False):
super().__init__(flow, params, work_flow_post_handler, base_to_response, None, None, None,
None,
None, None, start_node_id, start_node_data, chat_record, child_node)
None, None, start_node_id, start_node_data, chat_record, child_node, is_the_task_interrupted)

def get_params_serializer_class(self):
return KnowledgeFlowParamsSerializer
Expand Down Expand Up @@ -91,6 +91,9 @@ def hand_node_result(self, current_node, node_result_future):
list(result)
if current_node.status == 500:
return None
if self.is_the_task_interrupted():
current_node.status = 201
return None
return current_result
except Exception as e:
traceback.print_exc()
Expand Down
4 changes: 2 additions & 2 deletions apps/application/flow/loop_workflow_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ def __init__(self, flow: Workflow,
get_loop_context,
base_to_response: BaseToResponse = SystemToResponse(),
start_node_id=None,
start_node_data=None, chat_record=None, child_node=None):
start_node_data=None, chat_record=None, child_node=None, is_the_task_interrupted=lambda: False):
self.parentWorkflowManage = parentWorkflowManage
self.loop_params = loop_params
self.get_loop_context = get_loop_context
self.loop_field_list = []
super().__init__(flow, params, work_flow_post_handler, base_to_response, None, None, None,
None,
None, None, start_node_id, start_node_data, chat_record, child_node)
None, None, start_node_id, start_node_data, chat_record, child_node, is_the_task_interrupted)

def get_node_cls_by_id(self, node_id, up_node_id_list=None,
get_node_params=lambda node: node.properties.get('node_data')):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
response_reasoning_content = False

for chunk in response:
if workflow.is_the_task_interrupted():
break
reasoning_chunk = reasoning.get_reasoning_content(chunk)
content_chunk = reasoning_chunk.get('content')
if 'reasoning_content' in chunk.additional_kwargs:
Expand Down Expand Up @@ -110,7 +112,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
if 'reasoning_content' in meta:
reasoning_content = (meta.get('reasoning_content', '') or '')
else:
reasoning_content = (reasoning_result.get('reasoning_content') or '') + (reasoning_result_end.get('reasoning_content') or '')
reasoning_content = (reasoning_result.get('reasoning_content') or '') + (
reasoning_result_end.get('reasoning_content') or '')
_write_context(node_variable, workflow_variable, node, workflow, content, reasoning_content)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ def workflow_manage_new_instance(loop_data, global_data, start_node_id=None,
start_node_id=start_node_id,
start_node_data=start_node_data,
chat_record=chat_record,
child_node=child_node
child_node=child_node,
is_the_task_interrupted=self.workflow_manage.is_the_task_interrupted
)

return workflow_manage
Expand Down
3 changes: 2 additions & 1 deletion apps/application/flow/workflow_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(self, flow: Workflow, params, work_flow_post_handler: WorkFlowPostH
video_list=None,
other_list=None,
start_node_id=None,
start_node_data=None, chat_record=None, child_node=None):
start_node_data=None, chat_record=None, child_node=None, is_the_task_interrupted=lambda: False):
if form_data is None:
form_data = {}
if image_list is None:
Expand Down Expand Up @@ -138,6 +138,7 @@ def __init__(self, flow: Workflow, params, work_flow_post_handler: WorkFlowPostH
self.global_field_list = []
self.chat_field_list = []
self.init_fields()
self.is_the_task_interrupted = is_the_task_interrupted
if start_node_id is not None:
self.load_node(chat_record, start_node_id, start_node_data)
else:
Expand Down
2 changes: 1 addition & 1 deletion apps/common/constants/cache_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Cache_Version(Enum):
SYSTEM = "SYSTEM", lambda key: key
# 应用对接三方应用的缓存
APPLICATION_THIRD_PARTY = "APPLICATION:THIRD_PARTY", lambda key: key

KNOWLEDGE_WORKFLOW_INTERRUPTED = "KNOWLEDGE_WORKFLOW_INTERRUPTED", lambda action_id: action_id
# 对话
CHAT = "CHAT", lambda key: key

Expand Down
184 changes: 99 additions & 85 deletions apps/common/utils/tool_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,41 @@
import socket
import subprocess
import sys
import tempfile
import pwd
import resource
import getpass
import random
import signal
import time
import uuid_utils.compat as uuid
from contextlib import contextmanager
from common.utils.logger import maxkb_logger
from django.utils.translation import gettext_lazy as _
from maxkb.const import BASE_DIR, CONFIG
from maxkb.const import PROJECT_DIR
from textwrap import dedent

_enable_sandbox = bool(CONFIG.get('SANDBOX', 0))
_run_user = 'sandbox' if _enable_sandbox else getpass.getuser()
_sandbox_path = CONFIG.get("SANDBOX_HOME", '/opt/maxkb-app/sandbox') if _enable_sandbox else os.path.join(PROJECT_DIR, 'data', 'sandbox')
_process_limit_timeout_seconds = int(CONFIG.get("SANDBOX_PYTHON_PROCESS_LIMIT_TIMEOUT_SECONDS", '3600'))
_process_limit_cpu_cores = min(max(int(CONFIG.get("SANDBOX_PYTHON_PROCESS_LIMIT_CPU_CORES", '1')), 1), len(os.sched_getaffinity(0))) if sys.platform.startswith("linux") else os.cpu_count() # 只支持linux,window和mac不支持
_process_limit_mem_mb = int(CONFIG.get("SANDBOX_PYTHON_PROCESS_LIMIT_MEM_MB", '256'))
python_directory = sys.executable


class ToolExecutor:

def __init__(self):
pass
def __init__(self, sandbox=False):
self.sandbox = sandbox
if sandbox:
self.sandbox_path = CONFIG.get("SANDBOX_HOME", '/opt/maxkb-app/sandbox')
self.user = 'sandbox'
else:
self.sandbox_path = os.path.join(PROJECT_DIR, 'data', 'sandbox')
self.user = None
self.sandbox_so_path = f'{self.sandbox_path}/lib/sandbox.so'
self.process_timeout_seconds = int(CONFIG.get("SANDBOX_PYTHON_PROCESS_TIMEOUT_SECONDS", '3600'))
try:
self._init_sandbox_dir()
except Exception as e:
# 本机忽略异常,容器内不忽略
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if self.sandbox:
raise e

@staticmethod
def init_sandbox_dir():
if not _enable_sandbox:
# 不启用sandbox就不初始化目录
def _init_sandbox_dir(self):
if not self.sandbox:
# 不是sandbox就不初始化目录
return
try:
# 只初始化一次
Expand All @@ -46,7 +51,7 @@ def init_sandbox_dir():
except FileExistsError:
# 文件已存在 → 已初始化过
return
maxkb_logger.info("Init sandbox dir.")
maxkb_logger.debug("init dir")
try:
os.system("chmod -R g-rwx /dev/shm /dev/mqueue")
os.system("chmod o-rwx /run/postgresql")
Expand All @@ -56,7 +61,7 @@ def init_sandbox_dir():
if CONFIG.get("SANDBOX_TMP_DIR_ENABLED", '0') == "1":
os.system("chmod g+rwx /tmp")
# 初始化sandbox配置文件
sandbox_lib_path = os.path.dirname(f'{_sandbox_path}/lib/sandbox.so')
sandbox_lib_path = os.path.dirname(self.sandbox_so_path)
sandbox_conf_file_path = f'{sandbox_lib_path}/.sandbox.conf'
if os.path.exists(sandbox_conf_file_path):
os.remove(sandbox_conf_file_path)
Expand All @@ -69,60 +74,48 @@ def init_sandbox_dir():
with open(sandbox_conf_file_path, "w") as f:
f.write(f"SANDBOX_PYTHON_BANNED_HOSTS={banned_hosts}\n")
f.write(f"SANDBOX_PYTHON_ALLOW_SUBPROCESS={allow_subprocess}\n")
os.system(f"chmod -R 550 {_sandbox_path}")

try:
init_sandbox_dir()
except Exception as e:
maxkb_logger.error(f'Exception: {e}', exc_info=True)
os.system(f"chmod -R 550 {self.sandbox_path}")

def exec_code(self, code_str, keywords, function_name=None):
_id = str(uuid.uuid7())
success = '{"code":200,"msg":"成功","data":exec_result}'
err = '{"code":500,"msg":str(e),"data":None}'
action_function = f'({function_name !a}, locals_v.get({function_name !a}))' if function_name else 'locals_v.popitem()'
python_paths = CONFIG.get_sandbox_python_package_paths().split(',')
set_run_user = f'os.setgid({pwd.getpwnam(_run_user).pw_gid});os.setuid({pwd.getpwnam(_run_user).pw_uid});' if _enable_sandbox else ''
_exec_code = f"""
try:
import os, sys, json
from contextlib import redirect_stdout
import os, sys, json, base64, builtins
path_to_exclude = ['/opt/py3/lib/python3.11/site-packages', '/opt/maxkb-app/apps']
sys.path = [p for p in sys.path if p not in path_to_exclude]
sys.path += {python_paths}
locals_v={{}}
locals_v={'{}'}
keywords={keywords}
globals_v={{}}
{set_run_user}
globals_v={'{}'}
os.environ.clear()
with redirect_stdout(open(os.devnull, 'w')):
exec({dedent(code_str)!a}, globals_v, locals_v)
f_name, f = {action_function}
globals_v.update(locals_v)
exec_result=f(**keywords)
sys.stdout.write("\\n{_id}:")
json.dump({{'code':200,'msg':'success','data':exec_result}}, sys.stdout, default=str)
exec({dedent(code_str)!a}, globals_v, locals_v)
f_name, f = {action_function}
for local in locals_v:
globals_v[local] = locals_v[local]
exec_result=f(**keywords)
builtins.print("\\n{_id}:"+base64.b64encode(json.dumps({success}, default=str).encode()).decode())
except Exception as e:
if isinstance(e, MemoryError): e = Exception("Cannot allocate more memory: exceeded the limit of {_process_limit_mem_mb} MB.")
sys.stdout.write("\\n{_id}:")
json.dump({{'code':500,'msg':str(e),'data':None}}, sys.stdout, default=str)
sys.stdout.flush()
builtins.print("\\n{_id}:"+base64.b64encode(json.dumps({err}, default=str).encode()).decode())
"""
maxkb_logger.debug(f"Sandbox execute code: {_exec_code}")
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=True) as f:
f.write(_exec_code)
f.flush()
with execution_timer(_id):
subprocess_result = self._exec(f.name)
if self.sandbox:
subprocess_result = self._exec_sandbox(_exec_code)
else:
subprocess_result = self._exec(_exec_code)
if subprocess_result.returncode != 0:
raise Exception(subprocess_result.stderr or subprocess_result.stdout or "Unknown exception occurred")
lines = subprocess_result.stdout.splitlines()
result_line = [line for line in lines if line.startswith(_id)]
if not result_line:
maxkb_logger.error("\n".join(lines))
raise Exception("No result found.")
result = json.loads(result_line[-1].split(":", 1)[1])
result = json.loads(base64.b64decode(result_line[-1].split(":", 1)[1]).decode())
if result.get('code') == 200:
return result.get('data')
raise Exception(result.get('msg') + (f'\n{subprocess_result.stderr}' if subprocess_result.stderr else ''))
raise Exception(result.get('msg'))

def _generate_mcp_server_code(self, _code, params):
# 解析代码,提取导入语句和函数定义
Expand Down Expand Up @@ -190,7 +183,6 @@ def _generate_mcp_server_code(self, _code, params):
def generate_mcp_server_code(self, code_str, params):
python_paths = CONFIG.get_sandbox_python_package_paths().split(',')
code = self._generate_mcp_server_code(code_str, params)
set_run_user = f'os.setgid({pwd.getpwnam(_run_user).pw_gid});os.setuid({pwd.getpwnam(_run_user).pw_uid});' if _enable_sandbox else ''
return f"""
import os, sys, logging
logging.basicConfig(level=logging.WARNING)
Expand All @@ -199,7 +191,6 @@ def generate_mcp_server_code(self, code_str, params):
path_to_exclude = ['/opt/py3/lib/python3.11/site-packages', '/opt/maxkb-app/apps']
sys.path = [p for p in sys.path if p not in path_to_exclude]
sys.path += {python_paths}
{set_run_user}
os.environ.clear()
exec({dedent(code)!a})
"""
Expand All @@ -208,51 +199,74 @@ def get_tool_mcp_config(self, code, params):
_code = self.generate_mcp_server_code(code, params)
maxkb_logger.debug(f"Python code of mcp tool: {_code}")
compressed_and_base64_encoded_code_str = base64.b64encode(gzip.compress(_code.encode())).decode()
tool_config = {
'command': sys.executable,
'args': [
'-c',
f'import base64,gzip; exec(gzip.decompress(base64.b64decode(\'{compressed_and_base64_encoded_code_str}\')).decode())',
],
'cwd': _sandbox_path,
'env': {
'LD_PRELOAD': f'{_sandbox_path}/lib/sandbox.so',
},
'transport': 'stdio',
}
if self.sandbox:
tool_config = {
'command': 'su',
'args': [
'-s', sys.executable,
'-c',
f'import base64,gzip; exec(gzip.decompress(base64.b64decode(\'{compressed_and_base64_encoded_code_str}\')).decode())',
self.user,
],
'cwd': self.sandbox_path,
'env': {
'LD_PRELOAD': self.sandbox_so_path,
},
'transport': 'stdio',
}
else:
tool_config = {
'command': sys.executable,
'args': [
'-c',
f'import base64,gzip; exec(gzip.decompress(base64.b64decode(\'{compressed_and_base64_encoded_code_str}\')).decode())',
],
'transport': 'stdio',
}
return tool_config

def _exec(self, execute_file):
def _exec_sandbox(self, _code):
kwargs = {'cwd': BASE_DIR, 'env': {
'LD_PRELOAD': f'{_sandbox_path}/lib/sandbox.so',
'LD_PRELOAD': self.sandbox_so_path,
}}
maxkb_logger.debug(f"Sandbox execute code: {_code}")
compressed_and_base64_encoded_code_str = base64.b64encode(gzip.compress(_code.encode())).decode()
cmd = [
'su', '-s', python_directory, '-c',
f'import base64,gzip; exec(gzip.decompress(base64.b64decode(\'{compressed_and_base64_encoded_code_str}\')).decode())',
self.user
]
try:
subprocess_result = subprocess.run(
[sys.executable, execute_file],
timeout=_process_limit_timeout_seconds,
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
capture_output=True,
**kwargs,
preexec_fn=(lambda: None if (not _enable_sandbox or not sys.platform.startswith("linux")) else (
resource.setrlimit(resource.RLIMIT_AS, (_process_limit_mem_mb * 1024 * 1024,) * 2),
os.sched_setaffinity(0, set(random.sample(list(os.sched_getaffinity(0)), _process_limit_cpu_cores)))
))
start_new_session=True
)
proc.wait(timeout=self.process_timeout_seconds)
return subprocess.CompletedProcess(
proc.args,
proc.returncode,
proc.stdout.read(),
proc.stderr.read()
)
return subprocess_result
except subprocess.TimeoutExpired:
raise Exception(_(f"Process execution timed out after {_process_limit_timeout_seconds} seconds."))
pgid = os.getpgid(proc.pid)
os.killpg(pgid, signal.SIGTERM) #温和终止
time.sleep(1) #留出短暂时间让进程清理
if proc.poll() is None: #如果仍未终止,强制终止
os.killpg(pgid, signal.SIGKILL)
proc.wait()
raise Exception(_(f"Process execution timed out after {self.process_timeout_seconds} seconds."))

def validate_mcp_transport(self, code_str):
servers = json.loads(code_str)
for server, config in servers.items():
if config.get('transport') not in ['sse', 'streamable_http']:
raise Exception(_('Only support transport=sse or transport=streamable_http'))


@contextmanager
def execution_timer(id=""):
start = time.perf_counter()
try:
yield
finally:
maxkb_logger.debug(f"Tool execution({id}) takes {time.perf_counter() - start:.6f} seconds.")
@staticmethod
def _exec(_code):
return subprocess.run([python_directory, '-c', _code], text=True, capture_output=True)
Loading
Loading