Skip to content

Commit b8aa475

Browse files
authored
fix: 修复工作流节点输出等问题 (#1716)
1 parent bce2558 commit b8aa475

File tree

13 files changed

+313
-73
lines changed

13 files changed

+313
-73
lines changed

apps/application/flow/i_step_node.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import time
1010
import uuid
1111
from abc import abstractmethod
12+
from hashlib import sha1
1213
from typing import Type, Dict, List
1314

1415
from django.core import cache
@@ -131,6 +132,7 @@ class FlowParamsSerializer(serializers.Serializer):
131132

132133

133134
class INode:
135+
view_type = 'many_view'
134136

135137
@abstractmethod
136138
def save_context(self, details, workflow_manage):
@@ -139,7 +141,7 @@ def save_context(self, details, workflow_manage):
139141
def get_answer_text(self):
140142
return self.answer_text
141143

142-
def __init__(self, node, workflow_params, workflow_manage, runtime_node_id=None):
144+
def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None):
143145
# 当前步骤上下文,用于存储当前步骤信息
144146
self.status = 200
145147
self.err_message = ''
@@ -152,10 +154,13 @@ def __init__(self, node, workflow_params, workflow_manage, runtime_node_id=None)
152154
self.context = {}
153155
self.answer_text = None
154156
self.id = node.id
155-
if runtime_node_id is None:
156-
self.runtime_node_id = str(uuid.uuid1())
157-
else:
158-
self.runtime_node_id = runtime_node_id
157+
if up_node_id_list is None:
158+
up_node_id_list = []
159+
self.up_node_id_list = up_node_id_list
160+
self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS,
161+
"".join([*sorted(up_node_id_list),
162+
node.id]))),
163+
"utf-8")).hexdigest()
159164

160165
def valid_args(self, node_params, flow_params):
161166
flow_params_serializer_class = self.get_flow_params_serializer_class()

apps/application/flow/step_node/form_node/i_form_node.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class FormNodeParamsSerializer(serializers.Serializer):
2121

2222
class IFormNode(INode):
2323
type = 'form-node'
24+
view_type = 'single_view'
2425

2526
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
2627
return FormNodeParamsSerializer

apps/application/flow/step_node/form_node/impl/base_form_node.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def save_context(self, details, workflow_manage):
3434
self.context['form_field_list'] = details.get('form_field_list')
3535
self.context['run_time'] = details.get('run_time')
3636
self.context['start_time'] = details.get('start_time')
37+
self.context['form_data'] = details.get('form_data')
38+
self.context['is_submit'] = details.get('is_submit')
3739
self.answer_text = details.get('result')
3840

3941
def execute(self, form_field_list, form_content_format, **kwargs) -> NodeResult:
@@ -77,6 +79,7 @@ def get_details(self, index: int, **kwargs):
7779
"form_field_list": self.context.get('form_field_list'),
7880
'form_data': self.context.get('form_data'),
7981
'start_time': self.context.get('start_time'),
82+
'is_submit': self.context.get('is_submit'),
8083
'run_time': self.context.get('run_time'),
8184
'type': self.node.type,
8285
'status': self.status,

apps/application/flow/workflow_manage.py

Lines changed: 93 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwa
5252
self.__setattr__(keyword, kwargs.get(keyword))
5353

5454

55-
end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node', 'image-understand-node']
55+
end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node',
56+
'image-understand-node']
5657

5758

5859
class Flow:
@@ -229,7 +230,9 @@ def __init__(self):
229230
def add_chunk(self, chunk):
230231
self.chunk_list.append(chunk)
231232

232-
def end(self):
233+
def end(self, chunk=None):
234+
if chunk is not None:
235+
self.add_chunk(chunk)
233236
self.status = 200
234237

235238
def is_end(self):
@@ -266,6 +269,7 @@ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandl
266269
self.status = 0
267270
self.base_to_response = base_to_response
268271
self.chat_record = chat_record
272+
self.await_future_map = {}
269273
if start_node_id is not None:
270274
self.load_node(chat_record, start_node_id, start_node_data)
271275
else:
@@ -286,14 +290,16 @@ def load_node(self, chat_record, start_node_id, start_node_data):
286290
for node_details in sorted(chat_record.details.values(), key=lambda d: d.get('index')):
287291
node_id = node_details.get('node_id')
288292
if node_details.get('runtime_node_id') == start_node_id:
289-
self.start_node = self.get_node_cls_by_id(node_id, node_details.get('runtime_node_id'))
293+
self.start_node = self.get_node_cls_by_id(node_id, node_details.get('up_node_id_list'))
290294
self.start_node.valid_args(self.start_node.node_params, self.start_node.workflow_params)
291295
self.start_node.save_context(node_details, self)
292296
node_result = NodeResult({**start_node_data, 'form_data': start_node_data, 'is_submit': True}, {})
293297
self.start_node_result_future = NodeResultFuture(node_result, None)
294-
return
298+
self.node_context.append(self.start_node)
299+
continue
300+
295301
node_id = node_details.get('node_id')
296-
node = self.get_node_cls_by_id(node_id, node_details.get('runtime_node_id'))
302+
node = self.get_node_cls_by_id(node_id, node_details.get('up_node_id_list'))
297303
node.valid_args(node.node_params, node.workflow_params)
298304
node.save_context(node_details, self)
299305
self.node_context.append(node)
@@ -345,17 +351,22 @@ def await_result(self, result):
345351
if chunk is None:
346352
break
347353
yield chunk
348-
yield self.get_chunk_content('', True)
349354
finally:
350355
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
351356
self.answer,
352357
self)
353-
yield self.get_chunk_content('', True)
354358

355359
def run_chain_async(self, current_node, node_result_future):
356360
future = executor.submit(self.run_chain, current_node, node_result_future)
357361
return future
358362

363+
def set_await_map(self, node_run_list):
364+
sorted_node_run_list = sorted(node_run_list, key=lambda n: n.get('node').node.y)
365+
for index in range(len(sorted_node_run_list)):
366+
self.await_future_map[sorted_node_run_list[index].get('node').runtime_node_id] = [
367+
sorted_node_run_list[i].get('future')
368+
for i in range(index)]
369+
359370
def run_chain(self, current_node, node_result_future=None):
360371
if current_node is None:
361372
start_node = self.get_start_node()
@@ -365,6 +376,9 @@ def run_chain(self, current_node, node_result_future=None):
365376
try:
366377
is_stream = self.params.get('stream', True)
367378
# 处理节点响应
379+
await_future_list = self.await_future_map.get(current_node.runtime_node_id, None)
380+
if await_future_list is not None:
381+
[f.result() for f in await_future_list]
368382
result = self.hand_event_node_result(current_node,
369383
node_result_future) if is_stream else self.hand_node_result(
370384
current_node, node_result_future)
@@ -373,11 +387,9 @@ def run_chain(self, current_node, node_result_future=None):
373387
return
374388
node_list = self.get_next_node_list(current_node, result)
375389
# 获取到可执行的子节点
376-
result_list = []
377-
for node in node_list:
378-
result = self.run_chain_async(node, None)
379-
result_list.append(result)
380-
[r.result() for r in result_list]
390+
result_list = [{'node': node, 'future': self.run_chain_async(node, None)} for node in node_list]
391+
self.set_await_map(result_list)
392+
[r.get('future').result() for r in result_list]
381393
if self.status == 0:
382394
self.status = 200
383395
except Exception as e:
@@ -401,6 +413,14 @@ def hand_node_result(self, current_node, node_result_future):
401413
current_node.get_write_error_context(e)
402414
self.answer += str(e)
403415

416+
def append_node(self, current_node):
417+
for index in range(len(self.node_context)):
418+
n = self.node_context[index]
419+
if current_node.id == n.node.id and current_node.runtime_node_id == n.runtime_node_id:
420+
self.node_context[index] = current_node
421+
return
422+
self.node_context.append(current_node)
423+
404424
def hand_event_node_result(self, current_node, node_result_future):
405425
node_chunk = NodeChunk()
406426
try:
@@ -412,22 +432,35 @@ def hand_event_node_result(self, current_node, node_result_future):
412432
for r in result:
413433
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
414434
self.params['chat_record_id'],
415-
r, False, 0, 0)
435+
current_node.id,
436+
current_node.up_node_id_list,
437+
r, False, 0, 0,
438+
{'node_type': current_node.type,
439+
'view_type': current_node.view_type})
416440
node_chunk.add_chunk(chunk)
417-
node_chunk.end()
441+
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
442+
self.params['chat_record_id'],
443+
current_node.id,
444+
current_node.up_node_id_list,
445+
'', False, 0, 0, {'node_is_end': True,
446+
'node_type': current_node.type,
447+
'view_type': current_node.view_type})
448+
node_chunk.end(chunk)
418449
else:
419450
list(result)
420451
# 添加节点
421-
self.node_context.append(current_node)
452+
self.append_node(current_node)
422453
return current_result
423454
except Exception as e:
424455
# 添加节点
425-
self.node_context.append(current_node)
456+
self.append_node(current_node)
426457
traceback.print_exc()
427458
self.answer += str(e)
428459
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
429460
self.params['chat_record_id'],
430-
str(e), False, 0, 0)
461+
current_node.id,
462+
current_node.up_node_id_list,
463+
str(e), False, 0, 0, {'node_is_end': True})
431464
if not self.node_chunk_manage.contains(node_chunk):
432465
self.node_chunk_manage.add_node_chunk(node_chunk)
433466
node_chunk.add_chunk(chunk)
@@ -492,32 +525,36 @@ def get_runtime_details(self):
492525
continue
493526
details = node.get_details(index)
494527
details['node_id'] = node.id
528+
details['up_node_id_list'] = node.up_node_id_list
495529
details['runtime_node_id'] = node.runtime_node_id
496530
details_result[node.runtime_node_id] = details
497531
return details_result
498532

499533
def get_answer_text_list(self):
500-
answer_text_list = []
534+
result = []
535+
next_node_id_list = []
536+
if self.start_node is not None:
537+
next_node_id_list = [edge.targetNodeId for edge in self.flow.edges if
538+
edge.sourceNodeId == self.start_node.id]
501539
for index in range(len(self.node_context)):
502540
node = self.node_context[index]
541+
up_node = None
542+
if index > 0:
543+
up_node = self.node_context[index - 1]
503544
answer_text = node.get_answer_text()
504545
if answer_text is not None:
505-
if self.chat_record is not None and self.chat_record.details is not None:
506-
details = self.chat_record.details.get(node.runtime_node_id)
507-
if details is not None and self.start_node.runtime_node_id != node.runtime_node_id:
508-
continue
509-
answer_text_list.append(
510-
{'content': answer_text, 'type': 'form' if node.type == 'form-node' else 'md'})
511-
result = []
512-
for index in range(len(answer_text_list)):
513-
answer = answer_text_list[index]
514-
if index == 0:
515-
result.append(answer.get('content'))
516-
continue
517-
if answer.get('type') != answer_text_list[index - 1].get('type'):
518-
result.append(answer.get('content'))
519-
else:
520-
result[-1] += answer.get('content')
546+
if up_node is None or node.view_type == 'single_view' or (
547+
node.view_type == 'many_view' and up_node.view_type == 'single_view'):
548+
result.append(node.get_answer_text())
549+
elif self.chat_record is not None and next_node_id_list.__contains__(
550+
node.id) and up_node is not None and not next_node_id_list.__contains__(
551+
up_node.id):
552+
result.append(node.get_answer_text())
553+
else:
554+
content = result[len(result) - 1]
555+
answer_text = node.get_answer_text()
556+
result[len(result) - 1] += answer_text if len(
557+
content) == 0 else ('\n\n' + answer_text)
521558
return result
522559

523560
def get_next_node(self):
@@ -540,14 +577,28 @@ def get_next_node(self):
540577

541578
return None
542579

580+
@staticmethod
581+
def dependent_node(up_node_id, node):
582+
if node.id == up_node_id:
583+
if node.type == 'form-node':
584+
if node.context.get('form_data', None) is not None:
585+
return True
586+
return False
587+
return True
588+
543589
def dependent_node_been_executed(self, node_id):
544590
"""
545591
判断依赖节点是否都已执行
546592
@param node_id: 需要判断的节点id
547593
@return:
548594
"""
549595
up_node_id_list = [edge.sourceNodeId for edge in self.flow.edges if edge.targetNodeId == node_id]
550-
return all([any([node.id == up_node_id for node in self.node_context]) for up_node_id in up_node_id_list])
596+
return all([any([self.dependent_node(up_node_id, node) for node in self.node_context]) for up_node_id in
597+
up_node_id_list])
598+
599+
def get_up_node_id_list(self, node_id):
600+
up_node_id_list = [edge.sourceNodeId for edge in self.flow.edges if edge.targetNodeId == node_id]
601+
return up_node_id_list
551602

552603
def get_next_node_list(self, current_node, current_node_result):
553604
"""
@@ -556,6 +607,7 @@ def get_next_node_list(self, current_node, current_node_result):
556607
@param current_node_result: 当前可执行节点结果
557608
@return: 可执行节点列表
558609
"""
610+
559611
if current_node.type == 'form-node' and 'form_data' not in current_node_result.node_variable:
560612
return []
561613
node_list = []
@@ -564,11 +616,13 @@ def get_next_node_list(self, current_node, current_node_result):
564616
if (edge.sourceNodeId == current_node.id and
565617
f"{edge.sourceNodeId}_{current_node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId):
566618
if self.dependent_node_been_executed(edge.targetNodeId):
567-
node_list.append(self.get_node_cls_by_id(edge.targetNodeId))
619+
node_list.append(
620+
self.get_node_cls_by_id(edge.targetNodeId, self.get_up_node_id_list(edge.targetNodeId)))
568621
else:
569622
for edge in self.flow.edges:
570623
if edge.sourceNodeId == current_node.id and self.dependent_node_been_executed(edge.targetNodeId):
571-
node_list.append(self.get_node_cls_by_id(edge.targetNodeId))
624+
node_list.append(
625+
self.get_node_cls_by_id(edge.targetNodeId, self.get_up_node_id_list(edge.targetNodeId)))
572626
return node_list
573627

574628
def get_reference_field(self, node_id: str, fields: List[str]):
@@ -629,11 +683,11 @@ def get_base_node(self):
629683
base_node_list = [node for node in self.flow.nodes if node.type == 'base-node']
630684
return base_node_list[0]
631685

632-
def get_node_cls_by_id(self, node_id, runtime_node_id=None):
686+
def get_node_cls_by_id(self, node_id, up_node_id_list=None):
633687
for node in self.flow.nodes:
634688
if node.id == node_id:
635689
node_instance = get_node(node.type)(node,
636-
self.params, self, runtime_node_id)
690+
self.params, self, up_node_id_list)
637691
return node_instance
638692
return None
639693

apps/application/serializers/chat_message_serializers.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,13 @@ class ChatMessageSerializer(serializers.Serializer):
224224
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否重新回答"))
225225
chat_record_id = serializers.UUIDField(required=False, allow_null=True,
226226
error_messages=ErrMessage.uuid("对话记录id"))
227+
228+
node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
229+
error_messages=ErrMessage.char("节点id"))
230+
227231
runtime_node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
228-
error_messages=ErrMessage.char("节点id"))
232+
error_messages=ErrMessage.char("运行时节点id"))
233+
229234
node_data = serializers.DictField(required=False, error_messages=ErrMessage.char("节点参数"))
230235
application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id"))
231236
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
@@ -339,7 +344,8 @@ def chat_work_flow(self, chat_info: ChatInfo, base_to_response):
339344
'client_id': client_id,
340345
'client_type': client_type,
341346
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type),
342-
base_to_response, form_data, image_list, document_list, self.data.get('runtime_node_id'),
347+
base_to_response, form_data, image_list, document_list,
348+
self.data.get('runtime_node_id'),
343349
self.data.get('node_data'), chat_record)
344350
r = work_flow_manage.run()
345351
return r

0 commit comments

Comments
 (0)