Skip to content

Commit c392056

Browse files
committed
fix: close off dangling steps by properly tracking open ones
1 parent 1a7fcc2 commit c392056

File tree

2 files changed

+62
-28
lines changed

2 files changed

+62
-28
lines changed

typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import uuid
22
import json
33
from typing import Optional, List, Any, Union, AsyncGenerator, Generator
4+
from dataclasses import is_dataclass, asdict
5+
from datetime import date, datetime
46

57
from langgraph.graph.state import CompiledStateGraph
68
from langchain.schema import BaseMessage, SystemMessage
@@ -85,6 +87,7 @@ def __init__(self, *, name: str, graph: CompiledStateGraph, description: Optiona
8587
self.messages_in_process: MessagesInProgressRecord = {}
8688
self.active_run: Optional[RunMetadata] = None
8789
self.constant_schema_keys = ['messages', 'tools']
90+
self.active_step = None
8891

8992
def _dispatch_event(self, event: ProcessedEvents) -> str:
9093
return event # Fallback if no encoder
@@ -135,9 +138,8 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
135138

136139
# In case of resume (interrupt), re-start resumed step
137140
if resume_input and self.active_run.get("node_name"):
138-
yield self._dispatch_event(
139-
StepStartedEvent(type=EventType.STEP_STARTED, step_name=self.active_run.get("node_name"))
140-
)
141+
for ev in self.start_step(self.active_run.get("node_name")):
142+
yield ev
141143

142144
state = prepared_stream_response["state"]
143145
stream = prepared_stream_response["stream"]
@@ -151,6 +153,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
151153

152154
should_exit = False
153155
current_graph_state = state
156+
154157
async for event in stream:
155158
if event["event"] == "error":
156159
yield self._dispatch_event(
@@ -175,16 +178,8 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
175178
)
176179

177180
if current_node_name and current_node_name != self.active_run.get("node_name"):
178-
if self.active_run["node_name"] and self.active_run["node_name"] != node_name_input:
179-
yield self._dispatch_event(
180-
StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"])
181-
)
182-
self.active_run["node_name"] = None
183-
184-
yield self._dispatch_event(
185-
StepStartedEvent(type=EventType.STEP_STARTED, step_name=current_node_name)
186-
)
187-
self.active_run["node_name"] = current_node_name
181+
for ev in self.start_step(current_node_name):
182+
yield ev
188183

189184
updated_state = self.active_run.get("manually_emitted_state") or current_graph_state
190185
has_state_diff = updated_state != state
@@ -224,19 +219,14 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
224219
CustomEvent(
225220
type=EventType.CUSTOM,
226221
name=LangGraphEventTypes.OnInterrupt.value,
227-
value=json.dumps(interrupt.value) if not isinstance(interrupt.value, str) else interrupt.value,
222+
value=json.dumps(interrupt.value, default=make_json_safe) if not isinstance(interrupt.value, str) else interrupt.value,
228223
raw_event=interrupt,
229224
)
230225
)
231226

232227
if self.active_run.get("node_name") != node_name:
233-
yield self._dispatch_event(
234-
StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"])
235-
)
236-
self.active_run["node_name"] = node_name
237-
yield self._dispatch_event(
238-
StepStartedEvent(type=EventType.STEP_STARTED, step_name=self.active_run["node_name"])
239-
)
228+
for ev in self.start_step(node_name):
229+
yield ev
240230

241231
state_values = state.values if state.values else state
242232
yield self._dispatch_event(
@@ -250,10 +240,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
250240
)
251241
)
252242

253-
yield self._dispatch_event(
254-
StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"])
255-
)
256-
self.active_run["node_name"] = None
243+
yield self.end_step()
257244

258245
yield self._dispatch_event(
259246
RunFinishedEvent(type=EventType.RUN_FINISHED, thread_id=thread_id, run_id=self.active_run["id"])
@@ -700,3 +687,43 @@ async def get_checkpoint_before_message(self, message_id: str, thread_id: str):
700687

701688
raise ValueError("Message ID not found in history")
702689

690+
def start_step(self, step_name: str):
691+
if self.active_step:
692+
yield self.end_step()
693+
694+
yield self._dispatch_event(
695+
StepStartedEvent(
696+
type=EventType.STEP_STARTED,
697+
step_name=step_name
698+
)
699+
)
700+
self.active_run["node_name"] = step_name
701+
self.active_step = step_name
702+
703+
def end_step(self):
704+
if self.active_step is None:
705+
raise ValueError("No active step to end")
706+
707+
dispatch = self._dispatch_event(
708+
StepFinishedEvent(
709+
type=EventType.STEP_FINISHED,
710+
step_name=self.active_run["node_name"]
711+
)
712+
)
713+
714+
self.active_run["node_name"] = None
715+
self.active_step = None
716+
return dispatch
717+
718+
def make_json_safe(o):
719+
if is_dataclass(o): # dataclasses like Flight(...)
720+
return asdict(o)
721+
if hasattr(o, "model_dump"): # pydantic v2
722+
return o.model_dump()
723+
if hasattr(o, "dict"): # pydantic v1
724+
return o.dict()
725+
if hasattr(o, "__dict__"): # plain objects
726+
return vars(o)
727+
if isinstance(o, (datetime, date)):
728+
return o.isoformat()
729+
return str(o) # last resort

typescript-sdk/integrations/langgraph/src/agent.ts

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ export class LangGraphAgent extends AbstractAgent {
123123
// @ts-expect-error no need to initialize subscriber right now
124124
subscriber: Subscriber<ProcessedEvents>;
125125
constantSchemaKeys: string[] = DEFAULT_SCHEMA_KEYS;
126+
activeStep?: string;
126127

127128
constructor(config: LangGraphAgentConfig) {
128129
super(config);
@@ -383,7 +384,9 @@ export class LangGraphAgent extends AbstractAgent {
383384
break;
384385
}
385386

386-
if (streamResponseChunk.event === "updates") continue;
387+
if (streamResponseChunk.event === "updates") {
388+
continue;
389+
}
387390

388391
if (streamResponseChunk.event === "values") {
389392
latestStateValues = chunk.data;
@@ -467,7 +470,6 @@ export class LangGraphAgent extends AbstractAgent {
467470
newNodeName = isEndNode ? '__end__' : (state.next[0] ?? Object.keys(writes)[0]);
468471
}
469472

470-
471473
interrupts.forEach((interrupt) => {
472474
this.dispatchEvent({
473475
type: EventType.CUSTOM,
@@ -944,22 +946,27 @@ export class LangGraphAgent extends AbstractAgent {
944946
}
945947

946948
startStep(nodeName: string) {
949+
if (this.activeStep) {
950+
this.endStep()
951+
}
947952
this.dispatchEvent({
948953
type: EventType.STEP_STARTED,
949954
stepName: nodeName,
950955
});
951956
this.activeRun!.nodeName = nodeName;
957+
this.activeStep = nodeName;
952958
}
953959

954960
endStep() {
955-
if (!this.activeRun!.nodeName) {
961+
if (!this.activeStep) {
956962
throw new Error("No active step to end");
957963
}
958964
this.dispatchEvent({
959965
type: EventType.STEP_FINISHED,
960966
stepName: this.activeRun!.nodeName!,
961967
});
962968
this.activeRun!.nodeName = undefined;
969+
this.activeStep = undefined;
963970
}
964971

965972
async getCheckpointByMessage(

0 commit comments

Comments
 (0)