11import uuid
22import json
33from typing import Optional , List , Any , Union , AsyncGenerator , Generator
4+ from dataclasses import is_dataclass , asdict
5+ from datetime import date , datetime
46
57from langgraph .graph .state import CompiledStateGraph
68from langchain .schema import BaseMessage , SystemMessage
@@ -85,6 +87,7 @@ def __init__(self, *, name: str, graph: CompiledStateGraph, description: Optiona
8587 self .messages_in_process : MessagesInProgressRecord = {}
8688 self .active_run : Optional [RunMetadata ] = None
8789 self .constant_schema_keys = ['messages' , 'tools' ]
90+ self .active_step = None
8891
8992 def _dispatch_event (self , event : ProcessedEvents ) -> str :
9093 return event # Fallback if no encoder
@@ -135,9 +138,8 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
135138
136139 # In case of resume (interrupt), re-start resumed step
137140 if resume_input and self .active_run .get ("node_name" ):
138- yield self ._dispatch_event (
139- StepStartedEvent (type = EventType .STEP_STARTED , step_name = self .active_run .get ("node_name" ))
140- )
141+ for ev in self .start_step (self .active_run .get ("node_name" )):
142+ yield ev
141143
142144 state = prepared_stream_response ["state" ]
143145 stream = prepared_stream_response ["stream" ]
@@ -151,7 +153,13 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
151153
152154 should_exit = False
153155 current_graph_state = state
156+
154157 async for event in stream :
158+ subgraphs_stream_enabled = input .forwarded_props .get ('stream_subgraphs' ) if input .forwarded_props else False
159+ is_subgraph_stream = (subgraphs_stream_enabled and (
160+ event .get ("event" , "" ).startswith ("events" ) or
161+ event .get ("event" , "" ).startswith ("values" )
162+ ))
155163 if event ["event" ] == "error" :
156164 yield self ._dispatch_event (
157165 RunErrorEvent (type = EventType .RUN_ERROR , message = event ["data" ]["message" ], raw_event = event )
@@ -175,16 +183,8 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
175183 )
176184
177185 if current_node_name and current_node_name != self .active_run .get ("node_name" ):
178- if self .active_run ["node_name" ] and self .active_run ["node_name" ] != node_name_input :
179- yield self ._dispatch_event (
180- StepFinishedEvent (type = EventType .STEP_FINISHED , step_name = self .active_run ["node_name" ])
181- )
182- self .active_run ["node_name" ] = None
183-
184- yield self ._dispatch_event (
185- StepStartedEvent (type = EventType .STEP_STARTED , step_name = current_node_name )
186- )
187- self .active_run ["node_name" ] = current_node_name
186+ for ev in self .start_step (current_node_name ):
187+ yield ev
188188
189189 updated_state = self .active_run .get ("manually_emitted_state" ) or current_graph_state
190190 has_state_diff = updated_state != state
@@ -224,19 +224,14 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
224224 CustomEvent (
225225 type = EventType .CUSTOM ,
226226 name = LangGraphEventTypes .OnInterrupt .value ,
227- value = json .dumps (interrupt .value ) if not isinstance (interrupt .value , str ) else interrupt .value ,
227+ value = json .dumps (interrupt .value , default = make_json_safe ) if not isinstance (interrupt .value , str ) else interrupt .value ,
228228 raw_event = interrupt ,
229229 )
230230 )
231231
232232 if self .active_run .get ("node_name" ) != node_name :
233- yield self ._dispatch_event (
234- StepFinishedEvent (type = EventType .STEP_FINISHED , step_name = self .active_run ["node_name" ])
235- )
236- self .active_run ["node_name" ] = node_name
237- yield self ._dispatch_event (
238- StepStartedEvent (type = EventType .STEP_STARTED , step_name = self .active_run ["node_name" ])
239- )
233+ for ev in self .start_step (node_name ):
234+ yield ev
240235
241236 state_values = state .values if state .values else state
242237 yield self ._dispatch_event (
@@ -250,10 +245,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
250245 )
251246 )
252247
253- yield self ._dispatch_event (
254- StepFinishedEvent (type = EventType .STEP_FINISHED , step_name = self .active_run ["node_name" ])
255- )
256- self .active_run ["node_name" ] = None
248+ yield self .end_step ()
257249
258250 yield self ._dispatch_event (
259251 RunFinishedEvent (type = EventType .RUN_FINISHED , thread_id = thread_id , run_id = self .active_run ["id" ])
@@ -336,8 +328,18 @@ async def prepare_stream(self, input: RunAgentInput, agent_state: State, config:
336328 )
337329 stream_input = {** forwarded_props , ** payload_input } if payload_input else None
338330
331+
332+ subgraphs_stream_enabled = input .forwarded_props .get ('stream_subgraphs' ) if input .forwarded_props else False
333+
334+ stream = self .graph .astream_events (
335+ stream_input ,
336+ config = config ,
337+ subgraps = bool (subgraphs_stream_enabled ),
338+ version = "v2"
339+ )
340+
339341 return {
340- "stream" : self . graph . astream_events ( stream_input , config , version = "v2" ) ,
342+ "stream" : stream ,
341343 "state" : state ,
342344 "config" : config
343345 }
@@ -362,7 +364,13 @@ async def prepare_regenerate_stream( # pylint: disable=too-many-arguments
362364 )
363365
364366 stream_input = self .langgraph_default_merge_state (time_travel_checkpoint .values , [message_checkpoint ], tools )
365- stream = self .graph .astream_events (stream_input , fork , version = "v2" )
367+ subgraphs_stream_enabled = input .forwarded_props .get ('stream_subgraphs' ) if input .forwarded_props else False
368+ stream = self .graph .astream_events (
369+ stream_input ,
370+ fork ,
371+ subgraps = bool (subgraphs_stream_enabled ),
372+ version = "v2"
373+ )
366374
367375 return {
368376 "stream" : stream ,
@@ -700,3 +708,43 @@ async def get_checkpoint_before_message(self, message_id: str, thread_id: str):
700708
701709 raise ValueError ("Message ID not found in history" )
702710
711+ def start_step (self , step_name : str ):
712+ if self .active_step :
713+ yield self .end_step ()
714+
715+ yield self ._dispatch_event (
716+ StepStartedEvent (
717+ type = EventType .STEP_STARTED ,
718+ step_name = step_name
719+ )
720+ )
721+ self .active_run ["node_name" ] = step_name
722+ self .active_step = step_name
723+
724+ def end_step (self ):
725+ if self .active_step is None :
726+ raise ValueError ("No active step to end" )
727+
728+ dispatch = self ._dispatch_event (
729+ StepFinishedEvent (
730+ type = EventType .STEP_FINISHED ,
731+ step_name = self .active_run ["node_name" ]
732+ )
733+ )
734+
735+ self .active_run ["node_name" ] = None
736+ self .active_step = None
737+ return dispatch
738+
739+ def make_json_safe (o ):
740+ if is_dataclass (o ): # dataclasses like Flight(...)
741+ return asdict (o )
742+ if hasattr (o , "model_dump" ): # pydantic v2
743+ return o .model_dump ()
744+ if hasattr (o , "dict" ): # pydantic v1
745+ return o .dict ()
746+ if hasattr (o , "__dict__" ): # plain objects
747+ return vars (o )
748+ if isinstance (o , (datetime , date )):
749+ return o .isoformat ()
750+ return str (o ) # last resort
0 commit comments