11import  uuid 
22import  json 
33from  typing  import  Optional , List , Any , Union , AsyncGenerator , Generator 
4+ from  dataclasses  import  is_dataclass , asdict 
5+ from  datetime  import  date , datetime 
46
57from  langgraph .graph .state  import  CompiledStateGraph 
68from  langchain .schema  import  BaseMessage , SystemMessage 
@@ -85,6 +87,7 @@ def __init__(self, *, name: str, graph: CompiledStateGraph, description: Optiona
8587        self .messages_in_process : MessagesInProgressRecord  =  {}
8688        self .active_run : Optional [RunMetadata ] =  None 
8789        self .constant_schema_keys  =  ['messages' , 'tools' ]
90+         self .active_step  =  None 
8891
8992    def  _dispatch_event (self , event : ProcessedEvents ) ->  str :
9093        return  event   # Fallback if no encoder 
@@ -135,9 +138,8 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
135138
136139        # In case of resume (interrupt), re-start resumed step 
137140        if  resume_input  and  self .active_run .get ("node_name" ):
138-             yield  self ._dispatch_event (
139-                 StepStartedEvent (type = EventType .STEP_STARTED , step_name = self .active_run .get ("node_name" ))
140-             )
141+             for  ev  in  self .start_step (self .active_run .get ("node_name" )):
142+                 yield  ev 
141143
142144        state  =  prepared_stream_response ["state" ]
143145        stream  =  prepared_stream_response ["stream" ]
@@ -151,6 +153,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
151153
152154        should_exit  =  False 
153155        current_graph_state  =  state 
156+         
154157        async  for  event  in  stream :
155158            if  event ["event" ] ==  "error" :
156159                yield  self ._dispatch_event (
@@ -175,16 +178,8 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
175178                )
176179
177180            if  current_node_name  and  current_node_name  !=  self .active_run .get ("node_name" ):
178-                 if  self .active_run ["node_name" ] and  self .active_run ["node_name" ] !=  node_name_input :
179-                     yield  self ._dispatch_event (
180-                         StepFinishedEvent (type = EventType .STEP_FINISHED , step_name = self .active_run ["node_name" ])
181-                     )
182-                     self .active_run ["node_name" ] =  None 
183- 
184-                 yield  self ._dispatch_event (
185-                     StepStartedEvent (type = EventType .STEP_STARTED , step_name = current_node_name )
186-                 )
187-                 self .active_run ["node_name" ] =  current_node_name 
181+                 for  ev  in  self .start_step (current_node_name ):
182+                     yield  ev 
188183
189184            updated_state  =  self .active_run .get ("manually_emitted_state" ) or  current_graph_state 
190185            has_state_diff  =  updated_state  !=  state 
@@ -224,19 +219,14 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
224219                CustomEvent (
225220                    type = EventType .CUSTOM ,
226221                    name = LangGraphEventTypes .OnInterrupt .value ,
227-                     value = json .dumps (interrupt .value ) if  not  isinstance (interrupt .value , str ) else  interrupt .value ,
222+                     value = json .dumps (interrupt .value ,  default = make_json_safe ) if  not  isinstance (interrupt .value , str ) else  interrupt .value ,
228223                    raw_event = interrupt ,
229224                )
230225            )
231226
232227        if  self .active_run .get ("node_name" ) !=  node_name :
233-             yield  self ._dispatch_event (
234-                 StepFinishedEvent (type = EventType .STEP_FINISHED , step_name = self .active_run ["node_name" ])
235-             )
236-             self .active_run ["node_name" ] =  node_name 
237-             yield  self ._dispatch_event (
238-                 StepStartedEvent (type = EventType .STEP_STARTED , step_name = self .active_run ["node_name" ])
239-             )
228+             for  ev  in  self .start_step (node_name ):
229+                 yield  ev 
240230
241231        state_values  =  state .values  if  state .values  else  state 
242232        yield  self ._dispatch_event (
@@ -250,10 +240,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
250240            )
251241        )
252242
253-         yield  self ._dispatch_event (
254-             StepFinishedEvent (type = EventType .STEP_FINISHED , step_name = self .active_run ["node_name" ])
255-         )
256-         self .active_run ["node_name" ] =  None 
243+         yield  self .end_step ()
257244
258245        yield  self ._dispatch_event (
259246            RunFinishedEvent (type = EventType .RUN_FINISHED , thread_id = thread_id , run_id = self .active_run ["id" ])
@@ -700,3 +687,43 @@ async def get_checkpoint_before_message(self, message_id: str, thread_id: str):
700687
701688        raise  ValueError ("Message ID not found in history" )
702689
690+     def  start_step (self , step_name : str ):
691+         if  self .active_step :
692+             yield  self .end_step ()
693+ 
694+         yield  self ._dispatch_event (
695+             StepStartedEvent (
696+                 type = EventType .STEP_STARTED ,
697+                 step_name = step_name 
698+             )
699+         )
700+         self .active_run ["node_name" ] =  step_name 
701+         self .active_step  =  step_name 
702+ 
703+     def  end_step (self ):
704+         if  self .active_step  is  None :
705+             raise  ValueError ("No active step to end" )
706+ 
707+         dispatch  =  self ._dispatch_event (
708+             StepFinishedEvent (
709+                 type = EventType .STEP_FINISHED ,
710+                 step_name = self .active_run ["node_name" ]
711+             )
712+         )
713+ 
714+         self .active_run ["node_name" ] =  None 
715+         self .active_step  =  None 
716+         return  dispatch 
717+ 
718+ def  make_json_safe (o ):
719+     if  is_dataclass (o ):          # dataclasses like Flight(...) 
720+         return  asdict (o )
721+     if  hasattr (o , "model_dump" ): # pydantic v2 
722+         return  o .model_dump ()
723+     if  hasattr (o , "dict" ):       # pydantic v1 
724+         return  o .dict ()
725+     if  hasattr (o , "__dict__" ):   # plain objects 
726+         return  vars (o )
727+     if  isinstance (o , (datetime , date )):
728+         return  o .isoformat ()
729+     return  str (o )                # last resort 
0 commit comments