Skip to content

Commit 32c18ab

Browse files
committed
feat: add file upload functionality and enhance data source handling
1 parent eb8dbdb commit 32c18ab

File tree

2 files changed

+67
-25
lines changed

2 files changed

+67
-25
lines changed

apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -57,29 +57,6 @@ def save_image(image_list):
5757

5858
document_list = []
5959
for doc in document:
60-
if 'file_bytes' in doc:
61-
file_bytes = doc['file_bytes']
62-
# 如果是字符串,转换为字节
63-
if isinstance(file_bytes, str):
64-
file_bytes = ast.literal_eval(file_bytes)
65-
doc['file_id'] = doc.get('file_id') or uuid.uuid7()
66-
meta = {
67-
'debug': False if (application_id or knowledge_id) else True,
68-
'chat_id': chat_id,
69-
'application_id': str(application_id) if application_id else None,
70-
'knowledge_id': str(knowledge_id) if knowledge_id else None,
71-
'file_id': str(doc['file_id'])
72-
}
73-
new_file = File(
74-
id=doc['file_id'],
75-
file_name=doc['name'],
76-
file_size=len(file_bytes),
77-
source_type=FileSourceType.APPLICATION.value if meta[
78-
'application_id'] else FileSourceType.KNOWLEDGE.value,
79-
source_id=meta['application_id'] if meta['application_id'] else meta['knowledge_id'],
80-
meta={}
81-
)
82-
new_file.save(file_bytes)
8360
file = QuerySet(File).filter(id=doc['file_id']).first()
8461
buffer = io.BytesIO(file.get_bytes())
8562
buffer.name = doc['name'] # this is the important line

apps/application/flow/step_node/tool_lib_node/impl/base_tool_lib_node.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@
66
@date:2024/8/8 17:49
77
@desc:
88
"""
9+
import base64
10+
import io
911
import json
12+
import mimetypes
1013
import time
1114
from typing import Dict
1215

16+
from django.core.files.uploadedfile import InMemoryUploadedFile
1317
from django.db.models import QuerySet
1418
from django.utils.translation import gettext as _
1519

@@ -19,7 +23,8 @@
1923
from common.exception.app_exception import AppApiException
2024
from common.utils.rsa_util import rsa_long_decrypt
2125
from common.utils.tool_code import ToolExecutor
22-
from maxkb.const import CONFIG
26+
from knowledge.models import FileSourceType
27+
from oss.serializers.file import FileSerializer
2328
from tools.models import Tool
2429

2530
function_executor = ToolExecutor()
@@ -126,6 +131,7 @@ def valid_function(tool_lib, workspace_id):
126131
if not tool_lib.is_active:
127132
raise Exception(_("Tool is not active"))
128133

134+
129135
def _filter_file_bytes(data):
130136
"""递归过滤掉所有层级的 file_bytes"""
131137
if isinstance(data, dict):
@@ -136,6 +142,27 @@ def _filter_file_bytes(data):
136142
return data
137143

138144

145+
def bytes_to_uploaded_file(file_bytes, file_name="unknown"):
146+
content_type, _ = mimetypes.guess_type(file_name)
147+
if content_type is None:
148+
# 如果未能识别,设置为默认的二进制文件类型
149+
content_type = "application/octet-stream"
150+
# 创建一个内存中的字节流对象
151+
file_stream = io.BytesIO(file_bytes)
152+
153+
# 获取文件大小
154+
file_size = len(file_bytes)
155+
156+
uploaded_file = InMemoryUploadedFile(
157+
file=file_stream,
158+
field_name=None,
159+
name=file_name,
160+
content_type=content_type,
161+
size=file_size,
162+
charset=None,
163+
)
164+
return uploaded_file
165+
139166

140167
class BaseToolLibNodeNode(IToolLibNode):
141168
def save_context(self, details, workflow_manage):
@@ -168,12 +195,50 @@ def execute(self, tool_lib_id, input_field_list, **kwargs) -> NodeResult:
168195
else:
169196
all_params = init_params_default_value | params
170197
if self.node.properties.get('kind') == 'data-source':
171-
all_params = {**all_params, **self.workflow_params.get('data_source')}
198+
download_file_list = []
199+
download_list = function_executor.exec_code(
200+
tool_lib.code,
201+
{**all_params, **self.workflow_params.get('data_source')},
202+
function_name='get_down_file_list'
203+
)
204+
for item in download_list:
205+
result = function_executor.exec_code(
206+
tool_lib.code,
207+
{**all_params, **self.workflow_params.get('data_source'),
208+
'download_item': item},
209+
function_name='download'
210+
)
211+
file_bytes = result.get('file_bytes', [])
212+
chunks = []
213+
for chunk in file_bytes:
214+
chunks.append(base64.b64decode(chunk))
215+
file = bytes_to_uploaded_file(b''.join(chunks), result.get('name'))
216+
file_url = self.upload_knowledge_file(file)
217+
download_file_list.append({'file_id': file_url.split('/')[-1], 'name': result.get('name')})
218+
all_params = {
219+
**all_params, **self.workflow_params.get('data_source'),
220+
'download_file_list': download_file_list
221+
}
172222
result = function_executor.exec_code(tool_lib.code, all_params)
173223
return NodeResult({'result': result},
174224
(self.workflow_manage.params.get('knowledge_base') or {}) if self.node.properties.get(
175225
'kind') == 'data-source' else {}, _write_context=write_context)
176226

227+
def upload_knowledge_file(self, file):
228+
knowledge_id = self.workflow_params.get('knowledge_id')
229+
meta = {
230+
'debug': False,
231+
'knowledge_id': knowledge_id,
232+
}
233+
file_url = FileSerializer(data={
234+
'file': file,
235+
'meta': meta,
236+
'source_id': knowledge_id,
237+
'source_type': FileSourceType.KNOWLEDGE.value
238+
}).upload()
239+
file.close()
240+
return file_url
241+
177242
def get_details(self, index: int, **kwargs):
178243
result = _filter_file_bytes(self.context.get('result'))
179244

0 commit comments

Comments
 (0)