diff --git a/typescript-sdk/integrations/langgraph/examples/python/poetry.lock b/typescript-sdk/integrations/langgraph/examples/python/poetry.lock index e9ec71b71..776843d3e 100644 --- a/typescript-sdk/integrations/langgraph/examples/python/poetry.lock +++ b/typescript-sdk/integrations/langgraph/examples/python/poetry.lock @@ -2,14 +2,14 @@ [[package]] name = "ag-ui-langgraph" -version = "0.0.12a1" +version = "0.0.12a3" description = "Implementation of the AG-UI protocol for LangGraph." optional = false python-versions = "<3.14,>=3.10" groups = ["main"] files = [ - {file = "ag_ui_langgraph-0.0.12a1-py3-none-any.whl", hash = "sha256:3c5e6a2b1cea7c91c33f6fa352dacaf23f905b13baa959276f3c22cb5dcbaa59"}, - {file = "ag_ui_langgraph-0.0.12a1.tar.gz", hash = "sha256:13c6034aaa33ec053788cd7dba3a088d7763bf03b19830b4bba6d559546b30b2"}, + {file = "ag_ui_langgraph-0.0.12a3-py3-none-any.whl", hash = "sha256:5bea4056e413bb53952d742da44893d80712ecc9bd77473d9cccd07bd4cd00ee"}, + {file = "ag_ui_langgraph-0.0.12a3.tar.gz", hash = "sha256:0b0734c2a7a12bf98f7c411aaa9be3adacb77808ecb3c42360e5c129ea4b4411"}, ] [package.dependencies] @@ -2970,4 +2970,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.1" python-versions = ">=3.12,<3.14" -content-hash = "42262620b6201375360ae7e3b3a781edd9bba4bb64e2341175fd28921f268a65" +content-hash = "aef84782b941a699c57297f1e9752de447f0c851a4370fbeebd2808601e8e369" diff --git a/typescript-sdk/integrations/langgraph/examples/python/pyproject.toml b/typescript-sdk/integrations/langgraph/examples/python/pyproject.toml index bfd45b38c..71e46a9a5 100644 --- a/typescript-sdk/integrations/langgraph/examples/python/pyproject.toml +++ b/typescript-sdk/integrations/langgraph/examples/python/pyproject.toml @@ -21,7 +21,7 @@ langchain-experimental = ">=0.0.11" langchain-google-genai = ">=2.1.9" langchain-openai = ">=0.0.1" langgraph = "^0.6.1" -ag-ui-langgraph = { version = "0.0.12a1", extras = ["fastapi"] } +ag-ui-langgraph = { version = "0.0.12a3", extras = ["fastapi"] } python-dotenv = "^1.0.0" fastapi = "^0.115.12" diff --git a/typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py b/typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py index 92c5b0700..916b65558 100644 --- a/typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py +++ b/typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py @@ -88,7 +88,6 @@ def __init__(self, *, name: str, graph: CompiledStateGraph, description: Optiona self.messages_in_process: MessagesInProgressRecord = {} self.active_run: Optional[RunMetadata] = None self.constant_schema_keys = ['messages', 'tools'] - self.active_step = None def _dispatch_event(self, event: ProcessedEvents) -> str: if event.type == EventType.RAW: @@ -121,9 +120,6 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st node_name_input = forwarded_props.get('node_name', None) if forwarded_props else None self.active_run["manually_emitted_state"] = None - self.active_run["node_name"] = node_name_input - if self.active_run["node_name"] == "__end__": - self.active_run["node_name"] = None config = ensure_config(self.config.copy() if self.config else {}) config["configurable"] = {**(config.get('configurable', {})), "thread_id": thread_id} @@ -141,10 +137,11 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st yield self._dispatch_event( RunStartedEvent(type=EventType.RUN_STARTED, thread_id=thread_id, run_id=self.active_run["id"]) ) + self.handle_node_change(node_name_input) # In case of resume (interrupt), re-start resumed step if resume_input and self.active_run.get("node_name"): - for ev in self.start_step(self.active_run.get("node_name")): + for ev in self.handle_node_change(self.active_run.get("node_name")): yield ev state = prepared_stream_response["state"] @@ -189,7 +186,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st ) if current_node_name and current_node_name != self.active_run.get("node_name"): - for ev in self.start_step(current_node_name): + for ev in self.handle_node_change(current_node_name): yield ev updated_state = self.active_run.get("manually_emitted_state") or current_graph_state @@ -236,7 +233,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st ) if self.active_run.get("node_name") != node_name: - for ev in self.start_step(node_name): + for ev in self.handle_node_change(node_name): yield ev state_values = state.values if state.values else state @@ -251,7 +248,8 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st ) ) - yield self.end_step() + for ev in self.handle_node_change(None): + yield ev yield self._dispatch_event( RunFinishedEvent(type=EventType.RUN_FINISHED, thread_id=thread_id, run_id=self.active_run["id"]) @@ -730,34 +728,47 @@ async def get_checkpoint_before_message(self, message_id: str, thread_id: str): raise ValueError("Message ID not found in history") - def start_step(self, step_name: str): - if self.active_step: - yield self.end_step() + def handle_node_change(self, node_name: Optional[str]): + """ + Centralized method to handle node name changes and step transitions. + Automatically manages step start/end events based on node name changes. + """ + if node_name == "__end__": + node_name = None + + if node_name != self.active_run.get("node_name"): + # End current step if we have one + if self.active_run.get("node_name"): + yield self.end_step() + + # Start new step if we have a node name + if node_name: + for event in self.start_step(node_name): + yield event + self.active_run["node_name"] = node_name + + def start_step(self, step_name: str): + """Simple step start event dispatcher - node_name management handled by handle_node_change""" yield self._dispatch_event( StepStartedEvent( type=EventType.STEP_STARTED, step_name=step_name ) ) - self.active_run["node_name"] = step_name - self.active_step = step_name def end_step(self): - if self.active_step is None: + """Simple step end event dispatcher - node_name management handled by handle_node_change""" + if not self.active_run.get("node_name"): raise ValueError("No active step to end") - dispatch = self._dispatch_event( + return self._dispatch_event( StepFinishedEvent( type=EventType.STEP_FINISHED, - step_name=self.active_run["node_name"] or self.active_step + step_name=self.active_run["node_name"] ) ) - self.active_run["node_name"] = None - self.active_step = None - return dispatch - # Check if some kwargs are enabled per LG version, to "catch all versions" and backwards compatibility def get_stream_kwargs( self, diff --git a/typescript-sdk/integrations/langgraph/python/pyproject.toml b/typescript-sdk/integrations/langgraph/python/pyproject.toml index cf9e27e73..88368f22d 100644 --- a/typescript-sdk/integrations/langgraph/python/pyproject.toml +++ b/typescript-sdk/integrations/langgraph/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "ag-ui-langgraph" -version = "0.0.12-alpha.1" +version = "0.0.12-alpha.3" description = "Implementation of the AG-UI protocol for LangGraph." authors = ["Ran Shem Tov "] readme = "README.md" diff --git a/typescript-sdk/integrations/langgraph/src/agent.ts b/typescript-sdk/integrations/langgraph/src/agent.ts index 62b584696..110837cd2 100644 --- a/typescript-sdk/integrations/langgraph/src/agent.ts +++ b/typescript-sdk/integrations/langgraph/src/agent.ts @@ -125,7 +125,6 @@ export class LangGraphAgent extends AbstractAgent { // @ts-expect-error no need to initialize subscriber right now subscriber: Subscriber; constantSchemaKeys: string[] = DEFAULT_SCHEMA_KEYS; - activeStep?: string; config: LangGraphAgentConfig; constructor(config: LangGraphAgentConfig) { @@ -235,11 +234,6 @@ export class LangGraphAgent extends AbstractAgent { this.activeRun!.manuallyEmittedState = null; const nodeNameInput = forwardedProps?.nodeName; - this.activeRun!.nodeName = nodeNameInput; - if (this.activeRun!.nodeName === "__end__") { - this.activeRun!.nodeName = undefined; - } - const threadId = inputThreadId ?? randomUUID(); if (!this.assistant) { @@ -347,6 +341,7 @@ export class LangGraphAgent extends AbstractAgent { threadId, runId: input.runId, }); + this.handleNodeChange(nodeNameInput) interrupts.forEach((interrupt) => { this.dispatchEvent({ @@ -400,11 +395,7 @@ export class LangGraphAgent extends AbstractAgent { threadId, runId: this.activeRun!.id, }); - - // In case of resume (interrupt), re-start resumed step - if (forwardedProps?.command?.resume && this.activeRun!.nodeName) { - this.startStep(this.activeRun!.nodeName); - } + this.handleNodeChange(nodeNameInput) for await (let streamResponseChunk of streamResponse) { const subgraphsStreamEnabled = input.forwardedProps?.streamSubgraphs; @@ -460,11 +451,7 @@ export class LangGraphAgent extends AbstractAgent { this.activeRun!.id = metadata.run_id; if (currentNodeName && currentNodeName !== this.activeRun!.nodeName) { - if (this.activeRun!.nodeName && this.activeRun!.nodeName !== nodeNameInput) { - this.endStep(); - } - - this.startStep(currentNodeName); + this.handleNodeChange(currentNodeName) } shouldExit = @@ -482,7 +469,7 @@ export class LangGraphAgent extends AbstractAgent { // we only want to update the node name under certain conditions // since we don't need any internal node names to be sent to the frontend if (this.activeRun!.graphInfo?.["nodes"].some((node) => node.id === currentNodeName)) { - this.activeRun!.nodeName = currentNodeName; + this.handleNodeChange(currentNodeName) } updatedState.values = this.activeRun!.manuallyEmittedState ?? latestStateValues; @@ -523,6 +510,7 @@ export class LangGraphAgent extends AbstractAgent { const isEndNode = state.next.length === 0; const writes = state.metadata?.writes ?? {}; + // Initialize a new node name to use in the next if block let newNodeName = this.activeRun!.nodeName!; if (!interrupts?.length) { @@ -539,12 +527,10 @@ export class LangGraphAgent extends AbstractAgent { }); }); - if (this.activeRun!.nodeName != newNodeName) { - this.endStep(); - this.startStep(newNodeName); - } + this.handleNodeChange(newNodeName); + // Immediately turn off new step + this.handleNodeChange(undefined); - this.endStep(); this.dispatchEvent({ type: EventType.STATE_SNAPSHOT, snapshot: this.getStateSnapshot(state), @@ -1017,28 +1003,35 @@ export class LangGraphAgent extends AbstractAgent { }; } - startStep(nodeName: string) { - if (this.activeStep) { - this.endStep(); + handleNodeChange(nodeName: string | undefined) { + if (nodeName === "__end__") { + nodeName = undefined; } + if (nodeName !== this.activeRun?.nodeName) { + // End current step + if (this.activeRun?.nodeName) { + this.endStep(); + } + // If we actually got a node name, start a new step + if (nodeName) { + this.startStep(nodeName); + } + } + this.activeRun!.nodeName = nodeName; + } + + startStep(nodeName: string) { this.dispatchEvent({ type: EventType.STEP_STARTED, stepName: nodeName, }); - this.activeRun!.nodeName = nodeName; - this.activeStep = nodeName; } endStep() { - if (!this.activeStep) { - throw new Error("No active step to end"); - } this.dispatchEvent({ type: EventType.STEP_FINISHED, - stepName: this.activeRun!.nodeName! ?? this.activeStep, + stepName: this.activeRun!.nodeName!, }); - this.activeRun!.nodeName = undefined; - this.activeStep = undefined; } async getCheckpointByMessage(