@@ -88,7 +88,6 @@ def __init__(self, *, name: str, graph: CompiledStateGraph, description: Optiona
8888 self .messages_in_process : MessagesInProgressRecord = {}
8989 self .active_run : Optional [RunMetadata ] = None
9090 self .constant_schema_keys = ['messages' , 'tools' ]
91- self .active_step = None
9291
9392 def _dispatch_event (self , event : ProcessedEvents ) -> str :
9493 if event .type == EventType .RAW :
@@ -121,9 +120,6 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
121120 node_name_input = forwarded_props .get ('node_name' , None ) if forwarded_props else None
122121
123122 self .active_run ["manually_emitted_state" ] = None
124- self .active_run ["node_name" ] = node_name_input
125- if self .active_run ["node_name" ] == "__end__" :
126- self .active_run ["node_name" ] = None
127123
128124 config = ensure_config (self .config .copy () if self .config else {})
129125 config ["configurable" ] = {** (config .get ('configurable' , {})), "thread_id" : thread_id }
@@ -141,10 +137,11 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
141137 yield self ._dispatch_event (
142138 RunStartedEvent (type = EventType .RUN_STARTED , thread_id = thread_id , run_id = self .active_run ["id" ])
143139 )
140+ self .handle_node_change (node_name_input )
144141
145142 # In case of resume (interrupt), re-start resumed step
146143 if resume_input and self .active_run .get ("node_name" ):
147- for ev in self .start_step (self .active_run .get ("node_name" )):
144+ for ev in self .handle_node_change (self .active_run .get ("node_name" )):
148145 yield ev
149146
150147 state = prepared_stream_response ["state" ]
@@ -189,7 +186,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
189186 )
190187
191188 if current_node_name and current_node_name != self .active_run .get ("node_name" ):
192- for ev in self .start_step (current_node_name ):
189+ for ev in self .handle_node_change (current_node_name ):
193190 yield ev
194191
195192 updated_state = self .active_run .get ("manually_emitted_state" ) or current_graph_state
@@ -236,7 +233,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
236233 )
237234
238235 if self .active_run .get ("node_name" ) != node_name :
239- for ev in self .start_step (node_name ):
236+ for ev in self .handle_node_change (node_name ):
240237 yield ev
241238
242239 state_values = state .values if state .values else state
@@ -251,7 +248,8 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
251248 )
252249 )
253250
254- yield self .end_step ()
251+ for ev in self .handle_node_change (None ):
252+ yield ev
255253
256254 yield self ._dispatch_event (
257255 RunFinishedEvent (type = EventType .RUN_FINISHED , thread_id = thread_id , run_id = self .active_run ["id" ])
@@ -730,34 +728,47 @@ async def get_checkpoint_before_message(self, message_id: str, thread_id: str):
730728
731729 raise ValueError ("Message ID not found in history" )
732730
733- def start_step (self , step_name : str ):
734- if self .active_step :
735- yield self .end_step ()
731+ def handle_node_change (self , node_name : Optional [str ]):
732+ """
733+ Centralized method to handle node name changes and step transitions.
734+ Automatically manages step start/end events based on node name changes.
735+ """
736+ if node_name == "__end__" :
737+ node_name = None
738+
739+ if node_name != self .active_run .get ("node_name" ):
740+ # End current step if we have one
741+ if self .active_run .get ("node_name" ):
742+ yield self .end_step ()
743+
744+ # Start new step if we have a node name
745+ if node_name :
746+ for event in self .start_step (node_name ):
747+ yield event
736748
749+ self .handle_node_change (node_name )
750+
751+ def start_step (self , step_name : str ):
752+ """Simple step start event dispatcher - node_name management handled by handle_node_change"""
737753 yield self ._dispatch_event (
738754 StepStartedEvent (
739755 type = EventType .STEP_STARTED ,
740756 step_name = step_name
741757 )
742758 )
743- self .active_run ["node_name" ] = step_name
744- self .active_step = step_name
745759
746760 def end_step (self ):
747- if self .active_step is None :
761+ """Simple step end event dispatcher - node_name management handled by handle_node_change"""
762+ if not self .active_run .get ("node_name" ):
748763 raise ValueError ("No active step to end" )
749764
750- dispatch = self ._dispatch_event (
765+ return self ._dispatch_event (
751766 StepFinishedEvent (
752767 type = EventType .STEP_FINISHED ,
753- step_name = self .active_run ["node_name" ] or self . active_step
768+ step_name = self .active_run ["node_name" ]
754769 )
755770 )
756771
757- self .active_run ["node_name" ] = None
758- self .active_step = None
759- return dispatch
760-
761772 # Check if some kwargs are enabled per LG version, to "catch all versions" and backwards compatibility
762773 def get_stream_kwargs (
763774 self ,
0 commit comments