Skip to content

Commit 5c22341

Browse files
authored
fix: recreate step management for langgraph integration (#380)
* fix: recreate step management for langgraph integration
1 parent 45c1f0d commit 5c22341

File tree

5 files changed

+63
-59
lines changed

5 files changed

+63
-59
lines changed

typescript-sdk/integrations/langgraph/examples/python/poetry.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

typescript-sdk/integrations/langgraph/examples/python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ langchain-experimental = ">=0.0.11"
2121
langchain-google-genai = ">=2.1.9"
2222
langchain-openai = ">=0.0.1"
2323
langgraph = "^0.6.1"
24-
ag-ui-langgraph = { version = "0.0.12a1", extras = ["fastapi"] }
24+
ag-ui-langgraph = { version = "0.0.12a3", extras = ["fastapi"] }
2525
python-dotenv = "^1.0.0"
2626
fastapi = "^0.115.12"
2727

typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def __init__(self, *, name: str, graph: CompiledStateGraph, description: Optiona
8888
self.messages_in_process: MessagesInProgressRecord = {}
8989
self.active_run: Optional[RunMetadata] = None
9090
self.constant_schema_keys = ['messages', 'tools']
91-
self.active_step = None
9291

9392
def _dispatch_event(self, event: ProcessedEvents) -> str:
9493
if event.type == EventType.RAW:
@@ -121,9 +120,6 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
121120
node_name_input = forwarded_props.get('node_name', None) if forwarded_props else None
122121

123122
self.active_run["manually_emitted_state"] = None
124-
self.active_run["node_name"] = node_name_input
125-
if self.active_run["node_name"] == "__end__":
126-
self.active_run["node_name"] = None
127123

128124
config = ensure_config(self.config.copy() if self.config else {})
129125
config["configurable"] = {**(config.get('configurable', {})), "thread_id": thread_id}
@@ -141,10 +137,11 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
141137
yield self._dispatch_event(
142138
RunStartedEvent(type=EventType.RUN_STARTED, thread_id=thread_id, run_id=self.active_run["id"])
143139
)
140+
self.handle_node_change(node_name_input)
144141

145142
# In case of resume (interrupt), re-start resumed step
146143
if resume_input and self.active_run.get("node_name"):
147-
for ev in self.start_step(self.active_run.get("node_name")):
144+
for ev in self.handle_node_change(self.active_run.get("node_name")):
148145
yield ev
149146

150147
state = prepared_stream_response["state"]
@@ -189,7 +186,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
189186
)
190187

191188
if current_node_name and current_node_name != self.active_run.get("node_name"):
192-
for ev in self.start_step(current_node_name):
189+
for ev in self.handle_node_change(current_node_name):
193190
yield ev
194191

195192
updated_state = self.active_run.get("manually_emitted_state") or current_graph_state
@@ -236,7 +233,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
236233
)
237234

238235
if self.active_run.get("node_name") != node_name:
239-
for ev in self.start_step(node_name):
236+
for ev in self.handle_node_change(node_name):
240237
yield ev
241238

242239
state_values = state.values if state.values else state
@@ -251,7 +248,8 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
251248
)
252249
)
253250

254-
yield self.end_step()
251+
for ev in self.handle_node_change(None):
252+
yield ev
255253

256254
yield self._dispatch_event(
257255
RunFinishedEvent(type=EventType.RUN_FINISHED, thread_id=thread_id, run_id=self.active_run["id"])
@@ -730,34 +728,47 @@ async def get_checkpoint_before_message(self, message_id: str, thread_id: str):
730728

731729
raise ValueError("Message ID not found in history")
732730

733-
def start_step(self, step_name: str):
734-
if self.active_step:
735-
yield self.end_step()
731+
def handle_node_change(self, node_name: Optional[str]):
732+
"""
733+
Centralized method to handle node name changes and step transitions.
734+
Automatically manages step start/end events based on node name changes.
735+
"""
736+
if node_name == "__end__":
737+
node_name = None
738+
739+
if node_name != self.active_run.get("node_name"):
740+
# End current step if we have one
741+
if self.active_run.get("node_name"):
742+
yield self.end_step()
743+
744+
# Start new step if we have a node name
745+
if node_name:
746+
for event in self.start_step(node_name):
747+
yield event
736748

749+
self.active_run["node_name"] = node_name
750+
751+
def start_step(self, step_name: str):
752+
"""Simple step start event dispatcher - node_name management handled by handle_node_change"""
737753
yield self._dispatch_event(
738754
StepStartedEvent(
739755
type=EventType.STEP_STARTED,
740756
step_name=step_name
741757
)
742758
)
743-
self.active_run["node_name"] = step_name
744-
self.active_step = step_name
745759

746760
def end_step(self):
747-
if self.active_step is None:
761+
"""Simple step end event dispatcher - node_name management handled by handle_node_change"""
762+
if not self.active_run.get("node_name"):
748763
raise ValueError("No active step to end")
749764

750-
dispatch = self._dispatch_event(
765+
return self._dispatch_event(
751766
StepFinishedEvent(
752767
type=EventType.STEP_FINISHED,
753-
step_name=self.active_run["node_name"] or self.active_step
768+
step_name=self.active_run["node_name"]
754769
)
755770
)
756771

757-
self.active_run["node_name"] = None
758-
self.active_step = None
759-
return dispatch
760-
761772
# Check if some kwargs are enabled per LG version, to "catch all versions" and backwards compatibility
762773
def get_stream_kwargs(
763774
self,

typescript-sdk/integrations/langgraph/python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "ag-ui-langgraph"
3-
version = "0.0.12-alpha.1"
3+
version = "0.0.12-alpha.3"
44
description = "Implementation of the AG-UI protocol for LangGraph."
55
authors = ["Ran Shem Tov <[email protected]>"]
66
readme = "README.md"

typescript-sdk/integrations/langgraph/src/agent.ts

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ export class LangGraphAgent extends AbstractAgent {
125125
// @ts-expect-error no need to initialize subscriber right now
126126
subscriber: Subscriber<ProcessedEvents>;
127127
constantSchemaKeys: string[] = DEFAULT_SCHEMA_KEYS;
128-
activeStep?: string;
129128
config: LangGraphAgentConfig;
130129

131130
constructor(config: LangGraphAgentConfig) {
@@ -235,11 +234,6 @@ export class LangGraphAgent extends AbstractAgent {
235234
this.activeRun!.manuallyEmittedState = null;
236235

237236
const nodeNameInput = forwardedProps?.nodeName;
238-
this.activeRun!.nodeName = nodeNameInput;
239-
if (this.activeRun!.nodeName === "__end__") {
240-
this.activeRun!.nodeName = undefined;
241-
}
242-
243237
const threadId = inputThreadId ?? randomUUID();
244238

245239
if (!this.assistant) {
@@ -347,6 +341,7 @@ export class LangGraphAgent extends AbstractAgent {
347341
threadId,
348342
runId: input.runId,
349343
});
344+
this.handleNodeChange(nodeNameInput)
350345

351346
interrupts.forEach((interrupt) => {
352347
this.dispatchEvent({
@@ -400,11 +395,7 @@ export class LangGraphAgent extends AbstractAgent {
400395
threadId,
401396
runId: this.activeRun!.id,
402397
});
403-
404-
// In case of resume (interrupt), re-start resumed step
405-
if (forwardedProps?.command?.resume && this.activeRun!.nodeName) {
406-
this.startStep(this.activeRun!.nodeName);
407-
}
398+
this.handleNodeChange(nodeNameInput)
408399

409400
for await (let streamResponseChunk of streamResponse) {
410401
const subgraphsStreamEnabled = input.forwardedProps?.streamSubgraphs;
@@ -460,11 +451,7 @@ export class LangGraphAgent extends AbstractAgent {
460451
this.activeRun!.id = metadata.run_id;
461452

462453
if (currentNodeName && currentNodeName !== this.activeRun!.nodeName) {
463-
if (this.activeRun!.nodeName && this.activeRun!.nodeName !== nodeNameInput) {
464-
this.endStep();
465-
}
466-
467-
this.startStep(currentNodeName);
454+
this.handleNodeChange(currentNodeName)
468455
}
469456

470457
shouldExit =
@@ -482,7 +469,7 @@ export class LangGraphAgent extends AbstractAgent {
482469
// we only want to update the node name under certain conditions
483470
// since we don't need any internal node names to be sent to the frontend
484471
if (this.activeRun!.graphInfo?.["nodes"].some((node) => node.id === currentNodeName)) {
485-
this.activeRun!.nodeName = currentNodeName;
472+
this.handleNodeChange(currentNodeName)
486473
}
487474

488475
updatedState.values = this.activeRun!.manuallyEmittedState ?? latestStateValues;
@@ -523,6 +510,7 @@ export class LangGraphAgent extends AbstractAgent {
523510
const isEndNode = state.next.length === 0;
524511
const writes = state.metadata?.writes ?? {};
525512

513+
// Initialize a new node name to use in the next if block
526514
let newNodeName = this.activeRun!.nodeName!;
527515

528516
if (!interrupts?.length) {
@@ -539,12 +527,10 @@ export class LangGraphAgent extends AbstractAgent {
539527
});
540528
});
541529

542-
if (this.activeRun!.nodeName != newNodeName) {
543-
this.endStep();
544-
this.startStep(newNodeName);
545-
}
530+
this.handleNodeChange(newNodeName);
531+
// Immediately turn off new step
532+
this.handleNodeChange(undefined);
546533

547-
this.endStep();
548534
this.dispatchEvent({
549535
type: EventType.STATE_SNAPSHOT,
550536
snapshot: this.getStateSnapshot(state),
@@ -1017,28 +1003,35 @@ export class LangGraphAgent extends AbstractAgent {
10171003
};
10181004
}
10191005

1020-
startStep(nodeName: string) {
1021-
if (this.activeStep) {
1022-
this.endStep();
1006+
handleNodeChange(nodeName: string | undefined) {
1007+
if (nodeName === "__end__") {
1008+
nodeName = undefined;
10231009
}
1010+
if (nodeName !== this.activeRun?.nodeName) {
1011+
// End current step
1012+
if (this.activeRun?.nodeName) {
1013+
this.endStep();
1014+
}
1015+
// If we actually got a node name, start a new step
1016+
if (nodeName) {
1017+
this.startStep(nodeName);
1018+
}
1019+
}
1020+
this.activeRun!.nodeName = nodeName;
1021+
}
1022+
1023+
startStep(nodeName: string) {
10241024
this.dispatchEvent({
10251025
type: EventType.STEP_STARTED,
10261026
stepName: nodeName,
10271027
});
1028-
this.activeRun!.nodeName = nodeName;
1029-
this.activeStep = nodeName;
10301028
}
10311029

10321030
endStep() {
1033-
if (!this.activeStep) {
1034-
throw new Error("No active step to end");
1035-
}
10361031
this.dispatchEvent({
10371032
type: EventType.STEP_FINISHED,
1038-
stepName: this.activeRun!.nodeName! ?? this.activeStep,
1033+
stepName: this.activeRun!.nodeName!,
10391034
});
1040-
this.activeRun!.nodeName = undefined;
1041-
this.activeStep = undefined;
10421035
}
10431036

10441037
async getCheckpointByMessage(

0 commit comments

Comments
 (0)