Skip to content

Commit a4167ea

Browse files
authored
feat: Knowledge base workflow supports terminating execution (#4535)
1 parent d72c660 commit a4167ea

File tree

14 files changed

+218
-103
lines changed

14 files changed

+218
-103
lines changed

apps/application/flow/i_step_node.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ def get_loop_workflow_node(node_list):
123123

124124

125125
def get_workflow_state(workflow):
126+
if workflow.is_the_task_interrupted():
127+
return State.REVOKED
126128
details = workflow.get_runtime_details()
127129
node_list = details.values()
128130
all_node = [*node_list, *get_loop_workflow_node(node_list)]

apps/application/flow/knowledge_workflow_manage.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ def __init__(self, flow: Workflow,
3030
work_flow_post_handler: WorkFlowPostHandler,
3131
base_to_response: BaseToResponse = SystemToResponse(),
3232
start_node_id=None,
33-
start_node_data=None, chat_record=None, child_node=None):
33+
start_node_data=None, chat_record=None, child_node=None, is_the_task_interrupted=lambda: False):
3434
super().__init__(flow, params, work_flow_post_handler, base_to_response, None, None, None,
3535
None,
36-
None, None, start_node_id, start_node_data, chat_record, child_node)
36+
None, None, start_node_id, start_node_data, chat_record, child_node, is_the_task_interrupted)
3737

3838
def get_params_serializer_class(self):
3939
return KnowledgeFlowParamsSerializer
@@ -91,6 +91,9 @@ def hand_node_result(self, current_node, node_result_future):
9191
list(result)
9292
if current_node.status == 500:
9393
return None
94+
if self.is_the_task_interrupted():
95+
current_node.status = 201
96+
return None
9497
return current_result
9598
except Exception as e:
9699
traceback.print_exc()

apps/application/flow/loop_workflow_manage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,14 @@ def __init__(self, flow: Workflow,
9292
get_loop_context,
9393
base_to_response: BaseToResponse = SystemToResponse(),
9494
start_node_id=None,
95-
start_node_data=None, chat_record=None, child_node=None):
95+
start_node_data=None, chat_record=None, child_node=None, is_the_task_interrupted=lambda: False):
9696
self.parentWorkflowManage = parentWorkflowManage
9797
self.loop_params = loop_params
9898
self.get_loop_context = get_loop_context
9999
self.loop_field_list = []
100100
super().__init__(flow, params, work_flow_post_handler, base_to_response, None, None, None,
101101
None,
102-
None, None, start_node_id, start_node_data, chat_record, child_node)
102+
None, None, start_node_id, start_node_data, chat_record, child_node, is_the_task_interrupted)
103103

104104
def get_node_cls_by_id(self, node_id, up_node_id_list=None,
105105
get_node_params=lambda node: node.properties.get('node_data')):

apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
6363
response_reasoning_content = False
6464

6565
for chunk in response:
66+
if workflow.is_the_task_interrupted():
67+
break
6668
reasoning_chunk = reasoning.get_reasoning_content(chunk)
6769
content_chunk = reasoning_chunk.get('content')
6870
if 'reasoning_content' in chunk.additional_kwargs:
@@ -110,7 +112,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
110112
if 'reasoning_content' in meta:
111113
reasoning_content = (meta.get('reasoning_content', '') or '')
112114
else:
113-
reasoning_content = (reasoning_result.get('reasoning_content') or '') + (reasoning_result_end.get('reasoning_content') or '')
115+
reasoning_content = (reasoning_result.get('reasoning_content') or '') + (
116+
reasoning_result_end.get('reasoning_content') or '')
114117
_write_context(node_variable, workflow_variable, node, workflow, content, reasoning_content)
115118

116119

apps/application/flow/step_node/loop_node/impl/base_loop_node.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,8 @@ def workflow_manage_new_instance(loop_data, global_data, start_node_id=None,
268268
start_node_id=start_node_id,
269269
start_node_data=start_node_data,
270270
chat_record=chat_record,
271-
child_node=child_node
271+
child_node=child_node,
272+
is_the_task_interrupted=self.workflow_manage.is_the_task_interrupted
272273
)
273274

274275
return workflow_manage

apps/application/flow/workflow_manage.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(self, flow: Workflow, params, work_flow_post_handler: WorkFlowPostH
9797
video_list=None,
9898
other_list=None,
9999
start_node_id=None,
100-
start_node_data=None, chat_record=None, child_node=None):
100+
start_node_data=None, chat_record=None, child_node=None, is_the_task_interrupted=lambda: False):
101101
if form_data is None:
102102
form_data = {}
103103
if image_list is None:
@@ -138,6 +138,7 @@ def __init__(self, flow: Workflow, params, work_flow_post_handler: WorkFlowPostH
138138
self.global_field_list = []
139139
self.chat_field_list = []
140140
self.init_fields()
141+
self.is_the_task_interrupted = is_the_task_interrupted
141142
if start_node_id is not None:
142143
self.load_node(chat_record, start_node_id, start_node_data)
143144
else:

apps/common/constants/cache_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class Cache_Version(Enum):
2626
SYSTEM = "SYSTEM", lambda key: key
2727
# 应用对接三方应用的缓存
2828
APPLICATION_THIRD_PARTY = "APPLICATION:THIRD_PARTY", lambda key: key
29-
29+
KNOWLEDGE_WORKFLOW_INTERRUPTED = "KNOWLEDGE_WORKFLOW_INTERRUPTED", lambda action_id: action_id
3030
# 对话
3131
CHAT = "CHAT", lambda key: key
3232

apps/common/utils/tool_code.py

Lines changed: 99 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,41 @@
77
import socket
88
import subprocess
99
import sys
10-
import tempfile
11-
import pwd
12-
import resource
13-
import getpass
14-
import random
10+
import signal
1511
import time
1612
import uuid_utils.compat as uuid
17-
from contextlib import contextmanager
1813
from common.utils.logger import maxkb_logger
1914
from django.utils.translation import gettext_lazy as _
2015
from maxkb.const import BASE_DIR, CONFIG
2116
from maxkb.const import PROJECT_DIR
2217
from textwrap import dedent
2318

24-
_enable_sandbox = bool(CONFIG.get('SANDBOX', 0))
25-
_run_user = 'sandbox' if _enable_sandbox else getpass.getuser()
26-
_sandbox_path = CONFIG.get("SANDBOX_HOME", '/opt/maxkb-app/sandbox') if _enable_sandbox else os.path.join(PROJECT_DIR, 'data', 'sandbox')
27-
_process_limit_timeout_seconds = int(CONFIG.get("SANDBOX_PYTHON_PROCESS_LIMIT_TIMEOUT_SECONDS", '3600'))
28-
_process_limit_cpu_cores = min(max(int(CONFIG.get("SANDBOX_PYTHON_PROCESS_LIMIT_CPU_CORES", '1')), 1), len(os.sched_getaffinity(0))) if sys.platform.startswith("linux") else os.cpu_count() # 只支持linux,window和mac不支持
29-
_process_limit_mem_mb = int(CONFIG.get("SANDBOX_PYTHON_PROCESS_LIMIT_MEM_MB", '256'))
19+
python_directory = sys.executable
20+
3021

3122
class ToolExecutor:
3223

33-
def __init__(self):
34-
pass
24+
def __init__(self, sandbox=False):
25+
self.sandbox = sandbox
26+
if sandbox:
27+
self.sandbox_path = CONFIG.get("SANDBOX_HOME", '/opt/maxkb-app/sandbox')
28+
self.user = 'sandbox'
29+
else:
30+
self.sandbox_path = os.path.join(PROJECT_DIR, 'data', 'sandbox')
31+
self.user = None
32+
self.sandbox_so_path = f'{self.sandbox_path}/lib/sandbox.so'
33+
self.process_timeout_seconds = int(CONFIG.get("SANDBOX_PYTHON_PROCESS_TIMEOUT_SECONDS", '3600'))
34+
try:
35+
self._init_sandbox_dir()
36+
except Exception as e:
37+
# 本机忽略异常,容器内不忽略
38+
maxkb_logger.error(f'Exception: {e}', exc_info=True)
39+
if self.sandbox:
40+
raise e
3541

36-
@staticmethod
37-
def init_sandbox_dir():
38-
if not _enable_sandbox:
39-
# 不启用sandbox就不初始化目录
42+
def _init_sandbox_dir(self):
43+
if not self.sandbox:
44+
# 不是sandbox就不初始化目录
4045
return
4146
try:
4247
# 只初始化一次
@@ -46,7 +51,7 @@ def init_sandbox_dir():
4651
except FileExistsError:
4752
# 文件已存在 → 已初始化过
4853
return
49-
maxkb_logger.info("Init sandbox dir.")
54+
maxkb_logger.debug("init dir")
5055
try:
5156
os.system("chmod -R g-rwx /dev/shm /dev/mqueue")
5257
os.system("chmod o-rwx /run/postgresql")
@@ -56,7 +61,7 @@ def init_sandbox_dir():
5661
if CONFIG.get("SANDBOX_TMP_DIR_ENABLED", '0') == "1":
5762
os.system("chmod g+rwx /tmp")
5863
# 初始化sandbox配置文件
59-
sandbox_lib_path = os.path.dirname(f'{_sandbox_path}/lib/sandbox.so')
64+
sandbox_lib_path = os.path.dirname(self.sandbox_so_path)
6065
sandbox_conf_file_path = f'{sandbox_lib_path}/.sandbox.conf'
6166
if os.path.exists(sandbox_conf_file_path):
6267
os.remove(sandbox_conf_file_path)
@@ -69,60 +74,48 @@ def init_sandbox_dir():
6974
with open(sandbox_conf_file_path, "w") as f:
7075
f.write(f"SANDBOX_PYTHON_BANNED_HOSTS={banned_hosts}\n")
7176
f.write(f"SANDBOX_PYTHON_ALLOW_SUBPROCESS={allow_subprocess}\n")
72-
os.system(f"chmod -R 550 {_sandbox_path}")
73-
74-
try:
75-
init_sandbox_dir()
76-
except Exception as e:
77-
maxkb_logger.error(f'Exception: {e}', exc_info=True)
77+
os.system(f"chmod -R 550 {self.sandbox_path}")
7878

7979
def exec_code(self, code_str, keywords, function_name=None):
8080
_id = str(uuid.uuid7())
81+
success = '{"code":200,"msg":"成功","data":exec_result}'
82+
err = '{"code":500,"msg":str(e),"data":None}'
8183
action_function = f'({function_name !a}, locals_v.get({function_name !a}))' if function_name else 'locals_v.popitem()'
8284
python_paths = CONFIG.get_sandbox_python_package_paths().split(',')
83-
set_run_user = f'os.setgid({pwd.getpwnam(_run_user).pw_gid});os.setuid({pwd.getpwnam(_run_user).pw_uid});' if _enable_sandbox else ''
8485
_exec_code = f"""
8586
try:
86-
import os, sys, json
87-
from contextlib import redirect_stdout
87+
import os, sys, json, base64, builtins
8888
path_to_exclude = ['/opt/py3/lib/python3.11/site-packages', '/opt/maxkb-app/apps']
8989
sys.path = [p for p in sys.path if p not in path_to_exclude]
9090
sys.path += {python_paths}
91-
locals_v={{}}
91+
locals_v={'{}'}
9292
keywords={keywords}
93-
globals_v={{}}
94-
{set_run_user}
93+
globals_v={'{}'}
9594
os.environ.clear()
96-
with redirect_stdout(open(os.devnull, 'w')):
97-
exec({dedent(code_str)!a}, globals_v, locals_v)
98-
f_name, f = {action_function}
99-
globals_v.update(locals_v)
100-
exec_result=f(**keywords)
101-
sys.stdout.write("\\n{_id}:")
102-
json.dump({{'code':200,'msg':'success','data':exec_result}}, sys.stdout, default=str)
95+
exec({dedent(code_str)!a}, globals_v, locals_v)
96+
f_name, f = {action_function}
97+
for local in locals_v:
98+
globals_v[local] = locals_v[local]
99+
exec_result=f(**keywords)
100+
builtins.print("\\n{_id}:"+base64.b64encode(json.dumps({success}, default=str).encode()).decode())
103101
except Exception as e:
104-
if isinstance(e, MemoryError): e = Exception("Cannot allocate more memory: exceeded the limit of {_process_limit_mem_mb} MB.")
105-
sys.stdout.write("\\n{_id}:")
106-
json.dump({{'code':500,'msg':str(e),'data':None}}, sys.stdout, default=str)
107-
sys.stdout.flush()
102+
builtins.print("\\n{_id}:"+base64.b64encode(json.dumps({err}, default=str).encode()).decode())
108103
"""
109-
maxkb_logger.debug(f"Sandbox execute code: {_exec_code}")
110-
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=True) as f:
111-
f.write(_exec_code)
112-
f.flush()
113-
with execution_timer(_id):
114-
subprocess_result = self._exec(f.name)
104+
if self.sandbox:
105+
subprocess_result = self._exec_sandbox(_exec_code)
106+
else:
107+
subprocess_result = self._exec(_exec_code)
115108
if subprocess_result.returncode != 0:
116109
raise Exception(subprocess_result.stderr or subprocess_result.stdout or "Unknown exception occurred")
117110
lines = subprocess_result.stdout.splitlines()
118111
result_line = [line for line in lines if line.startswith(_id)]
119112
if not result_line:
120113
maxkb_logger.error("\n".join(lines))
121114
raise Exception("No result found.")
122-
result = json.loads(result_line[-1].split(":", 1)[1])
115+
result = json.loads(base64.b64decode(result_line[-1].split(":", 1)[1]).decode())
123116
if result.get('code') == 200:
124117
return result.get('data')
125-
raise Exception(result.get('msg') + (f'\n{subprocess_result.stderr}' if subprocess_result.stderr else ''))
118+
raise Exception(result.get('msg'))
126119

127120
def _generate_mcp_server_code(self, _code, params):
128121
# 解析代码,提取导入语句和函数定义
@@ -190,7 +183,6 @@ def _generate_mcp_server_code(self, _code, params):
190183
def generate_mcp_server_code(self, code_str, params):
191184
python_paths = CONFIG.get_sandbox_python_package_paths().split(',')
192185
code = self._generate_mcp_server_code(code_str, params)
193-
set_run_user = f'os.setgid({pwd.getpwnam(_run_user).pw_gid});os.setuid({pwd.getpwnam(_run_user).pw_uid});' if _enable_sandbox else ''
194186
return f"""
195187
import os, sys, logging
196188
logging.basicConfig(level=logging.WARNING)
@@ -199,7 +191,6 @@ def generate_mcp_server_code(self, code_str, params):
199191
path_to_exclude = ['/opt/py3/lib/python3.11/site-packages', '/opt/maxkb-app/apps']
200192
sys.path = [p for p in sys.path if p not in path_to_exclude]
201193
sys.path += {python_paths}
202-
{set_run_user}
203194
os.environ.clear()
204195
exec({dedent(code)!a})
205196
"""
@@ -208,51 +199,74 @@ def get_tool_mcp_config(self, code, params):
208199
_code = self.generate_mcp_server_code(code, params)
209200
maxkb_logger.debug(f"Python code of mcp tool: {_code}")
210201
compressed_and_base64_encoded_code_str = base64.b64encode(gzip.compress(_code.encode())).decode()
211-
tool_config = {
212-
'command': sys.executable,
213-
'args': [
214-
'-c',
215-
f'import base64,gzip; exec(gzip.decompress(base64.b64decode(\'{compressed_and_base64_encoded_code_str}\')).decode())',
216-
],
217-
'cwd': _sandbox_path,
218-
'env': {
219-
'LD_PRELOAD': f'{_sandbox_path}/lib/sandbox.so',
220-
},
221-
'transport': 'stdio',
222-
}
202+
if self.sandbox:
203+
tool_config = {
204+
'command': 'su',
205+
'args': [
206+
'-s', sys.executable,
207+
'-c',
208+
f'import base64,gzip; exec(gzip.decompress(base64.b64decode(\'{compressed_and_base64_encoded_code_str}\')).decode())',
209+
self.user,
210+
],
211+
'cwd': self.sandbox_path,
212+
'env': {
213+
'LD_PRELOAD': self.sandbox_so_path,
214+
},
215+
'transport': 'stdio',
216+
}
217+
else:
218+
tool_config = {
219+
'command': sys.executable,
220+
'args': [
221+
'-c',
222+
f'import base64,gzip; exec(gzip.decompress(base64.b64decode(\'{compressed_and_base64_encoded_code_str}\')).decode())',
223+
],
224+
'transport': 'stdio',
225+
}
223226
return tool_config
224227

225-
def _exec(self, execute_file):
228+
def _exec_sandbox(self, _code):
226229
kwargs = {'cwd': BASE_DIR, 'env': {
227-
'LD_PRELOAD': f'{_sandbox_path}/lib/sandbox.so',
230+
'LD_PRELOAD': self.sandbox_so_path,
228231
}}
232+
maxkb_logger.debug(f"Sandbox execute code: {_code}")
233+
compressed_and_base64_encoded_code_str = base64.b64encode(gzip.compress(_code.encode())).decode()
234+
cmd = [
235+
'su', '-s', python_directory, '-c',
236+
f'import base64,gzip; exec(gzip.decompress(base64.b64decode(\'{compressed_and_base64_encoded_code_str}\')).decode())',
237+
self.user
238+
]
229239
try:
230-
subprocess_result = subprocess.run(
231-
[sys.executable, execute_file],
232-
timeout=_process_limit_timeout_seconds,
240+
proc = subprocess.Popen(
241+
cmd,
242+
stdout=subprocess.PIPE,
243+
stderr=subprocess.PIPE,
233244
text=True,
234-
capture_output=True,
235245
**kwargs,
236-
preexec_fn=(lambda: None if (not _enable_sandbox or not sys.platform.startswith("linux")) else (
237-
resource.setrlimit(resource.RLIMIT_AS, (_process_limit_mem_mb * 1024 * 1024,) * 2),
238-
os.sched_setaffinity(0, set(random.sample(list(os.sched_getaffinity(0)), _process_limit_cpu_cores)))
239-
))
246+
start_new_session=True
247+
)
248+
proc.wait(timeout=self.process_timeout_seconds)
249+
return subprocess.CompletedProcess(
250+
proc.args,
251+
proc.returncode,
252+
proc.stdout.read(),
253+
proc.stderr.read()
240254
)
241-
return subprocess_result
242255
except subprocess.TimeoutExpired:
243-
raise Exception(_(f"Process execution timed out after {_process_limit_timeout_seconds} seconds."))
256+
pgid = os.getpgid(proc.pid)
257+
os.killpg(pgid, signal.SIGTERM) #温和终止
258+
time.sleep(1) #留出短暂时间让进程清理
259+
if proc.poll() is None: #如果仍未终止,强制终止
260+
os.killpg(pgid, signal.SIGKILL)
261+
proc.wait()
262+
raise Exception(_(f"Process execution timed out after {self.process_timeout_seconds} seconds."))
244263

245264
def validate_mcp_transport(self, code_str):
246265
servers = json.loads(code_str)
247266
for server, config in servers.items():
248267
if config.get('transport') not in ['sse', 'streamable_http']:
249268
raise Exception(_('Only support transport=sse or transport=streamable_http'))
250269

251-
252-
@contextmanager
253-
def execution_timer(id=""):
254-
start = time.perf_counter()
255-
try:
256-
yield
257-
finally:
258-
maxkb_logger.debug(f"Tool execution({id}) takes {time.perf_counter() - start:.6f} seconds.")
270+
@staticmethod
271+
def _exec(_code):
272+
return subprocess.run([python_directory, '-c', _code], text=True, capture_output=True)

0 commit comments

Comments
 (0)