Skip to content

Commit bc8cba3

Browse files
committed
fix: recreate step management for langgraph integration fastapi
1 parent ecc2d86 commit bc8cba3

File tree

1 file changed

+31
-20
lines changed
  • typescript-sdk/integrations/langgraph/python/ag_ui_langgraph

1 file changed

+31
-20
lines changed

typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def __init__(self, *, name: str, graph: CompiledStateGraph, description: Optiona
8888
self.messages_in_process: MessagesInProgressRecord = {}
8989
self.active_run: Optional[RunMetadata] = None
9090
self.constant_schema_keys = ['messages', 'tools']
91-
self.active_step = None
9291

9392
def _dispatch_event(self, event: ProcessedEvents) -> str:
9493
if event.type == EventType.RAW:
@@ -121,9 +120,6 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
121120
node_name_input = forwarded_props.get('node_name', None) if forwarded_props else None
122121

123122
self.active_run["manually_emitted_state"] = None
124-
self.active_run["node_name"] = node_name_input
125-
if self.active_run["node_name"] == "__end__":
126-
self.active_run["node_name"] = None
127123

128124
config = ensure_config(self.config.copy() if self.config else {})
129125
config["configurable"] = {**(config.get('configurable', {})), "thread_id": thread_id}
@@ -141,10 +137,11 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
141137
yield self._dispatch_event(
142138
RunStartedEvent(type=EventType.RUN_STARTED, thread_id=thread_id, run_id=self.active_run["id"])
143139
)
140+
self.handle_node_change(node_name_input)
144141

145142
# In case of resume (interrupt), re-start resumed step
146143
if resume_input and self.active_run.get("node_name"):
147-
for ev in self.start_step(self.active_run.get("node_name")):
144+
for ev in self.handle_node_change(self.active_run.get("node_name")):
148145
yield ev
149146

150147
state = prepared_stream_response["state"]
@@ -189,7 +186,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
189186
)
190187

191188
if current_node_name and current_node_name != self.active_run.get("node_name"):
192-
for ev in self.start_step(current_node_name):
189+
for ev in self.handle_node_change(current_node_name):
193190
yield ev
194191

195192
updated_state = self.active_run.get("manually_emitted_state") or current_graph_state
@@ -236,7 +233,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
236233
)
237234

238235
if self.active_run.get("node_name") != node_name:
239-
for ev in self.start_step(node_name):
236+
for ev in self.handle_node_change(node_name):
240237
yield ev
241238

242239
state_values = state.values if state.values else state
@@ -251,7 +248,8 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
251248
)
252249
)
253250

254-
yield self.end_step()
251+
for ev in self.handle_node_change(None):
252+
yield ev
255253

256254
yield self._dispatch_event(
257255
RunFinishedEvent(type=EventType.RUN_FINISHED, thread_id=thread_id, run_id=self.active_run["id"])
@@ -730,34 +728,47 @@ async def get_checkpoint_before_message(self, message_id: str, thread_id: str):
730728

731729
raise ValueError("Message ID not found in history")
732730

733-
def start_step(self, step_name: str):
734-
if self.active_step:
735-
yield self.end_step()
731+
def handle_node_change(self, node_name: Optional[str]):
732+
"""
733+
Centralized method to handle node name changes and step transitions.
734+
Automatically manages step start/end events based on node name changes.
735+
"""
736+
if node_name == "__end__":
737+
node_name = None
738+
739+
if node_name != self.active_run.get("node_name"):
740+
# End current step if we have one
741+
if self.active_run.get("node_name"):
742+
yield self.end_step()
743+
744+
# Start new step if we have a node name
745+
if node_name:
746+
for event in self.start_step(node_name):
747+
yield event
736748

749+
self.handle_node_change(node_name)
750+
751+
def start_step(self, step_name: str):
752+
"""Simple step start event dispatcher - node_name management handled by handle_node_change"""
737753
yield self._dispatch_event(
738754
StepStartedEvent(
739755
type=EventType.STEP_STARTED,
740756
step_name=step_name
741757
)
742758
)
743-
self.active_run["node_name"] = step_name
744-
self.active_step = step_name
745759

746760
def end_step(self):
747-
if self.active_step is None:
761+
"""Simple step end event dispatcher - node_name management handled by handle_node_change"""
762+
if not self.active_run.get("node_name"):
748763
raise ValueError("No active step to end")
749764

750-
dispatch = self._dispatch_event(
765+
return self._dispatch_event(
751766
StepFinishedEvent(
752767
type=EventType.STEP_FINISHED,
753-
step_name=self.active_run["node_name"] or self.active_step
768+
step_name=self.active_run["node_name"]
754769
)
755770
)
756771

757-
self.active_run["node_name"] = None
758-
self.active_step = None
759-
return dispatch
760-
761772
# Check if some kwargs are enabled per LG version, to "catch all versions" and backwards compatibility
762773
def get_stream_kwargs(
763774
self,

0 commit comments

Comments
 (0)