@@ -108,14 +108,18 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
108
108
109
109
messages = input .messages or []
110
110
forwarded_props = input .forwarded_props
111
-
111
+ node_name_input = forwarded_props .get ('node_name' , None ) if forwarded_props else None
112
+
112
113
self .active_run ["manually_emitted_state" ] = None
113
- self .active_run ["node_name" ] = forwarded_props .get ('node_name' , None ) if forwarded_props else None
114
+ self .active_run ["node_name" ] = node_name_input
115
+ if self .active_run ["node_name" ] == "__end__" :
116
+ self .active_run ["node_name" ] = None
114
117
115
118
config = ensure_config (self .config .copy () if self .config else {})
116
119
config ["configurable" ] = {** (config .get ('configurable' , {})), "thread_id" : thread_id }
117
120
118
121
agent_state = await self .graph .aget_state (config )
122
+ self .active_run ["mode" ] = "continue" if thread_id and self .active_run .get ("node_name" ) != "__end__" and self .active_run .get ("node_name" ) else "start"
119
123
prepared_stream_response = await self .prepare_stream (input = input , agent_state = agent_state , config = config )
120
124
121
125
yield self ._dispatch_event (
@@ -151,48 +155,48 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
151
155
return
152
156
153
157
should_exit = False
154
- latest_state_values = state
158
+ current_graph_state = state
155
159
async for event in stream :
156
160
if event ["event" ] == "error" :
157
161
yield self ._dispatch_event (
158
162
RunErrorEvent (type = EventType .RUN_ERROR , message = event ["data" ]["message" ], raw_event = event )
159
163
)
160
164
break
161
165
162
- if event ["event" ] == "values" :
163
- latest_state_values = event ["data" ]
164
- continue
165
-
166
- if event ["event" ] == "updates" :
167
- continue
168
-
169
166
current_node_name = event .get ("metadata" , {}).get ("langgraph_node" )
170
167
event_type = event .get ("event" )
171
168
self .active_run ["id" ] = event .get ("run_id" )
169
+ exiting_node = False
170
+
171
+ if event_type == "on_chain_end" and isinstance (
172
+ event .get ("data" , {}).get ("output" ), dict
173
+ ):
174
+ current_graph_state .update (event ["data" ]["output" ])
175
+ exiting_node = self .active_run ["node_name" ] == current_node_name
172
176
173
177
should_exit = should_exit or (
174
178
event_type == "on_custom_event" and
175
179
event ["name" ] == "exit"
176
180
)
177
181
178
182
if current_node_name and current_node_name != self .active_run .get ("node_name" ):
179
- if self .active_run . get ( "node_name" ) :
183
+ if self .active_run [ "node_name" ] and self . active_run [ "node_name" ] != node_name_input :
180
184
yield self ._dispatch_event (
181
185
StepFinishedEvent (type = EventType .STEP_FINISHED , step_name = self .active_run ["node_name" ])
182
186
)
183
-
184
- if current_node_name :
185
- yield self ._dispatch_event (
186
- StepStartedEvent (type = EventType .STEP_STARTED , step_name = current_node_name )
187
- )
188
- self .active_run ["node_name" ] = current_node_name
187
+ self .active_run ["node_name" ] = None
189
188
190
- updated_state = self .active_run .get ("manually_emitted_state" ) or latest_state_values
191
- has_state_diff = updated_state != state
189
+ yield self ._dispatch_event (
190
+ StepStartedEvent (type = EventType .STEP_STARTED , step_name = current_node_name )
191
+ )
192
+ self .active_run ["node_name" ] = current_node_name
192
193
193
- if has_state_diff and not self .get_message_in_progress (self .active_run ["id" ]):
194
+ updated_state = self .active_run .get ("manually_emitted_state" ) or current_graph_state
195
+ has_state_diff = updated_state != state
196
+ if exiting_node or (has_state_diff and not self .get_message_in_progress (self .active_run ["id" ])):
194
197
state = updated_state
195
198
self .active_run ["prev_node_name" ] = self .active_run ["node_name" ]
199
+ current_graph_state .update (updated_state )
196
200
yield self ._dispatch_event (
197
201
StateSnapshotEvent (
198
202
type = EventType .STATE_SNAPSHOT ,
@@ -208,18 +212,17 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
208
212
async for single_event in self ._handle_single_event (event , state ):
209
213
yield single_event
210
214
211
- state_after_run = await self .graph .aget_state (config )
212
- tasks = state_after_run .tasks
213
- interrupts = tasks [0 ].interrupts if tasks and len (tasks ) > 0 else []
215
+ state = await self .graph .aget_state (config )
214
216
215
- if interrupts :
216
- self .active_run ["node_name" ] = self .active_run ["node_name" ]
217
- elif "writes" in state_after_run .metadata and state_after_run .metadata ["writes" ]:
218
- self .active_run ["node_name" ] = list (state_after_run .metadata ["writes" ].keys ())[0 ]
219
- elif hasattr (state_after_run , "next" ) and state_after_run .next and state_after_run .next [0 ]:
220
- self .active_run ["node_name" ] = state_after_run .next [0 ]
221
- else :
222
- self .active_run ["node_name" ] = "__end__"
217
+ tasks = state .tasks if len (state .tasks ) > 0 else None
218
+ interrupts = tasks [0 ].interrupts if tasks else []
219
+
220
+ writes = state .metadata .get ("writes" , {}) or {}
221
+ node_name = self .active_run ["node_name" ] if interrupts else next (iter (writes ), None )
222
+ next_nodes = state .next or ()
223
+ is_end_node = len (next_nodes ) == 0 and not interrupts
224
+
225
+ node_name = "__end__" if is_end_node else node_name
223
226
224
227
for interrupt in interrupts :
225
228
yield self ._dispatch_event (
@@ -231,25 +234,38 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
231
234
)
232
235
)
233
236
237
+ if self .active_run .get ("node_name" ) != node_name :
238
+ yield self ._dispatch_event (
239
+ StepFinishedEvent (type = EventType .STEP_FINISHED , step_name = self .active_run ["node_name" ])
240
+ )
241
+ self .active_run ["node_name" ] = node_name
242
+ yield self ._dispatch_event (
243
+ StepStartedEvent (type = EventType .STEP_STARTED , step_name = self .active_run ["node_name" ])
244
+ )
245
+
246
+ # if tasks is None:
234
247
yield self ._dispatch_event (
235
- StateSnapshotEvent (type = EventType .STATE_SNAPSHOT , snapshot = self .get_state_snapshot ( state_after_run . values ) )
248
+ StepFinishedEvent (type = EventType .STEP_FINISHED , step_name = self .active_run [ "node_name" ] )
236
249
)
237
-
250
+ self .active_run ["node_name" ] = None
251
+
252
+ state_values = state .values if state .values else state
253
+ yield self ._dispatch_event (
254
+ StateSnapshotEvent (type = EventType .STATE_SNAPSHOT , snapshot = self .get_state_snapshot (state_values ))
255
+ )
256
+
238
257
yield self ._dispatch_event (
239
258
MessagesSnapshotEvent (
240
259
type = EventType .MESSAGES_SNAPSHOT ,
241
- messages = langchain_messages_to_agui (state_after_run . values .get ("messages" , [])),
260
+ messages = langchain_messages_to_agui (state_values .get ("messages" , [])),
242
261
)
243
262
)
244
263
245
- if self .active_run .get ("node_name" ):
246
- yield self ._dispatch_event (
247
- StepFinishedEvent (type = EventType .STEP_FINISHED , step_name = self .active_run ["node_name" ])
248
- )
249
-
250
264
yield self ._dispatch_event (
251
265
RunFinishedEvent (type = EventType .RUN_FINISHED , thread_id = thread_id , run_id = self .active_run ["id" ])
252
266
)
267
+ self .active_run = None
268
+
253
269
254
270
async def prepare_stream (self , input : RunAgentInput , agent_state : State , config : RunnableConfig ):
255
271
state_input = input .state or {}
@@ -259,7 +275,6 @@ async def prepare_stream(self, input: RunAgentInput, agent_state: State, config:
259
275
thread_id = input .thread_id
260
276
261
277
state_input ["messages" ] = agent_state .values .get ("messages" , [])
262
- # TODO: validate if we need current graph state
263
278
self .active_run ["current_graph_state" ] = agent_state .values
264
279
langchain_messages = agui_messages_to_langchain (messages )
265
280
state = self .langgraph_default_merge_state (state_input , langchain_messages , tools )
@@ -295,9 +310,7 @@ async def prepare_stream(self, input: RunAgentInput, agent_state: State, config:
295
310
"events_to_dispatch" : events_to_dispatch ,
296
311
}
297
312
298
- mode = "continue" if thread_id and self .active_run .get ("node_name" ) != "__end__" and self .active_run .get ("node_name" ) else "start"
299
-
300
- if mode == "continue" :
313
+ if self .active_run ["mode" ] == "continue" :
301
314
await self .graph .aupdate_state (config , state , as_node = self .active_run .get ("node_name" ))
302
315
303
316
self .active_run ["schema_keys" ] = self .get_schema_keys (config )
@@ -306,7 +319,7 @@ async def prepare_stream(self, input: RunAgentInput, agent_state: State, config:
306
319
stream_input = Command (resume = resume_input )
307
320
else :
308
321
payload_input = get_stream_payload_input (
309
- mode = mode ,
322
+ mode = self . active_run [ " mode" ] ,
310
323
state = state ,
311
324
schema_keys = self .active_run ["schema_keys" ],
312
325
)
@@ -466,25 +479,22 @@ async def _handle_single_event(self, event: Any, state: State) -> AsyncGenerator
466
479
)
467
480
468
481
if is_tool_call_end_event :
469
- resolved = self ._dispatch_event (
482
+ yield self ._dispatch_event (
470
483
ToolCallEndEvent (type = EventType .TOOL_CALL_END , tool_call_id = current_stream ["tool_call_id" ], raw_event = event )
471
484
)
472
- if resolved :
473
- self .messages_in_process [self .active_run ["id" ]] = None
474
- yield resolved
485
+ self .messages_in_process [self .active_run ["id" ]] = None
475
486
return
476
487
488
+
477
489
if is_message_end_event :
478
- resolved = self ._dispatch_event (
490
+ yield self ._dispatch_event (
479
491
TextMessageEndEvent (type = EventType .TEXT_MESSAGE_END , message_id = current_stream ["id" ], raw_event = event )
480
492
)
481
- if resolved :
482
- self .messages_in_process [self .active_run ["id" ]] = None
483
- yield resolved
493
+ self .messages_in_process [self .active_run ["id" ]] = None
484
494
return
485
495
486
496
if is_tool_call_start_event and should_emit_tool_calls :
487
- resolved = self ._dispatch_event (
497
+ yield self ._dispatch_event (
488
498
ToolCallStartEvent (
489
499
type = EventType .TOOL_CALL_START ,
490
500
tool_call_id = tool_call_data ["id" ],
@@ -493,12 +503,10 @@ async def _handle_single_event(self, event: Any, state: State) -> AsyncGenerator
493
503
raw_event = event ,
494
504
)
495
505
)
496
- if resolved :
497
- self .set_message_in_progress (
498
- self .active_run ["id" ],
499
- MessageInProgress (id = event ["data" ]["chunk" ].id , tool_call_id = tool_call_data ["id" ], tool_call_name = tool_call_data ["name" ])
500
- )
501
- yield resolved
506
+ self .set_message_in_progress (
507
+ self .active_run ["id" ],
508
+ MessageInProgress (id = event ["data" ]["chunk" ].id , tool_call_id = tool_call_data ["id" ], tool_call_name = tool_call_data ["name" ])
509
+ )
502
510
return
503
511
504
512
if is_tool_call_args_event and should_emit_tool_calls :
@@ -672,3 +680,4 @@ async def get_checkpoint_before_message(self, message_id: str, thread_id: str):
672
680
return history_list [idx - 1 ] # return one snapshot *before* the one that includes the message
673
681
674
682
raise ValueError ("Message ID not found in history" )
683
+
0 commit comments