Skip to content

Commit 54b5b80

Browse files
authored
fix(workflow): fix answer node stream processing in conditional branches (langgenius#12510)
1 parent 831459b commit 54b5b80

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

api/core/workflow/nodes/answer/base_stream_processor.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
from abc import ABC, abstractmethod
33
from collections.abc import Generator
4+
from typing import Optional
45

56
from core.workflow.entities.variable_pool import VariablePool
67
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
@@ -48,25 +49,35 @@ def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExcept
4849
# we remove the node maybe shortcut the answer node, so comment this code for now
4950
# there is not effect on the answer node and the workflow, when we have a better solution
5051
# we can open this code. Issues: #11542 #9560 #10638 #10564
51-
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id)
52-
if "answer" in ids:
53-
continue
54-
else:
55-
reachable_node_ids.extend(ids)
52+
# ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id)
53+
# if "answer" in ids:
54+
# continue
55+
# else:
56+
# reachable_node_ids.extend(ids)
57+
58+
# The branch_identify parameter is added to ensure that
59+
# only nodes in the correct logical branch are included.
60+
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle)
61+
reachable_node_ids.extend(ids)
5662
else:
5763
unreachable_first_node_ids.append(edge.target_node_id)
5864

5965
for node_id in unreachable_first_node_ids:
6066
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
6167

62-
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
68+
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
6369
node_ids = []
6470
for edge in self.graph.edge_mapping.get(node_id, []):
6571
if edge.target_node_id == self.graph.root_node_id:
6672
continue
6773

74+
# Only follow edges that match the branch_identify or have no run_condition
75+
if edge.run_condition and edge.run_condition.branch_identify:
76+
if not branch_identify or edge.run_condition.branch_identify != branch_identify:
77+
continue
78+
6879
node_ids.append(edge.target_node_id)
69-
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
80+
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify))
7081
return node_ids
7182

7283
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:

0 commit comments

Comments
 (0)