From c39205698ddc70f6676f74cf55e5482b67820cb6 Mon Sep 17 00:00:00 2001 From: ran Date: Fri, 29 Aug 2025 18:04:57 +0200 Subject: [PATCH] fix: close off dangling steps by properly tracking open ones --- .../langgraph/python/ag_ui_langgraph/agent.py | 77 +++++++++++++------ .../integrations/langgraph/src/agent.ts | 13 +++- 2 files changed, 62 insertions(+), 28 deletions(-) diff --git a/typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py b/typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py index 03ff8fe5a..f13d79dda 100644 --- a/typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py +++ b/typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py @@ -1,6 +1,8 @@ import uuid import json from typing import Optional, List, Any, Union, AsyncGenerator, Generator +from dataclasses import is_dataclass, asdict +from datetime import date, datetime from langgraph.graph.state import CompiledStateGraph from langchain.schema import BaseMessage, SystemMessage @@ -85,6 +87,7 @@ def __init__(self, *, name: str, graph: CompiledStateGraph, description: Optiona self.messages_in_process: MessagesInProgressRecord = {} self.active_run: Optional[RunMetadata] = None self.constant_schema_keys = ['messages', 'tools'] + self.active_step = None def _dispatch_event(self, event: ProcessedEvents) -> str: return event # Fallback if no encoder @@ -135,9 +138,8 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st # In case of resume (interrupt), re-start resumed step if resume_input and self.active_run.get("node_name"): - yield self._dispatch_event( - StepStartedEvent(type=EventType.STEP_STARTED, step_name=self.active_run.get("node_name")) - ) + for ev in self.start_step(self.active_run.get("node_name")): + yield ev state = prepared_stream_response["state"] stream = prepared_stream_response["stream"] @@ -151,6 +153,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st should_exit = False current_graph_state = state + async for event in stream: if event["event"] == "error": yield self._dispatch_event( @@ -175,16 +178,8 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st ) if current_node_name and current_node_name != self.active_run.get("node_name"): - if self.active_run["node_name"] and self.active_run["node_name"] != node_name_input: - yield self._dispatch_event( - StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"]) - ) - self.active_run["node_name"] = None - - yield self._dispatch_event( - StepStartedEvent(type=EventType.STEP_STARTED, step_name=current_node_name) - ) - self.active_run["node_name"] = current_node_name + for ev in self.start_step(current_node_name): + yield ev updated_state = self.active_run.get("manually_emitted_state") or current_graph_state has_state_diff = updated_state != state @@ -224,19 +219,14 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st CustomEvent( type=EventType.CUSTOM, name=LangGraphEventTypes.OnInterrupt.value, - value=json.dumps(interrupt.value) if not isinstance(interrupt.value, str) else interrupt.value, + value=json.dumps(interrupt.value, default=make_json_safe) if not isinstance(interrupt.value, str) else interrupt.value, raw_event=interrupt, ) ) if self.active_run.get("node_name") != node_name: - yield self._dispatch_event( - StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"]) - ) - self.active_run["node_name"] = node_name - yield self._dispatch_event( - StepStartedEvent(type=EventType.STEP_STARTED, step_name=self.active_run["node_name"]) - ) + for ev in self.start_step(node_name): + yield ev state_values = state.values if state.values else state yield self._dispatch_event( @@ -250,10 +240,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st ) ) - yield self._dispatch_event( - StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"]) - ) - self.active_run["node_name"] = None + yield self.end_step() yield self._dispatch_event( RunFinishedEvent(type=EventType.RUN_FINISHED, thread_id=thread_id, run_id=self.active_run["id"]) @@ -700,3 +687,43 @@ async def get_checkpoint_before_message(self, message_id: str, thread_id: str): raise ValueError("Message ID not found in history") + def start_step(self, step_name: str): + if self.active_step: + yield self.end_step() + + yield self._dispatch_event( + StepStartedEvent( + type=EventType.STEP_STARTED, + step_name=step_name + ) + ) + self.active_run["node_name"] = step_name + self.active_step = step_name + + def end_step(self): + if self.active_step is None: + raise ValueError("No active step to end") + + dispatch = self._dispatch_event( + StepFinishedEvent( + type=EventType.STEP_FINISHED, + step_name=self.active_run["node_name"] + ) + ) + + self.active_run["node_name"] = None + self.active_step = None + return dispatch + +def make_json_safe(o): + if is_dataclass(o): # dataclasses like Flight(...) + return asdict(o) + if hasattr(o, "model_dump"): # pydantic v2 + return o.model_dump() + if hasattr(o, "dict"): # pydantic v1 + return o.dict() + if hasattr(o, "__dict__"): # plain objects + return vars(o) + if isinstance(o, (datetime, date)): + return o.isoformat() + return str(o) # last resort diff --git a/typescript-sdk/integrations/langgraph/src/agent.ts b/typescript-sdk/integrations/langgraph/src/agent.ts index c08f50a8a..dde5e51b3 100644 --- a/typescript-sdk/integrations/langgraph/src/agent.ts +++ b/typescript-sdk/integrations/langgraph/src/agent.ts @@ -123,6 +123,7 @@ export class LangGraphAgent extends AbstractAgent { // @ts-expect-error no need to initialize subscriber right now subscriber: Subscriber; constantSchemaKeys: string[] = DEFAULT_SCHEMA_KEYS; + activeStep?: string; constructor(config: LangGraphAgentConfig) { super(config); @@ -383,7 +384,9 @@ export class LangGraphAgent extends AbstractAgent { break; } - if (streamResponseChunk.event === "updates") continue; + if (streamResponseChunk.event === "updates") { + continue; + } if (streamResponseChunk.event === "values") { latestStateValues = chunk.data; @@ -467,7 +470,6 @@ export class LangGraphAgent extends AbstractAgent { newNodeName = isEndNode ? '__end__' : (state.next[0] ?? Object.keys(writes)[0]); } - interrupts.forEach((interrupt) => { this.dispatchEvent({ type: EventType.CUSTOM, @@ -944,15 +946,19 @@ export class LangGraphAgent extends AbstractAgent { } startStep(nodeName: string) { + if (this.activeStep) { + this.endStep() + } this.dispatchEvent({ type: EventType.STEP_STARTED, stepName: nodeName, }); this.activeRun!.nodeName = nodeName; + this.activeStep = nodeName; } endStep() { - if (!this.activeRun!.nodeName) { + if (!this.activeStep) { throw new Error("No active step to end"); } this.dispatchEvent({ @@ -960,6 +966,7 @@ export class LangGraphAgent extends AbstractAgent { stepName: this.activeRun!.nodeName!, }); this.activeRun!.nodeName = undefined; + this.activeStep = undefined; } async getCheckpointByMessage(