1
1
import uuid
2
2
import json
3
3
from typing import Optional , List , Any , Union , AsyncGenerator , Generator
4
+ from dataclasses import is_dataclass , asdict
5
+ from datetime import date , datetime
4
6
5
7
from langgraph .graph .state import CompiledStateGraph
6
8
from langchain .schema import BaseMessage , SystemMessage
@@ -85,6 +87,7 @@ def __init__(self, *, name: str, graph: CompiledStateGraph, description: Optiona
85
87
self .messages_in_process : MessagesInProgressRecord = {}
86
88
self .active_run : Optional [RunMetadata ] = None
87
89
self .constant_schema_keys = ['messages' , 'tools' ]
90
+ self .active_step = None
88
91
89
92
def _dispatch_event (self , event : ProcessedEvents ) -> str :
90
93
return event # Fallback if no encoder
@@ -135,9 +138,8 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
135
138
136
139
# In case of resume (interrupt), re-start resumed step
137
140
if resume_input and self .active_run .get ("node_name" ):
138
- yield self ._dispatch_event (
139
- StepStartedEvent (type = EventType .STEP_STARTED , step_name = self .active_run .get ("node_name" ))
140
- )
141
+ for ev in self .start_step (self .active_run .get ("node_name" )):
142
+ yield ev
141
143
142
144
state = prepared_stream_response ["state" ]
143
145
stream = prepared_stream_response ["stream" ]
@@ -151,6 +153,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
151
153
152
154
should_exit = False
153
155
current_graph_state = state
156
+
154
157
async for event in stream :
155
158
if event ["event" ] == "error" :
156
159
yield self ._dispatch_event (
@@ -175,16 +178,8 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
175
178
)
176
179
177
180
if current_node_name and current_node_name != self .active_run .get ("node_name" ):
178
- if self .active_run ["node_name" ] and self .active_run ["node_name" ] != node_name_input :
179
- yield self ._dispatch_event (
180
- StepFinishedEvent (type = EventType .STEP_FINISHED , step_name = self .active_run ["node_name" ])
181
- )
182
- self .active_run ["node_name" ] = None
183
-
184
- yield self ._dispatch_event (
185
- StepStartedEvent (type = EventType .STEP_STARTED , step_name = current_node_name )
186
- )
187
- self .active_run ["node_name" ] = current_node_name
181
+ for ev in self .start_step (current_node_name ):
182
+ yield ev
188
183
189
184
updated_state = self .active_run .get ("manually_emitted_state" ) or current_graph_state
190
185
has_state_diff = updated_state != state
@@ -224,19 +219,14 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
224
219
CustomEvent (
225
220
type = EventType .CUSTOM ,
226
221
name = LangGraphEventTypes .OnInterrupt .value ,
227
- value = json .dumps (interrupt .value ) if not isinstance (interrupt .value , str ) else interrupt .value ,
222
+ value = json .dumps (interrupt .value , default = make_json_safe ) if not isinstance (interrupt .value , str ) else interrupt .value ,
228
223
raw_event = interrupt ,
229
224
)
230
225
)
231
226
232
227
if self .active_run .get ("node_name" ) != node_name :
233
- yield self ._dispatch_event (
234
- StepFinishedEvent (type = EventType .STEP_FINISHED , step_name = self .active_run ["node_name" ])
235
- )
236
- self .active_run ["node_name" ] = node_name
237
- yield self ._dispatch_event (
238
- StepStartedEvent (type = EventType .STEP_STARTED , step_name = self .active_run ["node_name" ])
239
- )
228
+ for ev in self .start_step (node_name ):
229
+ yield ev
240
230
241
231
state_values = state .values if state .values else state
242
232
yield self ._dispatch_event (
@@ -250,10 +240,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
250
240
)
251
241
)
252
242
253
- yield self ._dispatch_event (
254
- StepFinishedEvent (type = EventType .STEP_FINISHED , step_name = self .active_run ["node_name" ])
255
- )
256
- self .active_run ["node_name" ] = None
243
+ yield self .end_step ()
257
244
258
245
yield self ._dispatch_event (
259
246
RunFinishedEvent (type = EventType .RUN_FINISHED , thread_id = thread_id , run_id = self .active_run ["id" ])
@@ -700,3 +687,43 @@ async def get_checkpoint_before_message(self, message_id: str, thread_id: str):
700
687
701
688
raise ValueError ("Message ID not found in history" )
702
689
690
+ def start_step (self , step_name : str ):
691
+ if self .active_step :
692
+ yield self .end_step ()
693
+
694
+ yield self ._dispatch_event (
695
+ StepStartedEvent (
696
+ type = EventType .STEP_STARTED ,
697
+ step_name = step_name
698
+ )
699
+ )
700
+ self .active_run ["node_name" ] = step_name
701
+ self .active_step = step_name
702
+
703
+ def end_step (self ):
704
+ if self .active_step is None :
705
+ raise ValueError ("No active step to end" )
706
+
707
+ dispatch = self ._dispatch_event (
708
+ StepFinishedEvent (
709
+ type = EventType .STEP_FINISHED ,
710
+ step_name = self .active_run ["node_name" ]
711
+ )
712
+ )
713
+
714
+ self .active_run ["node_name" ] = None
715
+ self .active_step = None
716
+ return dispatch
717
+
718
+ def make_json_safe (o ):
719
+ if is_dataclass (o ): # dataclasses like Flight(...)
720
+ return asdict (o )
721
+ if hasattr (o , "model_dump" ): # pydantic v2
722
+ return o .model_dump ()
723
+ if hasattr (o , "dict" ): # pydantic v1
724
+ return o .dict ()
725
+ if hasattr (o , "__dict__" ): # plain objects
726
+ return vars (o )
727
+ if isinstance (o , (datetime , date )):
728
+ return o .isoformat ()
729
+ return str (o ) # last resort
0 commit comments