11import uuid
22import json
33from typing import Optional , List , Any , Union , AsyncGenerator , Generator
4+ from dataclasses import is_dataclass , asdict
5+ from datetime import date , datetime
46
57from langgraph .graph .state import CompiledStateGraph
68from langchain .schema import BaseMessage , SystemMessage
@@ -85,6 +87,7 @@ def __init__(self, *, name: str, graph: CompiledStateGraph, description: Optiona
8587 self .messages_in_process : MessagesInProgressRecord = {}
8688 self .active_run : Optional [RunMetadata ] = None
8789 self .constant_schema_keys = ['messages' , 'tools' ]
90+ self .active_step = None
8891
8992 def _dispatch_event (self , event : ProcessedEvents ) -> str :
9093 return event # Fallback if no encoder
@@ -135,9 +138,8 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
135138
136139 # In case of resume (interrupt), re-start resumed step
137140 if resume_input and self .active_run .get ("node_name" ):
138- yield self ._dispatch_event (
139- StepStartedEvent (type = EventType .STEP_STARTED , step_name = self .active_run .get ("node_name" ))
140- )
141+ for ev in self .start_step (self .active_run .get ("node_name" )):
142+ yield ev
141143
142144 state = prepared_stream_response ["state" ]
143145 stream = prepared_stream_response ["stream" ]
@@ -151,6 +153,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
151153
152154 should_exit = False
153155 current_graph_state = state
156+
154157 async for event in stream :
155158 if event ["event" ] == "error" :
156159 yield self ._dispatch_event (
@@ -175,16 +178,8 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
175178 )
176179
177180 if current_node_name and current_node_name != self .active_run .get ("node_name" ):
178- if self .active_run ["node_name" ] and self .active_run ["node_name" ] != node_name_input :
179- yield self ._dispatch_event (
180- StepFinishedEvent (type = EventType .STEP_FINISHED , step_name = self .active_run ["node_name" ])
181- )
182- self .active_run ["node_name" ] = None
183-
184- yield self ._dispatch_event (
185- StepStartedEvent (type = EventType .STEP_STARTED , step_name = current_node_name )
186- )
187- self .active_run ["node_name" ] = current_node_name
181+ for ev in self .start_step (current_node_name ):
182+ yield ev
188183
189184 updated_state = self .active_run .get ("manually_emitted_state" ) or current_graph_state
190185 has_state_diff = updated_state != state
@@ -224,19 +219,14 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
224219 CustomEvent (
225220 type = EventType .CUSTOM ,
226221 name = LangGraphEventTypes .OnInterrupt .value ,
227- value = json .dumps (interrupt .value ) if not isinstance (interrupt .value , str ) else interrupt .value ,
222+ value = json .dumps (interrupt .value , default = make_json_safe ) if not isinstance (interrupt .value , str ) else interrupt .value ,
228223 raw_event = interrupt ,
229224 )
230225 )
231226
232227 if self .active_run .get ("node_name" ) != node_name :
233- yield self ._dispatch_event (
234- StepFinishedEvent (type = EventType .STEP_FINISHED , step_name = self .active_run ["node_name" ])
235- )
236- self .active_run ["node_name" ] = node_name
237- yield self ._dispatch_event (
238- StepStartedEvent (type = EventType .STEP_STARTED , step_name = self .active_run ["node_name" ])
239- )
228+ for ev in self .start_step (node_name ):
229+ yield ev
240230
241231 state_values = state .values if state .values else state
242232 yield self ._dispatch_event (
@@ -250,10 +240,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
250240 )
251241 )
252242
253- yield self ._dispatch_event (
254- StepFinishedEvent (type = EventType .STEP_FINISHED , step_name = self .active_run ["node_name" ])
255- )
256- self .active_run ["node_name" ] = None
243+ yield self .end_step ()
257244
258245 yield self ._dispatch_event (
259246 RunFinishedEvent (type = EventType .RUN_FINISHED , thread_id = thread_id , run_id = self .active_run ["id" ])
@@ -700,3 +687,43 @@ async def get_checkpoint_before_message(self, message_id: str, thread_id: str):
700687
701688 raise ValueError ("Message ID not found in history" )
702689
690+ def start_step (self , step_name : str ):
691+ if self .active_step :
692+ yield self .end_step ()
693+
694+ yield self ._dispatch_event (
695+ StepStartedEvent (
696+ type = EventType .STEP_STARTED ,
697+ step_name = step_name
698+ )
699+ )
700+ self .active_run ["node_name" ] = step_name
701+ self .active_step = step_name
702+
703+ def end_step (self ):
704+ if self .active_step is None :
705+ raise ValueError ("No active step to end" )
706+
707+ dispatch = self ._dispatch_event (
708+ StepFinishedEvent (
709+ type = EventType .STEP_FINISHED ,
710+ step_name = self .active_run ["node_name" ]
711+ )
712+ )
713+
714+ self .active_run ["node_name" ] = None
715+ self .active_step = None
716+ return dispatch
717+
718+ def make_json_safe (o ):
719+ if is_dataclass (o ): # dataclasses like Flight(...)
720+ return asdict (o )
721+ if hasattr (o , "model_dump" ): # pydantic v2
722+ return o .model_dump ()
723+ if hasattr (o , "dict" ): # pydantic v1
724+ return o .dict ()
725+ if hasattr (o , "__dict__" ): # plain objects
726+ return vars (o )
727+ if isinstance (o , (datetime , date )):
728+ return o .isoformat ()
729+ return str (o ) # last resort
0 commit comments