Skip to content

Commit 36ae9dd

Browse files
committed
feat: provide agui context in langgraph state
1 parent b32d883 commit 36ae9dd

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,14 +262,13 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
262262
async def prepare_stream(self, input: RunAgentInput, agent_state: State, config: RunnableConfig):
263263
state_input = input.state or {}
264264
messages = input.messages or []
265-
tools = input.tools or []
266265
forwarded_props = input.forwarded_props or {}
267266
thread_id = input.thread_id
268267

269268
state_input["messages"] = agent_state.values.get("messages", [])
270269
self.active_run["current_graph_state"] = agent_state.values.copy()
271270
langchain_messages = agui_messages_to_langchain(messages)
272-
state = self.langgraph_default_merge_state(state_input, langchain_messages, tools)
271+
state = self.langgraph_default_merge_state(state_input, langchain_messages, input)
273272
self.active_run["current_graph_state"].update(state)
274273
config["configurable"]["thread_id"] = thread_id
275274
interrupts = agent_state.tasks[0].interrupts if agent_state.tasks and len(agent_state.tasks) > 0 else []
@@ -368,7 +367,7 @@ async def prepare_regenerate_stream( # pylint: disable=too-many-arguments
368367
as_node=time_travel_checkpoint.next[0] if time_travel_checkpoint.next else "__start__"
369368
)
370369

371-
stream_input = self.langgraph_default_merge_state(time_travel_checkpoint.values, [message_checkpoint], tools)
370+
stream_input = self.langgraph_default_merge_state(time_travel_checkpoint.values, [message_checkpoint], input)
372371
subgraphs_stream_enabled = input.forwarded_props.get('stream_subgraphs') if input.forwarded_props else False
373372
stream = self.graph.astream_events(
374373
stream_input,
@@ -415,7 +414,7 @@ def get_schema_keys(self, config) -> SchemaKeys:
415414
"config": [],
416415
}
417416

418-
def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage], tools: Any) -> State:
417+
def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage], input: RunAgentInput) -> State:
419418
if messages and isinstance(messages[0], SystemMessage):
420419
messages = messages[1:]
421420

@@ -424,6 +423,7 @@ def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage
424423

425424
new_messages = [msg for msg in messages if msg.id not in existing_message_ids]
426425

426+
tools = input.tools or []
427427
tools_as_dicts = []
428428
if tools:
429429
for tool in tools:
@@ -438,6 +438,10 @@ def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage
438438
**state,
439439
"messages": new_messages,
440440
"tools": [*state.get("tools", []), *tools_as_dicts],
441+
"ag-ui": {
442+
"tools": [*state.get("tools", []), *tools_as_dicts],
443+
"context": input.context or []
444+
}
441445
}
442446

443447
def get_state_snapshot(self, state: State) -> State:

typescript-sdk/integrations/langgraph/src/agent.ts

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ export class LangGraphAgent extends AbstractAgent {
196196
}
197197

198198
const fork = await this.client.threads.updateState(threadId, {
199-
values: this.langGraphDefaultMergeState(timeTravelCheckpoint.values, [], tools),
199+
values: this.langGraphDefaultMergeState(timeTravelCheckpoint.values, [], input),
200200
checkpointId: timeTravelCheckpoint.checkpoint.checkpoint_id!,
201201
asNode: timeTravelCheckpoint.next?.[0] ?? "__start__",
202202
});
@@ -206,7 +206,7 @@ export class LangGraphAgent extends AbstractAgent {
206206
input: this.langGraphDefaultMergeState(
207207
timeTravelCheckpoint.values,
208208
[messageCheckpoint],
209-
tools,
209+
input,
210210
),
211211
// @ts-ignore
212212
checkpointId: fork.checkpoint.checkpoint_id!,
@@ -255,7 +255,7 @@ export class LangGraphAgent extends AbstractAgent {
255255
const stateValuesDiff = this.langGraphDefaultMergeState(
256256
{ ...inputState, messages: agentStateMessages },
257257
inputMessagesToLangchain,
258-
tools,
258+
input,
259259
);
260260
// Messages are a combination of existing messages in state + everything that was newly sent
261261
let threadState = {
@@ -968,7 +968,7 @@ export class LangGraphAgent extends AbstractAgent {
968968
}
969969
}
970970

971-
langGraphDefaultMergeState(state: State, messages: LangGraphMessage[], tools: any): State {
971+
langGraphDefaultMergeState(state: State, messages: LangGraphMessage[], input: RunAgentExtendedInput): State {
972972
if (messages.length > 0 && "role" in messages[0] && messages[0].role === "system") {
973973
// remove system message
974974
messages = messages.slice(1);
@@ -980,7 +980,7 @@ export class LangGraphAgent extends AbstractAgent {
980980

981981
const newMessages = messages.filter((message) => !existingMessageIds.has(message.id));
982982

983-
const langGraphTools = [...(state.tools ?? []), ...(tools ?? [])].map((tool) => {
983+
const langGraphTools = [...(state.tools ?? []), ...(input.tools ?? [])].map((tool) => {
984984
if (tool.type) {
985985
return tool;
986986
}
@@ -999,6 +999,10 @@ export class LangGraphAgent extends AbstractAgent {
999999
...state,
10001000
messages: newMessages,
10011001
tools: langGraphTools,
1002+
'ag-ui': {
1003+
tools: langGraphTools,
1004+
context: input.context,
1005+
}
10021006
};
10031007
}
10041008

0 commit comments

Comments
 (0)