Skip to content

Commit 8337c8a

Browse files
committed
refactor: Workflow execution logic
1 parent c000ee4 commit 8337c8a

File tree

1 file changed

+21
-17
lines changed

1 file changed

+21
-17
lines changed

apps/application/flow/workflow_manage.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from setting.models import Model
3030
from setting.models_provider import get_model_credential
3131

32-
executor = ThreadPoolExecutor(max_workers=50)
32+
executor = ThreadPoolExecutor(max_workers=200)
3333

3434

3535
class Edge:
@@ -271,7 +271,7 @@ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandl
271271
self.current_result = None
272272
self.answer = ""
273273
self.answer_list = ['']
274-
self.status = 0
274+
self.status = 200
275275
self.base_to_response = base_to_response
276276
self.chat_record = chat_record
277277
self.await_future_map = {}
@@ -384,8 +384,23 @@ def await_result(self, result):
384384
'', True, message_tokens, answer_tokens, {})
385385

386386
def run_chain_async(self, current_node, node_result_future):
387-
future = executor.submit(self.run_chain, current_node, node_result_future)
388-
return future
387+
return executor.submit(self.run_chain_manage, current_node, node_result_future)
388+
389+
def run_chain_manage(self, current_node, node_result_future):
390+
if current_node is None:
391+
start_node = self.get_start_node()
392+
current_node = get_node(start_node.type)(start_node, self.params, self)
393+
result = self.run_chain(current_node, node_result_future)
394+
node_list = self.get_next_node_list(current_node, result)
395+
if len(node_list) == 1:
396+
self.run_chain_manage(node_list[0], None)
397+
elif len(node_list) > 1:
398+
399+
# 获取到可执行的子节点
400+
result_list = [{'node': node, 'future': executor.submit(self.run_chain_manage, node, None)} for node in
401+
node_list]
402+
self.set_await_map(result_list)
403+
[r.get('future').result() for r in result_list]
389404

390405
def set_await_map(self, node_run_list):
391406
sorted_node_run_list = sorted(node_run_list, key=lambda n: n.get('node').node.y)
@@ -395,9 +410,6 @@ def set_await_map(self, node_run_list):
395410
for i in range(index)]
396411

397412
def run_chain(self, current_node, node_result_future=None):
398-
if current_node is None:
399-
start_node = self.get_start_node()
400-
current_node = get_node(start_node.type)(start_node, self.params, self)
401413
if node_result_future is None:
402414
node_result_future = self.run_node_future(current_node)
403415
try:
@@ -409,18 +421,10 @@ def run_chain(self, current_node, node_result_future=None):
409421
result = self.hand_event_node_result(current_node,
410422
node_result_future) if is_stream else self.hand_node_result(
411423
current_node, node_result_future)
412-
with self.lock:
413-
if current_node.status == 500:
414-
return
415-
node_list = self.get_next_node_list(current_node, result)
416-
# 获取到可执行的子节点
417-
result_list = [{'node': node, 'future': self.run_chain_async(node, None)} for node in node_list]
418-
self.set_await_map(result_list)
419-
[r.get('future').result() for r in result_list]
420-
if self.status == 0:
421-
self.status = 200
424+
return result
422425
except Exception as e:
423426
traceback.print_exc()
427+
return []
424428

425429
def hand_node_result(self, current_node, node_result_future):
426430
try:

0 commit comments

Comments
 (0)