Skip to content

Commit 2e2c87c

Browse files
authored
fix(graph_engine): error strategy fall. (#26078)
Signed-off-by: -LAN- <laipz8200@outlook.com>
1 parent f4522fd commit 2e2c87c

File tree

8 files changed

+255
-84
lines changed

8 files changed

+255
-84
lines changed

api/core/workflow/graph_engine/domain/graph_execution.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ class GraphExecutionState(BaseModel):
4141
completed: bool = Field(default=False)
4242
aborted: bool = Field(default=False)
4343
error: GraphExecutionErrorState | None = Field(default=None)
44-
node_executions: list[NodeExecutionState] = Field(default_factory=list)
44+
exceptions_count: int = Field(default=0)
45+
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
4546

4647

4748
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
@@ -103,7 +104,8 @@ class GraphExecution:
103104
completed: bool = False
104105
aborted: bool = False
105106
error: Exception | None = None
106-
node_executions: dict[str, NodeExecution] = field(default_factory=dict)
107+
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
108+
exceptions_count: int = 0
107109

108110
def start(self) -> None:
109111
"""Mark the graph execution as started."""
@@ -172,6 +174,7 @@ def dumps(self) -> str:
172174
completed=self.completed,
173175
aborted=self.aborted,
174176
error=_serialize_error(self.error),
177+
exceptions_count=self.exceptions_count,
175178
node_executions=node_states,
176179
)
177180

@@ -195,6 +198,7 @@ def loads(self, data: str) -> None:
195198
self.completed = state.completed
196199
self.aborted = state.aborted
197200
self.error = _deserialize_error(state.error)
201+
self.exceptions_count = state.exceptions_count
198202
self.node_executions = {
199203
item.node_id: NodeExecution(
200204
node_id=item.node_id,
@@ -205,3 +209,7 @@ def loads(self, data: str) -> None:
205209
)
206210
for item in state.node_executions
207211
}
212+
213+
def record_node_failure(self) -> None:
214+
"""Increment the count of node failures encountered during execution."""
215+
self.exceptions_count += 1

api/core/workflow/graph_engine/event_management/event_handlers.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
"""
44

55
import logging
6+
from collections.abc import Mapping
67
from functools import singledispatchmethod
78
from typing import TYPE_CHECKING, final
89

910
from core.workflow.entities import GraphRuntimeState
10-
from core.workflow.enums import NodeExecutionType
11+
from core.workflow.enums import ErrorStrategy, NodeExecutionType
1112
from core.workflow.graph import Graph
1213
from core.workflow.graph_events import (
1314
GraphNodeEventBase,
@@ -122,13 +123,15 @@ def _(self, event: NodeRunStartedEvent) -> None:
122123
"""
123124
# Track execution in domain model
124125
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
126+
is_initial_attempt = node_execution.retry_count == 0
125127
node_execution.mark_started(event.id)
126128

127129
# Track in response coordinator for stream ordering
128130
self._response_coordinator.track_node_execution(event.node_id, event.id)
129131

130-
# Collect the event
131-
self._event_collector.collect(event)
132+
# Collect the event only for the first attempt; retries remain silent
133+
if is_initial_attempt:
134+
self._event_collector.collect(event)
132135

133136
@_dispatch.register
134137
def _(self, event: NodeRunStreamChunkEvent) -> None:
@@ -161,7 +164,7 @@ def _(self, event: NodeRunSucceededEvent) -> None:
161164
node_execution.mark_taken()
162165

163166
# Store outputs in variable pool
164-
self._store_node_outputs(event)
167+
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
165168

166169
# Forward to response coordinator and emit streaming events
167170
streaming_events = self._response_coordinator.intercept_event(event)
@@ -191,7 +194,7 @@ def _(self, event: NodeRunSucceededEvent) -> None:
191194

192195
# Handle response node outputs
193196
if node.execution_type == NodeExecutionType.RESPONSE:
194-
self._update_response_outputs(event)
197+
self._update_response_outputs(event.node_run_result.outputs)
195198

196199
# Collect the event
197200
self._event_collector.collect(event)
@@ -207,6 +210,7 @@ def _(self, event: NodeRunFailedEvent) -> None:
207210
# Update domain model
208211
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
209212
node_execution.mark_failed(event.error)
213+
self._graph_execution.record_node_failure()
210214

211215
result = self._error_handler.handle_node_failure(event)
212216

@@ -227,10 +231,40 @@ def _(self, event: NodeRunExceptionEvent) -> None:
227231
Args:
228232
event: The node exception event
229233
"""
230-
# Node continues via fail-branch, so it's technically "succeeded"
234+
# Node continues via fail-branch/default-value, treat as completion
231235
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
232236
node_execution.mark_taken()
233237

238+
# Persist outputs produced by the exception strategy (e.g. default values)
239+
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
240+
241+
node = self._graph.nodes[event.node_id]
242+
243+
if node.error_strategy == ErrorStrategy.DEFAULT_VALUE:
244+
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
245+
elif node.error_strategy == ErrorStrategy.FAIL_BRANCH:
246+
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
247+
event.node_id, event.node_run_result.edge_source_handle
248+
)
249+
else:
250+
raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}")
251+
252+
for edge_event in edge_streaming_events:
253+
self._event_collector.collect(edge_event)
254+
255+
for node_id in ready_nodes:
256+
self._state_manager.enqueue_node(node_id)
257+
self._state_manager.start_execution(node_id)
258+
259+
# Update response outputs if applicable
260+
if node.execution_type == NodeExecutionType.RESPONSE:
261+
self._update_response_outputs(event.node_run_result.outputs)
262+
263+
self._state_manager.finish_execution(event.node_id)
264+
265+
# Collect the exception event for observers
266+
self._event_collector.collect(event)
267+
234268
@_dispatch.register
235269
def _(self, event: NodeRunRetryEvent) -> None:
236270
"""
@@ -242,21 +276,31 @@ def _(self, event: NodeRunRetryEvent) -> None:
242276
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
243277
node_execution.increment_retry()
244278

245-
def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
279+
# Finish the previous attempt before re-queuing the node
280+
self._state_manager.finish_execution(event.node_id)
281+
282+
# Emit retry event for observers
283+
self._event_collector.collect(event)
284+
285+
# Re-queue node for execution
286+
self._state_manager.enqueue_node(event.node_id)
287+
self._state_manager.start_execution(event.node_id)
288+
289+
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
246290
"""
247291
Store node outputs in the variable pool.
248292
249293
Args:
250294
event: The node succeeded event containing outputs
251295
"""
252-
for variable_name, variable_value in event.node_run_result.outputs.items():
253-
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
296+
for variable_name, variable_value in outputs.items():
297+
self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
254298

255-
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
299+
def _update_response_outputs(self, outputs: Mapping[str, object]) -> None:
256300
"""Update response outputs for response nodes."""
257301
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs
258302
# in runtime state, rather than allowing nodes to directly access runtime state.
259-
for key, value in event.node_run_result.outputs.items():
303+
for key, value in outputs.items():
260304
if key == "answer":
261305
existing = self._graph_runtime_state.get_output("answer", "")
262306
if existing:

api/core/workflow/graph_engine/graph_engine.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
GraphNodeEventBase,
2424
GraphRunAbortedEvent,
2525
GraphRunFailedEvent,
26+
GraphRunPartialSucceededEvent,
2627
GraphRunStartedEvent,
2728
GraphRunSucceededEvent,
2829
)
@@ -260,12 +261,23 @@ def run(self) -> Generator[GraphEngineEvent, None, None]:
260261
if self._graph_execution.error:
261262
raise self._graph_execution.error
262263
else:
263-
yield GraphRunSucceededEvent(
264-
outputs=self._graph_runtime_state.outputs,
265-
)
264+
outputs = self._graph_runtime_state.outputs
265+
exceptions_count = self._graph_execution.exceptions_count
266+
if exceptions_count > 0:
267+
yield GraphRunPartialSucceededEvent(
268+
exceptions_count=exceptions_count,
269+
outputs=outputs,
270+
)
271+
else:
272+
yield GraphRunSucceededEvent(
273+
outputs=outputs,
274+
)
266275

267276
except Exception as e:
268-
yield GraphRunFailedEvent(error=str(e))
277+
yield GraphRunFailedEvent(
278+
error=str(e),
279+
exceptions_count=self._graph_execution.exceptions_count,
280+
)
269281
raise
270282

271283
finally:

api/core/workflow/graph_engine/layers/debug_logging.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
GraphEngineEvent,
1616
GraphRunAbortedEvent,
1717
GraphRunFailedEvent,
18+
GraphRunPartialSucceededEvent,
1819
GraphRunStartedEvent,
1920
GraphRunSucceededEvent,
2021
NodeRunExceptionEvent,
@@ -127,6 +128,13 @@ def on_event(self, event: GraphEngineEvent) -> None:
127128
if self.include_outputs and event.outputs:
128129
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
129130

131+
elif isinstance(event, GraphRunPartialSucceededEvent):
132+
self.logger.warning("⚠️ Graph run partially succeeded")
133+
if event.exceptions_count > 0:
134+
self.logger.warning(" Total exceptions: %s", event.exceptions_count)
135+
if self.include_outputs and event.outputs:
136+
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
137+
130138
elif isinstance(event, GraphRunFailedEvent):
131139
self.logger.error("❌ Graph run failed: %s", event.error)
132140
if event.exceptions_count > 0:

api/core/workflow/nodes/iteration/iteration_node.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from core.workflow.graph_events import (
2020
GraphNodeEventBase,
2121
GraphRunFailedEvent,
22+
GraphRunPartialSucceededEvent,
2223
GraphRunSucceededEvent,
2324
)
2425
from core.workflow.node_events import (
@@ -456,7 +457,7 @@ def _run_single_iter(
456457
if isinstance(event, GraphNodeEventBase):
457458
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
458459
yield event
459-
elif isinstance(event, GraphRunSucceededEvent):
460+
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
460461
result = variable_pool.get(self._node_data.output_selector)
461462
if result is None:
462463
outputs.append(None)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
"""Tests for graph engine event handlers."""
2+
3+
from __future__ import annotations
4+
5+
from datetime import datetime
6+
7+
from core.workflow.entities import GraphRuntimeState, VariablePool
8+
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
9+
from core.workflow.graph import Graph
10+
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
11+
from core.workflow.graph_engine.event_management.event_handlers import EventHandler
12+
from core.workflow.graph_engine.event_management.event_manager import EventManager
13+
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
14+
from core.workflow.graph_engine.ready_queue.in_memory import InMemoryReadyQueue
15+
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
16+
from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent
17+
from core.workflow.node_events import NodeRunResult
18+
from core.workflow.nodes.base.entities import RetryConfig
19+
20+
21+
class _StubEdgeProcessor:
22+
"""Minimal edge processor stub for tests."""
23+
24+
25+
class _StubErrorHandler:
26+
"""Minimal error handler stub for tests."""
27+
28+
29+
class _StubNode:
30+
"""Simple node stub exposing the attributes needed by the state manager."""
31+
32+
def __init__(self, node_id: str) -> None:
33+
self.id = node_id
34+
self.state = NodeState.UNKNOWN
35+
self.title = "Stub Node"
36+
self.execution_type = NodeExecutionType.EXECUTABLE
37+
self.error_strategy = None
38+
self.retry_config = RetryConfig()
39+
self.retry = False
40+
41+
42+
def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]:
43+
"""Construct an EventHandler with in-memory dependencies for testing."""
44+
45+
node = _StubNode(node_id)
46+
graph = Graph(nodes={node_id: node}, edges={}, in_edges={}, out_edges={}, root_node=node)
47+
48+
variable_pool = VariablePool()
49+
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
50+
graph_execution = GraphExecution(workflow_id="test-workflow")
51+
52+
event_manager = EventManager()
53+
state_manager = GraphStateManager(graph=graph, ready_queue=InMemoryReadyQueue())
54+
response_coordinator = ResponseStreamCoordinator(variable_pool=variable_pool, graph=graph)
55+
56+
handler = EventHandler(
57+
graph=graph,
58+
graph_runtime_state=runtime_state,
59+
graph_execution=graph_execution,
60+
response_coordinator=response_coordinator,
61+
event_collector=event_manager,
62+
edge_processor=_StubEdgeProcessor(),
63+
state_manager=state_manager,
64+
error_handler=_StubErrorHandler(),
65+
)
66+
67+
return handler, event_manager, graph_execution
68+
69+
70+
def test_retry_does_not_emit_additional_start_event() -> None:
71+
"""Ensure retry attempts do not produce duplicate start events."""
72+
73+
node_id = "test-node"
74+
handler, event_manager, graph_execution = _build_event_handler(node_id)
75+
76+
execution_id = "exec-1"
77+
node_type = NodeType.CODE
78+
start_time = datetime.utcnow()
79+
80+
start_event = NodeRunStartedEvent(
81+
id=execution_id,
82+
node_id=node_id,
83+
node_type=node_type,
84+
node_title="Stub Node",
85+
start_at=start_time,
86+
)
87+
handler.dispatch(start_event)
88+
89+
retry_event = NodeRunRetryEvent(
90+
id=execution_id,
91+
node_id=node_id,
92+
node_type=node_type,
93+
node_title="Stub Node",
94+
start_at=start_time,
95+
error="boom",
96+
retry_index=1,
97+
node_run_result=NodeRunResult(
98+
status=WorkflowNodeExecutionStatus.FAILED,
99+
error="boom",
100+
error_type="TestError",
101+
),
102+
)
103+
handler.dispatch(retry_event)
104+
105+
# Simulate the node starting execution again after retry
106+
second_start_event = NodeRunStartedEvent(
107+
id=execution_id,
108+
node_id=node_id,
109+
node_type=node_type,
110+
node_title="Stub Node",
111+
start_at=start_time,
112+
)
113+
handler.dispatch(second_start_event)
114+
115+
collected_types = [type(event) for event in event_manager._events] # type: ignore[attr-defined]
116+
117+
assert collected_types == [NodeRunStartedEvent, NodeRunRetryEvent]
118+
119+
node_execution = graph_execution.get_or_create_node_execution(node_id)
120+
assert node_execution.retry_count == 1

0 commit comments

Comments
 (0)