Skip to content

Commit 29b4370

Browse files
committed
feat: 高级编排支持文件上传(WIP)
1 parent c866d05 commit 29b4370

File tree

12 files changed

+90
-30
lines changed

12 files changed

+90
-30
lines changed

apps/application/flow/step_node/document_extract_node/i_document_extract_node.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,7 @@
99

1010

1111
class DocumentExtractNodeSerializer(serializers.Serializer):
12-
# 需要查询的数据集id列表
13-
file_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
14-
error_messages=ErrMessage.list("数据集id列表"))
15-
16-
def is_valid(self, *, raise_exception=False):
17-
super().is_valid(raise_exception=True)
12+
document_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文档"))
1813

1914

2015
class IDocumentExtractNode(INode):
@@ -24,7 +19,9 @@ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
2419
return DocumentExtractNodeSerializer
2520

2621
def _run(self):
27-
return self.execute(**self.flow_params_serializer.data)
22+
res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('document_list')[0],
23+
self.node_params_serializer.data.get('document_list')[1:])
24+
return self.execute(document=res, **self.flow_params_serializer.data)
2825

29-
def execute(self, file_list, **kwargs) -> NodeResult:
26+
def execute(self, document, **kwargs) -> NodeResult:
3027
pass
Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,26 @@
11
# coding=utf-8
2-
2+
from application.flow.i_step_node import NodeResult
33
from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode
44

55

66
class BaseDocumentExtractNode(IDocumentExtractNode):
7-
def execute(self, file_list, **kwargs):
8-
pass
7+
def execute(self, document, **kwargs):
8+
self.context['document_list'] = document
9+
content = ''
10+
if len(document) > 0:
11+
for doc in document:
12+
content += doc['name']
13+
content += '\n-----------------------------------\n'
14+
return NodeResult({'content': content}, {})
915

1016
def get_details(self, index: int, **kwargs):
11-
pass
17+
return {
18+
'name': self.node.properties.get('stepName'),
19+
"index": index,
20+
'run_time': self.context.get('run_time'),
21+
'type': self.node.type,
22+
'content': self.context.get('content'),
23+
'status': self.status,
24+
'err_message': self.err_message,
25+
'document_list': self.context.get('document_list')
26+
}

apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class ImageUnderstandNodeSerializer(serializers.Serializer):
1818

1919
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
2020

21-
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片仅1张"))
21+
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片"))
2222

2323

2424
class IImageUnderstandNode(INode):

apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wo
2525
node.context['question'] = node_variable['question']
2626
node.context['run_time'] = time.time() - node.context['start_time']
2727
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
28-
workflow.answer += answer
28+
node.answer_text = answer
2929

3030

3131
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):

apps/application/flow/step_node/start_node/impl/base_start_node.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,12 @@ def execute(self, question, **kwargs) -> NodeResult:
5252
"""
5353
开始节点 初始化全局变量
5454
"""
55-
return NodeResult({'question': question, 'image': self.workflow_manage.image_list},
56-
workflow_variable)
55+
node_variable = {
56+
'question': question,
57+
'image': self.workflow_manage.image_list,
58+
'document': self.workflow_manage.document_list
59+
}
60+
return NodeResult(node_variable, workflow_variable)
5761

5862
def get_details(self, index: int, **kwargs):
5963
global_fields = []

apps/application/flow/workflow_manage.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,16 +240,20 @@ def is_end(self):
240240
class WorkflowManage:
241241
def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler,
242242
base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None,
243+
document_list=None,
243244
start_node_id=None,
244245
start_node_data=None, chat_record=None):
245246
if form_data is None:
246247
form_data = {}
247248
if image_list is None:
248249
image_list = []
250+
if document_list is None:
251+
document_list = []
249252
self.start_node = None
250253
self.start_node_result_future = None
251254
self.form_data = form_data
252255
self.image_list = image_list
256+
self.document_list = document_list
253257
self.params = params
254258
self.flow = flow
255259
self.lock = threading.Lock()

apps/application/serializers/chat_message_serializers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,8 @@ class ChatMessageSerializer(serializers.Serializer):
230230
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
231231
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
232232
form_data = serializers.DictField(required=False, error_messages=ErrMessage.char("全局变量"))
233-
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片仅1张"))
233+
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片"))
234+
document_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文档"))
234235

235236
def is_valid_application_workflow(self, *, raise_exception=False):
236237
self.is_valid_intraday_access_num()
@@ -322,6 +323,7 @@ def chat_work_flow(self, chat_info: ChatInfo, base_to_response):
322323
client_type = self.data.get('client_type')
323324
form_data = self.data.get('form_data')
324325
image_list = self.data.get('image_list')
326+
document_list = self.data.get('document_list')
325327
user_id = chat_info.application.user_id
326328
chat_record_id = self.data.get('chat_record_id')
327329
chat_record = None
@@ -336,7 +338,7 @@ def chat_work_flow(self, chat_info: ChatInfo, base_to_response):
336338
'client_id': client_id,
337339
'client_type': client_type,
338340
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type),
339-
base_to_response, form_data, image_list, self.data.get('runtime_node_id'),
341+
base_to_response, form_data, image_list, document_list, self.data.get('runtime_node_id'),
340342
self.data.get('node_data'), chat_record)
341343
r = work_flow_manage.run()
342344
return r

apps/application/views/chat_views.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ def post(self, request: Request, chat_id: str):
132132

133133
'image_list': request.data.get(
134134
'image_list') if 'image_list' in request.data else [],
135+
'document_list': request.data.get(
136+
'document_list') if 'document_list' in request.data else [],
135137
'client_type': request.auth.client_type,
136138
'runtime_node_id': request.data.get('runtime_node_id', None),
137139
'node_data': request.data.get('node_data', {}),

ui/src/api/type/application.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ interface chatType {
3939
record_id: string
4040
chat_id: string
4141
vote_status: string
42-
status?: number
42+
status?: number,
43+
execution_details: any[]
4344
}
4445

4546
export class ChatRecordManage {

ui/src/components/ai-chat/component/chat-input-operate/index.vue

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020

2121
<div class="operate flex align-center">
2222
<span v-if="props.applicationDetails.file_upload_enable" class="flex align-center">
23+
<!-- accept="image/jpeg, image/png, image/gif"-->
2324
<el-upload
2425
action="#"
2526
:auto-upload="false"
2627
:show-file-list="false"
28+
:accept="[...imageExtensions, ...documentExtensions].map((ext) => '.' + ext).join(',')"
2729
:on-change="(file: any, fileList: any) => uploadFile(file, fileList)"
2830
>
2931
<el-button text>
@@ -126,6 +128,13 @@ const localLoading = computed({
126128
emit('update:loading', v)
127129
}
128130
})
131+
132+
133+
const imageExtensions = ['jpg', 'jpeg', 'png', 'gif', 'bmp']
134+
const documentExtensions = ['pdf', 'docx', 'txt', 'xls', 'xlsx', 'md', 'html']
135+
const videoExtensions = ['mp4', 'avi', 'mov', 'mkv', 'flv']
136+
const audioExtensions = ['mp3', 'wav', 'aac', 'flac']
137+
129138
const uploadFile = async (file: any, fileList: any) => {
130139
const { maxFiles, fileLimit } = props.applicationDetails.file_upload_setting
131140
if (fileList.length > maxFiles) {
@@ -141,7 +150,18 @@ const uploadFile = async (file: any, fileList: any) => {
141150
const formData = new FormData()
142151
for (const file of fileList) {
143152
formData.append('file', file.raw, file.name)
144-
uploadFileList.value.push(file)
153+
//
154+
const extension = file.name.split('.').pop().toLowerCase() // 获取文件后缀名并转为小写
155+
156+
if (imageExtensions.includes(extension)) {
157+
uploadImageList.value.push(file)
158+
} else if (documentExtensions.includes(extension)) {
159+
uploadDocumentList.value.push(file)
160+
} else if (videoExtensions.includes(extension)) {
161+
// videos.push(file)
162+
} else if (audioExtensions.includes(extension)) {
163+
// audios.push(file)
164+
}
145165
}
146166
147167
if (!chatId_context.value) {
@@ -158,21 +178,22 @@ const uploadFile = async (file: any, fileList: any) => {
158178
)
159179
.then((response) => {
160180
fileList.splice(0, fileList.length)
161-
uploadFileList.value.forEach((file: any) => {
181+
uploadImageList.value.forEach((file: any) => {
162182
const f = response.data.filter((f: any) => f.name === file.name)
163183
if (f.length > 0) {
164184
file.url = f[0].url
165185
file.file_id = f[0].file_id
166186
}
167187
})
168-
console.log(uploadFileList.value)
188+
console.log(uploadDocumentList.value, uploadImageList.value)
169189
})
170190
}
171191
const recorderTime = ref(0)
172192
const startRecorderTime = ref(false)
173193
const recorderLoading = ref(false)
174194
const inputValue = ref<string>('')
175-
const uploadFileList = ref<Array<any>>([])
195+
const uploadImageList = ref<Array<any>>([])
196+
const uploadDocumentList = ref<Array<any>>([])
176197
const mediaRecorderStatus = ref(true)
177198
// 定义响应式引用
178199
const mediaRecorder = ref<any>(null)
@@ -289,15 +310,20 @@ const handleTimeChange = () => {
289310
handleTimeChange()
290311
}, 1000)
291312
}
313+
292314
function sendChatHandle(event: any) {
293315
if (!event.ctrlKey) {
294316
// 如果没有按下组合键ctrl,则会阻止默认事件
295317
event.preventDefault()
296318
if (!isDisabledChart.value && !props.loading && !event.isComposing) {
297319
if (inputValue.value.trim()) {
298-
props.sendMessage(inputValue.value, { image_list: uploadFileList.value })
320+
props.sendMessage(inputValue.value, {
321+
image_list: uploadImageList.value,
322+
document_list: uploadDocumentList.value
323+
})
299324
inputValue.value = ''
300-
uploadFileList.value = []
325+
uploadImageList.value = []
326+
uploadDocumentList.value = []
301327
quickInputRef.value.textareaStyle.height = '45px'
302328
}
303329
}

0 commit comments

Comments
 (0)