66 @date:2024/1/9 17:40
77 @desc:
88"""
9+ import concurrent
910import json
1011import threading
1112import traceback
1213from concurrent .futures import ThreadPoolExecutor
1314from functools import reduce
1415from typing import List , Dict
1516
17+ from django .db import close_old_connections
1618from django .db .models import QuerySet
1719from langchain_core .prompts import PromptTemplate
1820from rest_framework import status
@@ -223,23 +225,6 @@ def pop(self):
223225 return None
224226
225227
226- class NodeChunk :
227- def __init__ (self ):
228- self .status = 0
229- self .chunk_list = []
230-
231- def add_chunk (self , chunk ):
232- self .chunk_list .append (chunk )
233-
234- def end (self , chunk = None ):
235- if chunk is not None :
236- self .add_chunk (chunk )
237- self .status = 200
238-
239- def is_end (self ):
240- return self .status == 200
241-
242-
243228class WorkflowManage :
244229 def __init__ (self , flow : Flow , params , work_flow_post_handler : WorkFlowPostHandler ,
245230 base_to_response : BaseToResponse = SystemToResponse (), form_data = None , image_list = None ,
@@ -273,8 +258,9 @@ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandl
273258 self .status = 200
274259 self .base_to_response = base_to_response
275260 self .chat_record = chat_record
276- self .await_future_map = {}
277261 self .child_node = child_node
262+ self .future_list = []
263+ self .lock = threading .Lock ()
278264 if start_node_id is not None :
279265 self .load_node (chat_record , start_node_id , start_node_data )
280266 else :
@@ -319,6 +305,7 @@ def get_node_params(n):
319305 self .node_context .append (node )
320306
321307 def run (self ):
308+ close_old_connections ()
322309 if self .params .get ('stream' ):
323310 return self .run_stream (self .start_node , None )
324311 return self .run_block ()
@@ -328,8 +315,9 @@ def run_block(self):
328315 非流式响应
329316 @return: 结果
330317 """
331- result = self .run_chain_async (None , None )
332- result .result ()
318+ self .run_chain_async (None , None )
319+ while self .is_run ():
320+ pass
333321 details = self .get_runtime_details ()
334322 message_tokens = sum ([row .get ('message_tokens' ) for row in details .values () if
335323 'message_tokens' in row and row .get ('message_tokens' ) is not None ])
@@ -350,12 +338,22 @@ def run_stream(self, current_node, node_result_future):
350338 流式响应
351339 @return:
352340 """
353- result = self .run_chain_async (current_node , node_result_future )
354- return tools .to_stream_response_simple (self .await_result (result ))
341+ self .run_chain_async (current_node , node_result_future )
342+ return tools .to_stream_response_simple (self .await_result ())
355343
356- def await_result (self , result ):
344+ def is_run (self , timeout = 0.1 ):
345+ self .lock .acquire ()
357346 try :
358- while await_result (result ):
347+ r = concurrent .futures .wait (self .future_list , timeout )
348+ return len (r .not_done ) > 0
349+ except Exception as e :
350+ return True
351+ finally :
352+ self .lock .release ()
353+
354+ def await_result (self ):
355+ try :
356+ while self .is_run ():
359357 while True :
360358 chunk = self .node_chunk_manage .pop ()
361359 if chunk is not None :
@@ -383,42 +381,39 @@ def await_result(self, result):
383381 '' , True , message_tokens , answer_tokens , {})
384382
385383 def run_chain_async (self , current_node , node_result_future ):
386- return executor .submit (self .run_chain_manage , current_node , node_result_future )
384+ future = executor .submit (self .run_chain_manage , current_node , node_result_future )
385+ self .future_list .append (future )
387386
388387 def run_chain_manage (self , current_node , node_result_future ):
389388 if current_node is None :
390389 start_node = self .get_start_node ()
391390 current_node = get_node (start_node .type )(start_node , self .params , self )
391+ self .node_chunk_manage .add_node_chunk (current_node .node_chunk )
392+ # 添加节点
393+ self .append_node (current_node )
392394 result = self .run_chain (current_node , node_result_future )
393395 if result is None :
394396 return
395397 node_list = self .get_next_node_list (current_node , result )
396398 if len (node_list ) == 1 :
397399 self .run_chain_manage (node_list [0 ], None )
398400 elif len (node_list ) > 1 :
399-
401+ sorted_node_run_list = sorted ( node_list , key = lambda n : n . node . y )
400402 # 获取到可执行的子节点
401403 result_list = [{'node' : node , 'future' : executor .submit (self .run_chain_manage , node , None )} for node in
402- node_list ]
403- self .set_await_map (result_list )
404- [r .get ('future' ).result () for r in result_list ]
405-
406- def set_await_map (self , node_run_list ):
407- sorted_node_run_list = sorted (node_run_list , key = lambda n : n .get ('node' ).node .y )
408- for index in range (len (sorted_node_run_list )):
409- self .await_future_map [sorted_node_run_list [index ].get ('node' ).runtime_node_id ] = [
410- sorted_node_run_list [i ].get ('future' )
411- for i in range (index )]
404+ sorted_node_run_list ]
405+ try :
406+ self .lock .acquire ()
407+ for r in result_list :
408+ self .future_list .append (r .get ('future' ))
409+ finally :
410+ self .lock .release ()
412411
413412 def run_chain (self , current_node , node_result_future = None ):
414413 if node_result_future is None :
415414 node_result_future = self .run_node_future (current_node )
416415 try :
417416 is_stream = self .params .get ('stream' , True )
418- # 处理节点响应
419- await_future_list = self .await_future_map .get (current_node .runtime_node_id , None )
420- if await_future_list is not None :
421- [f .result () for f in await_future_list ]
422417 result = self .hand_event_node_result (current_node ,
423418 node_result_future ) if is_stream else self .hand_node_result (
424419 current_node , node_result_future )
@@ -434,16 +429,14 @@ def hand_node_result(self, current_node, node_result_future):
434429 if result is not None :
435430 # 阻塞获取结果
436431 list (result )
437- # 添加节点
438- self .node_context .append (current_node )
439432 return current_result
440433 except Exception as e :
441- # 添加节点
442- self .node_context .append (current_node )
443434 traceback .print_exc ()
444435 self .status = 500
445436 current_node .get_write_error_context (e )
446437 self .answer += str (e )
438+ finally :
439+ current_node .node_chunk .end ()
447440
448441 def append_node (self , current_node ):
449442 for index in range (len (self .node_context )):
@@ -454,15 +447,14 @@ def append_node(self, current_node):
454447 self .node_context .append (current_node )
455448
456449 def hand_event_node_result (self , current_node , node_result_future ):
457- node_chunk = NodeChunk ()
458450 real_node_id = current_node .runtime_node_id
459451 child_node = {}
452+ view_type = current_node .view_type
460453 try :
461454 current_result = node_result_future .result ()
462455 result = current_result .write_context (current_node , self )
463456 if result is not None :
464457 if self .is_result (current_node , current_result ):
465- self .node_chunk_manage .add_node_chunk (node_chunk )
466458 for r in result :
467459 content = r
468460 child_node = {}
@@ -487,26 +479,24 @@ def hand_event_node_result(self, current_node, node_result_future):
487479 'child_node' : child_node ,
488480 'node_is_end' : node_is_end ,
489481 'real_node_id' : real_node_id })
490- node_chunk .add_chunk (chunk )
491- chunk = self .base_to_response .to_stream_chunk_response (self .params ['chat_id' ],
492- self .params ['chat_record_id' ],
493- current_node .id ,
494- current_node .up_node_id_list ,
495- '' , False , 0 , 0 , {'node_is_end' : True ,
496- 'runtime_node_id' : current_node .runtime_node_id ,
497- 'node_type' : current_node .type ,
498- 'view_type' : view_type ,
499- 'child_node' : child_node ,
500- 'real_node_id' : real_node_id })
501- node_chunk .end (chunk )
482+ current_node .node_chunk .add_chunk (chunk )
483+ chunk = (self .base_to_response
484+ .to_stream_chunk_response (self .params ['chat_id' ],
485+ self .params ['chat_record_id' ],
486+ current_node .id ,
487+ current_node .up_node_id_list ,
488+ '' , False , 0 , 0 , {'node_is_end' : True ,
489+ 'runtime_node_id' : current_node .runtime_node_id ,
490+ 'node_type' : current_node .type ,
491+ 'view_type' : view_type ,
492+ 'child_node' : child_node ,
493+ 'real_node_id' : real_node_id }))
494+ current_node .node_chunk .add_chunk (chunk )
502495 else :
503496 list (result )
504- # 添加节点
505- self .append_node (current_node )
506497 return current_result
507498 except Exception as e :
508499 # 添加节点
509- self .append_node (current_node )
510500 traceback .print_exc ()
511501 chunk = self .base_to_response .to_stream_chunk_response (self .params ['chat_id' ],
512502 self .params ['chat_record_id' ],
@@ -519,12 +509,12 @@ def hand_event_node_result(self, current_node, node_result_future):
519509 'view_type' : current_node .view_type ,
520510 'child_node' : {},
521511 'real_node_id' : real_node_id })
522- if not self .node_chunk_manage .contains (node_chunk ):
523- self .node_chunk_manage .add_node_chunk (node_chunk )
524- node_chunk .end (chunk )
512+ current_node .node_chunk .add_chunk (chunk )
525513 current_node .get_write_error_context (e )
526514 self .status = 500
527515 return None
516+ finally :
517+ current_node .node_chunk .end ()
528518
529519 def run_node_async (self , node ):
530520 future = executor .submit (self .run_node , node )
@@ -636,6 +626,8 @@ def get_next_node(self):
636626
637627 @staticmethod
638628 def dependent_node (up_node_id , node ):
629+ if not node .node_chunk .is_end ():
630+ return False
639631 if node .id == up_node_id :
640632 if node .type == 'form-node' :
641633 if node .context .get ('form_data' , None ) is not None :
0 commit comments