Skip to content

Commit 0bb74b7

Browse files
committed
feat: add subgraphs support in langgraph integrations
1 parent e58eedc commit 0bb74b7

File tree

2 files changed

+106
-37
lines changed

2 files changed

+106
-37
lines changed

typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py

Lines changed: 75 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import uuid
22
import json
33
from typing import Optional, List, Any, Union, AsyncGenerator, Generator
4+
from dataclasses import is_dataclass, asdict
5+
from datetime import date, datetime
46

57
from langgraph.graph.state import CompiledStateGraph
68
from langchain.schema import BaseMessage, SystemMessage
@@ -85,6 +87,7 @@ def __init__(self, *, name: str, graph: CompiledStateGraph, description: Optiona
8587
self.messages_in_process: MessagesInProgressRecord = {}
8688
self.active_run: Optional[RunMetadata] = None
8789
self.constant_schema_keys = ['messages', 'tools']
90+
self.active_step = None
8891

8992
def _dispatch_event(self, event: ProcessedEvents) -> str:
9093
return event # Fallback if no encoder
@@ -135,9 +138,8 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
135138

136139
# In case of resume (interrupt), re-start resumed step
137140
if resume_input and self.active_run.get("node_name"):
138-
yield self._dispatch_event(
139-
StepStartedEvent(type=EventType.STEP_STARTED, step_name=self.active_run.get("node_name"))
140-
)
141+
for ev in self.start_step(self.active_run.get("node_name")):
142+
yield ev
141143

142144
state = prepared_stream_response["state"]
143145
stream = prepared_stream_response["stream"]
@@ -151,7 +153,13 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
151153

152154
should_exit = False
153155
current_graph_state = state
156+
154157
async for event in stream:
158+
subgraphs_stream_enabled = input.forwarded_props.get('stream_subgraphs') if input.forwarded_props else False
159+
is_subgraph_stream = (subgraphs_stream_enabled and (
160+
event.get("event", "").startswith("events") or
161+
event.get("event", "").startswith("values")
162+
))
155163
if event["event"] == "error":
156164
yield self._dispatch_event(
157165
RunErrorEvent(type=EventType.RUN_ERROR, message=event["data"]["message"], raw_event=event)
@@ -175,16 +183,8 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
175183
)
176184

177185
if current_node_name and current_node_name != self.active_run.get("node_name"):
178-
if self.active_run["node_name"] and self.active_run["node_name"] != node_name_input:
179-
yield self._dispatch_event(
180-
StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"])
181-
)
182-
self.active_run["node_name"] = None
183-
184-
yield self._dispatch_event(
185-
StepStartedEvent(type=EventType.STEP_STARTED, step_name=current_node_name)
186-
)
187-
self.active_run["node_name"] = current_node_name
186+
for ev in self.start_step(current_node_name):
187+
yield ev
188188

189189
updated_state = self.active_run.get("manually_emitted_state") or current_graph_state
190190
has_state_diff = updated_state != state
@@ -224,19 +224,14 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
224224
CustomEvent(
225225
type=EventType.CUSTOM,
226226
name=LangGraphEventTypes.OnInterrupt.value,
227-
value=json.dumps(interrupt.value) if not isinstance(interrupt.value, str) else interrupt.value,
227+
value=json.dumps(interrupt.value, default=make_json_safe) if not isinstance(interrupt.value, str) else interrupt.value,
228228
raw_event=interrupt,
229229
)
230230
)
231231

232232
if self.active_run.get("node_name") != node_name:
233-
yield self._dispatch_event(
234-
StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"])
235-
)
236-
self.active_run["node_name"] = node_name
237-
yield self._dispatch_event(
238-
StepStartedEvent(type=EventType.STEP_STARTED, step_name=self.active_run["node_name"])
239-
)
233+
for ev in self.start_step(node_name):
234+
yield ev
240235

241236
state_values = state.values if state.values else state
242237
yield self._dispatch_event(
@@ -250,10 +245,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
250245
)
251246
)
252247

253-
yield self._dispatch_event(
254-
StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"])
255-
)
256-
self.active_run["node_name"] = None
248+
yield self.end_step()
257249

258250
yield self._dispatch_event(
259251
RunFinishedEvent(type=EventType.RUN_FINISHED, thread_id=thread_id, run_id=self.active_run["id"])
@@ -336,8 +328,18 @@ async def prepare_stream(self, input: RunAgentInput, agent_state: State, config:
336328
)
337329
stream_input = {**forwarded_props, **payload_input} if payload_input else None
338330

331+
332+
subgraphs_stream_enabled = input.forwarded_props.get('stream_subgraphs') if input.forwarded_props else False
333+
334+
stream = self.graph.astream_events(
335+
stream_input,
336+
config=config,
337+
subgraps=bool(subgraphs_stream_enabled),
338+
version="v2"
339+
)
340+
339341
return {
340-
"stream": self.graph.astream_events(stream_input, config, version="v2"),
342+
"stream": stream,
341343
"state": state,
342344
"config": config
343345
}
@@ -362,7 +364,13 @@ async def prepare_regenerate_stream( # pylint: disable=too-many-arguments
362364
)
363365

364366
stream_input = self.langgraph_default_merge_state(time_travel_checkpoint.values, [message_checkpoint], tools)
365-
stream = self.graph.astream_events(stream_input, fork, version="v2")
367+
subgraphs_stream_enabled = input.forwarded_props.get('stream_subgraphs') if input.forwarded_props else False
368+
stream = self.graph.astream_events(
369+
stream_input,
370+
fork,
371+
subgraps=bool(subgraphs_stream_enabled),
372+
version="v2"
373+
)
366374

367375
return {
368376
"stream": stream,
@@ -700,3 +708,43 @@ async def get_checkpoint_before_message(self, message_id: str, thread_id: str):
700708

701709
raise ValueError("Message ID not found in history")
702710

711+
def start_step(self, step_name: str):
712+
if self.active_step:
713+
yield self.end_step()
714+
715+
yield self._dispatch_event(
716+
StepStartedEvent(
717+
type=EventType.STEP_STARTED,
718+
step_name=step_name
719+
)
720+
)
721+
self.active_run["node_name"] = step_name
722+
self.active_step = step_name
723+
724+
def end_step(self):
725+
if self.active_step is None:
726+
raise ValueError("No active step to end")
727+
728+
dispatch = self._dispatch_event(
729+
StepFinishedEvent(
730+
type=EventType.STEP_FINISHED,
731+
step_name=self.active_run["node_name"]
732+
)
733+
)
734+
735+
self.active_run["node_name"] = None
736+
self.active_step = None
737+
return dispatch
738+
739+
def make_json_safe(o):
740+
if is_dataclass(o): # dataclasses like Flight(...)
741+
return asdict(o)
742+
if hasattr(o, "model_dump"): # pydantic v2
743+
return o.model_dump()
744+
if hasattr(o, "dict"): # pydantic v1
745+
return o.dict()
746+
if hasattr(o, "__dict__"): # plain objects
747+
return vars(o)
748+
if isinstance(o, (datetime, date)):
749+
return o.isoformat()
750+
return str(o) # last resort

typescript-sdk/integrations/langgraph/src/agent.ts

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ export class LangGraphAgent extends AbstractAgent {
123123
// @ts-expect-error no need to initialize subscriber right now
124124
subscriber: Subscriber<ProcessedEvents>;
125125
constantSchemaKeys: string[] = DEFAULT_SCHEMA_KEYS;
126+
activeStep?: string;
126127

127128
constructor(config: LangGraphAgentConfig) {
128129
super(config);
@@ -193,13 +194,15 @@ export class LangGraphAgent extends AbstractAgent {
193194
}
194195
);
195196

197+
const payload = {
198+
...(input.forwardedProps ?? {}),
199+
input: this.langGraphDefaultMergeState(timeTravelCheckpoint.values, [messageCheckpoint], tools),
200+
// @ts-ignore
201+
checkpointId: fork.checkpoint.checkpoint_id!,
202+
streamMode,
203+
};
196204
return {
197-
streamResponse: this.client.runs.stream(threadId, this.assistant.assistant_id, {
198-
input: this.langGraphDefaultMergeState(timeTravelCheckpoint.values, [messageCheckpoint], tools),
199-
// @ts-ignore
200-
checkpointId: fork.checkpoint.checkpoint_id!,
201-
streamMode,
202-
}),
205+
streamResponse: this.client.runs.stream(threadId, this.assistant.assistant_id, payload),
203206
state: timeTravelCheckpoint as ThreadState<State>,
204207
streamMode,
205208
};
@@ -359,8 +362,14 @@ export class LangGraphAgent extends AbstractAgent {
359362
}
360363

361364
for await (let streamResponseChunk of streamResponse) {
365+
const subgraphsStreamEnabled = input.forwardedProps?.streamSubgraphs
366+
const isSubgraphStream = (subgraphsStreamEnabled && (
367+
streamResponseChunk.event.startsWith("events") ||
368+
streamResponseChunk.event.startsWith("values")
369+
))
370+
362371
// @ts-ignore
363-
if (!streamMode.includes(streamResponseChunk.event as StreamMode)) {
372+
if (!streamMode.includes(streamResponseChunk.event as StreamMode) && !isSubgraphStream) {
364373
continue;
365374
}
366375

@@ -383,11 +392,19 @@ export class LangGraphAgent extends AbstractAgent {
383392
break;
384393
}
385394

386-
if (streamResponseChunk.event === "updates") continue;
395+
if (streamResponseChunk.event === "updates") {
396+
continue;
397+
}
387398

388399
if (streamResponseChunk.event === "values") {
389400
latestStateValues = chunk.data;
390401
continue;
402+
} else if (subgraphsStreamEnabled && chunk.event.startsWith("values|")) {
403+
latestStateValues = {
404+
...latestStateValues,
405+
...chunk.data,
406+
};
407+
continue;
391408
}
392409

393410
const chunkData = chunk.data;
@@ -467,7 +484,6 @@ export class LangGraphAgent extends AbstractAgent {
467484
newNodeName = isEndNode ? '__end__' : (state.next[0] ?? Object.keys(writes)[0]);
468485
}
469486

470-
471487
interrupts.forEach((interrupt) => {
472488
this.dispatchEvent({
473489
type: EventType.CUSTOM,
@@ -944,22 +960,27 @@ export class LangGraphAgent extends AbstractAgent {
944960
}
945961

946962
startStep(nodeName: string) {
963+
if (this.activeStep) {
964+
this.endStep()
965+
}
947966
this.dispatchEvent({
948967
type: EventType.STEP_STARTED,
949968
stepName: nodeName,
950969
});
951970
this.activeRun!.nodeName = nodeName;
971+
this.activeStep = nodeName;
952972
}
953973

954974
endStep() {
955-
if (!this.activeRun!.nodeName) {
975+
if (!this.activeStep) {
956976
throw new Error("No active step to end");
957977
}
958978
this.dispatchEvent({
959979
type: EventType.STEP_FINISHED,
960980
stepName: this.activeRun!.nodeName!,
961981
});
962982
this.activeRun!.nodeName = undefined;
983+
this.activeStep = undefined;
963984
}
964985

965986
async getCheckpointByMessage(

0 commit comments

Comments
 (0)