|
1 | 1 | import logging |
2 | 2 | from abc import ABC, abstractmethod |
3 | 3 | from collections.abc import Generator |
| 4 | +from typing import Optional |
4 | 5 |
|
5 | 6 | from core.workflow.entities.variable_pool import VariablePool |
6 | 7 | from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent |
@@ -48,25 +49,35 @@ def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExcept |
48 | 49 | # we remove the node maybe shortcut the answer node, so comment this code for now |
49 | 50 | # there is not effect on the answer node and the workflow, when we have a better solution |
50 | 51 | # we can open this code. Issues: #11542 #9560 #10638 #10564 |
51 | | - ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id) |
52 | | - if "answer" in ids: |
53 | | - continue |
54 | | - else: |
55 | | - reachable_node_ids.extend(ids) |
| 52 | + # ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id) |
| 53 | + # if "answer" in ids: |
| 54 | + # continue |
| 55 | + # else: |
| 56 | + # reachable_node_ids.extend(ids) |
| 57 | + |
| 58 | + # The branch_identify parameter is added to ensure that |
| 59 | + # only nodes in the correct logical branch are included. |
| 60 | + ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle) |
| 61 | + reachable_node_ids.extend(ids) |
56 | 62 | else: |
57 | 63 | unreachable_first_node_ids.append(edge.target_node_id) |
58 | 64 |
|
59 | 65 | for node_id in unreachable_first_node_ids: |
60 | 66 | self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) |
61 | 67 |
|
62 | | - def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]: |
| 68 | + def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]: |
63 | 69 | node_ids = [] |
64 | 70 | for edge in self.graph.edge_mapping.get(node_id, []): |
65 | 71 | if edge.target_node_id == self.graph.root_node_id: |
66 | 72 | continue |
67 | 73 |
|
| 74 | + # Only follow edges that match the branch_identify or have no run_condition |
| 75 | + if edge.run_condition and edge.run_condition.branch_identify: |
| 76 | + if not branch_identify or edge.run_condition.branch_identify != branch_identify: |
| 77 | + continue |
| 78 | + |
68 | 79 | node_ids.append(edge.target_node_id) |
69 | | - node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) |
| 80 | + node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify)) |
70 | 81 | return node_ids |
71 | 82 |
|
72 | 83 | def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None: |
|
0 commit comments