2929from setting .models import Model
3030from setting .models_provider import get_model_credential
3131
32- executor = ThreadPoolExecutor (max_workers = 50 )
32+ executor = ThreadPoolExecutor (max_workers = 200 )
3333
3434
3535class Edge :
@@ -271,7 +271,7 @@ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandl
271271 self .current_result = None
272272 self .answer = ""
273273 self .answer_list = ['' ]
274- self .status = 0
274+ self .status = 200
275275 self .base_to_response = base_to_response
276276 self .chat_record = chat_record
277277 self .await_future_map = {}
@@ -384,8 +384,23 @@ def await_result(self, result):
384384 '' , True , message_tokens , answer_tokens , {})
385385
386386 def run_chain_async (self , current_node , node_result_future ):
387- future = executor .submit (self .run_chain , current_node , node_result_future )
388- return future
387+ return executor .submit (self .run_chain_manage , current_node , node_result_future )
388+
389+ def run_chain_manage (self , current_node , node_result_future ):
390+ if current_node is None :
391+ start_node = self .get_start_node ()
392+ current_node = get_node (start_node .type )(start_node , self .params , self )
393+ result = self .run_chain (current_node , node_result_future )
394+ node_list = self .get_next_node_list (current_node , result )
395+ if len (node_list ) == 1 :
396+ self .run_chain_manage (node_list [0 ], None )
397+ elif len (node_list ) > 1 :
398+
399+ # 获取到可执行的子节点
400+ result_list = [{'node' : node , 'future' : executor .submit (self .run_chain_manage , node , None )} for node in
401+ node_list ]
402+ self .set_await_map (result_list )
403+ [r .get ('future' ).result () for r in result_list ]
389404
390405 def set_await_map (self , node_run_list ):
391406 sorted_node_run_list = sorted (node_run_list , key = lambda n : n .get ('node' ).node .y )
@@ -395,9 +410,6 @@ def set_await_map(self, node_run_list):
395410 for i in range (index )]
396411
397412 def run_chain (self , current_node , node_result_future = None ):
398- if current_node is None :
399- start_node = self .get_start_node ()
400- current_node = get_node (start_node .type )(start_node , self .params , self )
401413 if node_result_future is None :
402414 node_result_future = self .run_node_future (current_node )
403415 try :
@@ -409,18 +421,10 @@ def run_chain(self, current_node, node_result_future=None):
409421 result = self .hand_event_node_result (current_node ,
410422 node_result_future ) if is_stream else self .hand_node_result (
411423 current_node , node_result_future )
412- with self .lock :
413- if current_node .status == 500 :
414- return
415- node_list = self .get_next_node_list (current_node , result )
416- # 获取到可执行的子节点
417- result_list = [{'node' : node , 'future' : self .run_chain_async (node , None )} for node in node_list ]
418- self .set_await_map (result_list )
419- [r .get ('future' ).result () for r in result_list ]
420- if self .status == 0 :
421- self .status = 200
424+ return result
422425 except Exception as e :
423426 traceback .print_exc ()
427+ return []
424428
425429 def hand_node_result (self , current_node , node_result_future ):
426430 try :
0 commit comments