Skip to content

Commit 2ae813f

Browse files
committed
perf: Optimization of data source file download logic
1 parent e4c8a25 commit 2ae813f

File tree

2 files changed

+101
-3
lines changed

2 files changed

+101
-3
lines changed

apps/application/flow/step_node/tool_lib_node/impl/base_tool_lib_node.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@
66
@date:2024/8/8 17:49
77
@desc:
88
"""
9+
import ast
10+
import io
911
import json
12+
import mimetypes
1013
import time
1114
from typing import Dict
1215

16+
from django.core.files.uploadedfile import InMemoryUploadedFile
1317
from django.db.models import QuerySet
1418
from django.utils.translation import gettext as _
1519

@@ -19,7 +23,8 @@
1923
from common.exception.app_exception import AppApiException
2024
from common.utils.rsa_util import rsa_long_decrypt
2125
from common.utils.tool_code import ToolExecutor
22-
from maxkb.const import CONFIG
26+
from knowledge.models import FileSourceType
27+
from oss.serializers.file import FileSerializer
2328
from tools.models import Tool
2429

2530
function_executor = ToolExecutor()
@@ -126,6 +131,7 @@ def valid_function(tool_lib, workspace_id):
126131
if not tool_lib.is_active:
127132
raise Exception(_("Tool is not active"))
128133

134+
129135
def _filter_file_bytes(data):
130136
"""递归过滤掉所有层级的 file_bytes"""
131137
if isinstance(data, dict):
@@ -136,6 +142,27 @@ def _filter_file_bytes(data):
136142
return data
137143

138144

145+
def bytes_to_uploaded_file(file_bytes, file_name="unknown"):
146+
content_type, _ = mimetypes.guess_type(file_name)
147+
if content_type is None:
148+
# 如果未能识别,设置为默认的二进制文件类型
149+
content_type = "application/octet-stream"
150+
# 创建一个内存中的字节流对象
151+
file_stream = io.BytesIO(file_bytes)
152+
153+
# 获取文件大小
154+
file_size = len(file_bytes)
155+
156+
uploaded_file = InMemoryUploadedFile(
157+
file=file_stream,
158+
field_name=None,
159+
name=file_name,
160+
content_type=content_type,
161+
size=file_size,
162+
charset=None,
163+
)
164+
return uploaded_file
165+
139166

140167
class BaseToolLibNodeNode(IToolLibNode):
141168
def save_context(self, details, workflow_manage):
@@ -168,12 +195,42 @@ def execute(self, tool_lib_id, input_field_list, **kwargs) -> NodeResult:
168195
else:
169196
all_params = init_params_default_value | params
170197
if self.node.properties.get('kind') == 'data-source':
171-
all_params = {**all_params, **self.workflow_params.get('data_source')}
198+
exist = function_executor.exist_function(tool_lib.code, 'get_download_file_list')
199+
if exist:
200+
download_file_list = []
201+
download_list = function_executor.exec_code(tool_lib.code,
202+
{**all_params, **self.workflow_params.get('data_source')},
203+
function_name='get_download_file_list')
204+
for item in download_list:
205+
result = function_executor.exec_code(tool_lib.code,
206+
{**all_params, **self.workflow_params.get('data_source'),
207+
'download_item': item},
208+
function_name='download')
209+
file = bytes_to_uploaded_file(ast.literal_eval(result.get('file_bytes')), result.get('name'))
210+
file_url = self.upload_knowledge_file(file)
211+
download_file_list.append({'file_id': file_url, 'name': result.get('name')})
212+
all_params = {**all_params, **self.workflow_params.get('data_source'),
213+
'download_file_list': download_file_list}
172214
result = function_executor.exec_code(tool_lib.code, all_params)
173215
return NodeResult({'result': result},
174216
(self.workflow_manage.params.get('knowledge_base') or {}) if self.node.properties.get(
175217
'kind') == 'data-source' else {}, _write_context=write_context)
176218

219+
def upload_knowledge_file(self, file):
220+
knowledge_id = self.workflow_params.get('knowledge_id')
221+
meta = {
222+
'debug': False,
223+
'knowledge_id': knowledge_id,
224+
}
225+
file_url = FileSerializer(data={
226+
'file': file,
227+
'meta': meta,
228+
'source_id': knowledge_id,
229+
'source_type': FileSourceType.KNOWLEDGE.value
230+
}).upload().replace("./oss/file/", '')
231+
file.close()
232+
return file_url
233+
177234
def get_details(self, index: int, **kwargs):
178235
result = _filter_file_bytes(self.context.get('result'))
179236

apps/common/utils/tool_code.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,47 @@ def init_sandbox_dir():
7474
except Exception as e:
7575
maxkb_logger.error(f'Exception: {e}', exc_info=True)
7676

77+
def exist_function(self, code_str, name):
78+
_id = str(uuid.uuid7())
79+
python_paths = CONFIG.get_sandbox_python_package_paths().split(',')
80+
set_run_user = f'os.setgid({pwd.getpwnam(_run_user).pw_gid});os.setuid({pwd.getpwnam(_run_user).pw_uid});' if _enable_sandbox else ''
81+
_exec_code = f"""
82+
try:
83+
import os, sys, json
84+
path_to_exclude = ['/opt/py3/lib/python3.11/site-packages', '/opt/maxkb-app/apps']
85+
sys.path = [p for p in sys.path if p not in path_to_exclude]
86+
sys.path += {python_paths}
87+
locals_v={{}}
88+
globals_v={{}}
89+
{set_run_user}
90+
os.environ.clear()
91+
exec({dedent(code_str)!a}, globals_v, locals_v)
92+
exec_result=locals_v.__contains__('{name}')
93+
sys.stdout.write("\\n{_id}:")
94+
json.dump({{'code':200,'msg':'success','data':exec_result}}, sys.stdout, default=str)
95+
except Exception as e:
96+
if isinstance(e, MemoryError): e = Exception("Cannot allocate more memory: exceeded the limit of {_process_limit_mem_mb} MB.")
97+
sys.stdout.write("\\n{_id}:")
98+
json.dump({{'code':500,'msg':str(e),'data':False}}, sys.stdout, default=str)
99+
sys.stdout.flush()
100+
"""
101+
maxkb_logger.debug(f"Sandbox execute code: {_exec_code}")
102+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=True) as f:
103+
f.write(_exec_code)
104+
f.flush()
105+
subprocess_result = self._exec(f.name)
106+
if subprocess_result.returncode != 0:
107+
raise Exception(subprocess_result.stderr or subprocess_result.stdout or "Unknown exception occurred")
108+
lines = subprocess_result.stdout.splitlines()
109+
result_line = [line for line in lines if line.startswith(_id)]
110+
if not result_line:
111+
maxkb_logger.error("\n".join(lines))
112+
raise Exception("No result found.")
113+
result = json.loads(result_line[-1].split(":", 1)[1])
114+
if result.get('code') == 200:
115+
return result.get('data')
116+
raise Exception(result.get('msg'))
117+
77118
def exec_code(self, code_str, keywords, function_name=None):
78119
_id = str(uuid.uuid7())
79120
action_function = f'({function_name !a}, locals_v.get({function_name !a}))' if function_name else 'locals_v.popitem()'
@@ -212,7 +253,7 @@ def get_tool_mcp_config(self, code, params):
212253
],
213254
'cwd': _sandbox_path,
214255
'env': {
215-
'LD_PRELOAD': f'{_sandbox_path}/lib/sandbox.so',
256+
'LD_PRELOAD': f'{_sandbox_path}/lib/sandbox.so',
216257
},
217258
'transport': 'stdio',
218259
}

0 commit comments

Comments
 (0)