@@ -52,7 +52,8 @@ def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwa
5252 self .__setattr__ (keyword , kwargs .get (keyword ))
5353
5454
55- end_nodes = ['ai-chat-node' , 'reply-node' , 'function-node' , 'function-lib-node' , 'application-node' , 'image-understand-node' ]
55+ end_nodes = ['ai-chat-node' , 'reply-node' , 'function-node' , 'function-lib-node' , 'application-node' ,
56+ 'image-understand-node' ]
5657
5758
5859class Flow :
@@ -229,7 +230,9 @@ def __init__(self):
229230 def add_chunk (self , chunk ):
230231 self .chunk_list .append (chunk )
231232
232- def end (self ):
233+ def end (self , chunk = None ):
234+ if chunk is not None :
235+ self .add_chunk (chunk )
233236 self .status = 200
234237
235238 def is_end (self ):
@@ -266,6 +269,7 @@ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandl
266269 self .status = 0
267270 self .base_to_response = base_to_response
268271 self .chat_record = chat_record
272+ self .await_future_map = {}
269273 if start_node_id is not None :
270274 self .load_node (chat_record , start_node_id , start_node_data )
271275 else :
@@ -286,14 +290,16 @@ def load_node(self, chat_record, start_node_id, start_node_data):
286290 for node_details in sorted (chat_record .details .values (), key = lambda d : d .get ('index' )):
287291 node_id = node_details .get ('node_id' )
288292 if node_details .get ('runtime_node_id' ) == start_node_id :
289- self .start_node = self .get_node_cls_by_id (node_id , node_details .get ('runtime_node_id ' ))
293+ self .start_node = self .get_node_cls_by_id (node_id , node_details .get ('up_node_id_list ' ))
290294 self .start_node .valid_args (self .start_node .node_params , self .start_node .workflow_params )
291295 self .start_node .save_context (node_details , self )
292296 node_result = NodeResult ({** start_node_data , 'form_data' : start_node_data , 'is_submit' : True }, {})
293297 self .start_node_result_future = NodeResultFuture (node_result , None )
294- return
298+ self .node_context .append (self .start_node )
299+ continue
300+
295301 node_id = node_details .get ('node_id' )
296- node = self .get_node_cls_by_id (node_id , node_details .get ('runtime_node_id ' ))
302+ node = self .get_node_cls_by_id (node_id , node_details .get ('up_node_id_list ' ))
297303 node .valid_args (node .node_params , node .workflow_params )
298304 node .save_context (node_details , self )
299305 self .node_context .append (node )
@@ -345,17 +351,22 @@ def await_result(self, result):
345351 if chunk is None :
346352 break
347353 yield chunk
348- yield self .get_chunk_content ('' , True )
349354 finally :
350355 self .work_flow_post_handler .handler (self .params ['chat_id' ], self .params ['chat_record_id' ],
351356 self .answer ,
352357 self )
353- yield self .get_chunk_content ('' , True )
354358
355359 def run_chain_async (self , current_node , node_result_future ):
356360 future = executor .submit (self .run_chain , current_node , node_result_future )
357361 return future
358362
363+ def set_await_map (self , node_run_list ):
364+ sorted_node_run_list = sorted (node_run_list , key = lambda n : n .get ('node' ).node .y )
365+ for index in range (len (sorted_node_run_list )):
366+ self .await_future_map [sorted_node_run_list [index ].get ('node' ).runtime_node_id ] = [
367+ sorted_node_run_list [i ].get ('future' )
368+ for i in range (index )]
369+
359370 def run_chain (self , current_node , node_result_future = None ):
360371 if current_node is None :
361372 start_node = self .get_start_node ()
@@ -365,6 +376,9 @@ def run_chain(self, current_node, node_result_future=None):
365376 try :
366377 is_stream = self .params .get ('stream' , True )
367378 # 处理节点响应
379+ await_future_list = self .await_future_map .get (current_node .runtime_node_id , None )
380+ if await_future_list is not None :
381+ [f .result () for f in await_future_list ]
368382 result = self .hand_event_node_result (current_node ,
369383 node_result_future ) if is_stream else self .hand_node_result (
370384 current_node , node_result_future )
@@ -373,11 +387,9 @@ def run_chain(self, current_node, node_result_future=None):
373387 return
374388 node_list = self .get_next_node_list (current_node , result )
375389 # 获取到可执行的子节点
376- result_list = []
377- for node in node_list :
378- result = self .run_chain_async (node , None )
379- result_list .append (result )
380- [r .result () for r in result_list ]
390+ result_list = [{'node' : node , 'future' : self .run_chain_async (node , None )} for node in node_list ]
391+ self .set_await_map (result_list )
392+ [r .get ('future' ).result () for r in result_list ]
381393 if self .status == 0 :
382394 self .status = 200
383395 except Exception as e :
@@ -401,6 +413,14 @@ def hand_node_result(self, current_node, node_result_future):
401413 current_node .get_write_error_context (e )
402414 self .answer += str (e )
403415
416+ def append_node (self , current_node ):
417+ for index in range (len (self .node_context )):
418+ n = self .node_context [index ]
419+ if current_node .id == n .node .id and current_node .runtime_node_id == n .runtime_node_id :
420+ self .node_context [index ] = current_node
421+ return
422+ self .node_context .append (current_node )
423+
404424 def hand_event_node_result (self , current_node , node_result_future ):
405425 node_chunk = NodeChunk ()
406426 try :
@@ -412,22 +432,35 @@ def hand_event_node_result(self, current_node, node_result_future):
412432 for r in result :
413433 chunk = self .base_to_response .to_stream_chunk_response (self .params ['chat_id' ],
414434 self .params ['chat_record_id' ],
415- r , False , 0 , 0 )
435+ current_node .id ,
436+ current_node .up_node_id_list ,
437+ r , False , 0 , 0 ,
438+ {'node_type' : current_node .type ,
439+ 'view_type' : current_node .view_type })
416440 node_chunk .add_chunk (chunk )
417- node_chunk .end ()
441+ chunk = self .base_to_response .to_stream_chunk_response (self .params ['chat_id' ],
442+ self .params ['chat_record_id' ],
443+ current_node .id ,
444+ current_node .up_node_id_list ,
445+ '' , False , 0 , 0 , {'node_is_end' : True ,
446+ 'node_type' : current_node .type ,
447+ 'view_type' : current_node .view_type })
448+ node_chunk .end (chunk )
418449 else :
419450 list (result )
420451 # 添加节点
421- self .node_context . append (current_node )
452+ self .append_node (current_node )
422453 return current_result
423454 except Exception as e :
424455 # 添加节点
425- self .node_context . append (current_node )
456+ self .append_node (current_node )
426457 traceback .print_exc ()
427458 self .answer += str (e )
428459 chunk = self .base_to_response .to_stream_chunk_response (self .params ['chat_id' ],
429460 self .params ['chat_record_id' ],
430- str (e ), False , 0 , 0 )
461+ current_node .id ,
462+ current_node .up_node_id_list ,
463+ str (e ), False , 0 , 0 , {'node_is_end' : True })
431464 if not self .node_chunk_manage .contains (node_chunk ):
432465 self .node_chunk_manage .add_node_chunk (node_chunk )
433466 node_chunk .add_chunk (chunk )
@@ -492,32 +525,36 @@ def get_runtime_details(self):
492525 continue
493526 details = node .get_details (index )
494527 details ['node_id' ] = node .id
528+ details ['up_node_id_list' ] = node .up_node_id_list
495529 details ['runtime_node_id' ] = node .runtime_node_id
496530 details_result [node .runtime_node_id ] = details
497531 return details_result
498532
499533 def get_answer_text_list (self ):
500- answer_text_list = []
534+ result = []
535+ next_node_id_list = []
536+ if self .start_node is not None :
537+ next_node_id_list = [edge .targetNodeId for edge in self .flow .edges if
538+ edge .sourceNodeId == self .start_node .id ]
501539 for index in range (len (self .node_context )):
502540 node = self .node_context [index ]
541+ up_node = None
542+ if index > 0 :
543+ up_node = self .node_context [index - 1 ]
503544 answer_text = node .get_answer_text ()
504545 if answer_text is not None :
505- if self .chat_record is not None and self .chat_record .details is not None :
506- details = self .chat_record .details .get (node .runtime_node_id )
507- if details is not None and self .start_node .runtime_node_id != node .runtime_node_id :
508- continue
509- answer_text_list .append (
510- {'content' : answer_text , 'type' : 'form' if node .type == 'form-node' else 'md' })
511- result = []
512- for index in range (len (answer_text_list )):
513- answer = answer_text_list [index ]
514- if index == 0 :
515- result .append (answer .get ('content' ))
516- continue
517- if answer .get ('type' ) != answer_text_list [index - 1 ].get ('type' ):
518- result .append (answer .get ('content' ))
519- else :
520- result [- 1 ] += answer .get ('content' )
546+ if up_node is None or node .view_type == 'single_view' or (
547+ node .view_type == 'many_view' and up_node .view_type == 'single_view' ):
548+ result .append (node .get_answer_text ())
549+ elif self .chat_record is not None and next_node_id_list .__contains__ (
550+ node .id ) and up_node is not None and not next_node_id_list .__contains__ (
551+ up_node .id ):
552+ result .append (node .get_answer_text ())
553+ else :
554+ content = result [len (result ) - 1 ]
555+ answer_text = node .get_answer_text ()
556+ result [len (result ) - 1 ] += answer_text if len (
557+ content ) == 0 else ('\n \n ' + answer_text )
521558 return result
522559
523560 def get_next_node (self ):
@@ -540,14 +577,28 @@ def get_next_node(self):
540577
541578 return None
542579
580+ @staticmethod
581+ def dependent_node (up_node_id , node ):
582+ if node .id == up_node_id :
583+ if node .type == 'form-node' :
584+ if node .context .get ('form_data' , None ) is not None :
585+ return True
586+ return False
587+ return True
588+
543589 def dependent_node_been_executed (self , node_id ):
544590 """
545591 判断依赖节点是否都已执行
546592 @param node_id: 需要判断的节点id
547593 @return:
548594 """
549595 up_node_id_list = [edge .sourceNodeId for edge in self .flow .edges if edge .targetNodeId == node_id ]
550- return all ([any ([node .id == up_node_id for node in self .node_context ]) for up_node_id in up_node_id_list ])
596+ return all ([any ([self .dependent_node (up_node_id , node ) for node in self .node_context ]) for up_node_id in
597+ up_node_id_list ])
598+
599+ def get_up_node_id_list (self , node_id ):
600+ up_node_id_list = [edge .sourceNodeId for edge in self .flow .edges if edge .targetNodeId == node_id ]
601+ return up_node_id_list
551602
552603 def get_next_node_list (self , current_node , current_node_result ):
553604 """
@@ -556,6 +607,7 @@ def get_next_node_list(self, current_node, current_node_result):
556607 @param current_node_result: 当前可执行节点结果
557608 @return: 可执行节点列表
558609 """
610+
559611 if current_node .type == 'form-node' and 'form_data' not in current_node_result .node_variable :
560612 return []
561613 node_list = []
@@ -564,11 +616,13 @@ def get_next_node_list(self, current_node, current_node_result):
564616 if (edge .sourceNodeId == current_node .id and
565617 f"{ edge .sourceNodeId } _{ current_node_result .node_variable .get ('branch_id' )} _right" == edge .sourceAnchorId ):
566618 if self .dependent_node_been_executed (edge .targetNodeId ):
567- node_list .append (self .get_node_cls_by_id (edge .targetNodeId ))
619+ node_list .append (
620+ self .get_node_cls_by_id (edge .targetNodeId , self .get_up_node_id_list (edge .targetNodeId )))
568621 else :
569622 for edge in self .flow .edges :
570623 if edge .sourceNodeId == current_node .id and self .dependent_node_been_executed (edge .targetNodeId ):
571- node_list .append (self .get_node_cls_by_id (edge .targetNodeId ))
624+ node_list .append (
625+ self .get_node_cls_by_id (edge .targetNodeId , self .get_up_node_id_list (edge .targetNodeId )))
572626 return node_list
573627
574628 def get_reference_field (self , node_id : str , fields : List [str ]):
@@ -629,11 +683,11 @@ def get_base_node(self):
629683 base_node_list = [node for node in self .flow .nodes if node .type == 'base-node' ]
630684 return base_node_list [0 ]
631685
632- def get_node_cls_by_id (self , node_id , runtime_node_id = None ):
686+ def get_node_cls_by_id (self , node_id , up_node_id_list = None ):
633687 for node in self .flow .nodes :
634688 if node .id == node_id :
635689 node_instance = get_node (node .type )(node ,
636- self .params , self , runtime_node_id )
690+ self .params , self , up_node_id_list )
637691 return node_instance
638692 return None
639693
0 commit comments