From b32d8831d6328d403faba642b5372a2dbd827cba Mon Sep 17 00:00:00 2001 From: ran Date: Mon, 8 Sep 2025 14:52:19 +0200 Subject: [PATCH 1/4] feat: provide agui context to mastras runtime context --- typescript-sdk/integrations/mastra/src/mastra.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/typescript-sdk/integrations/mastra/src/mastra.ts b/typescript-sdk/integrations/mastra/src/mastra.ts index 8c0c2babe..a8f0d4b4c 100644 --- a/typescript-sdk/integrations/mastra/src/mastra.ts +++ b/typescript-sdk/integrations/mastra/src/mastra.ts @@ -227,7 +227,7 @@ export class MastraAgent extends AbstractAgent { * @returns The stream of the mastra agent. */ private async streamMastraAgent( - { threadId, runId, messages, tools }: RunAgentInput, + { threadId, runId, messages, tools, context: inputContext }: RunAgentInput, { onTextPart, onFinishMessagePart, @@ -250,6 +250,7 @@ export class MastraAgent extends AbstractAgent { ); const resourceId = this.resourceId ?? threadId; const convertedMessages = convertAGUIMessagesToMastra(messages); + this.runtimeContext?.set('ag-ui', inputContext); const runtimeContext = this.runtimeContext; if (this.isLocalMastraAgent(this.agent)) { From 36ae9dd2687b6673a8718efcb1cbb3a14fde0808 Mon Sep 17 00:00:00 2001 From: ran Date: Mon, 8 Sep 2025 14:52:38 +0200 Subject: [PATCH 2/4] feat: provide agui context in langgraph state --- .../langgraph/python/ag_ui_langgraph/agent.py | 12 ++++++++---- typescript-sdk/integrations/langgraph/src/agent.ts | 14 +++++++++----- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py b/typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py index 392a76450..2d4893045 100644 --- a/typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py +++ b/typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py @@ -262,14 +262,13 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st async def prepare_stream(self, input: RunAgentInput, agent_state: State, config: RunnableConfig): state_input = input.state or {} messages = input.messages or [] - tools = input.tools or [] forwarded_props = input.forwarded_props or {} thread_id = input.thread_id state_input["messages"] = agent_state.values.get("messages", []) self.active_run["current_graph_state"] = agent_state.values.copy() langchain_messages = agui_messages_to_langchain(messages) - state = self.langgraph_default_merge_state(state_input, langchain_messages, tools) + state = self.langgraph_default_merge_state(state_input, langchain_messages, input) self.active_run["current_graph_state"].update(state) config["configurable"]["thread_id"] = thread_id interrupts = agent_state.tasks[0].interrupts if agent_state.tasks and len(agent_state.tasks) > 0 else [] @@ -368,7 +367,7 @@ async def prepare_regenerate_stream( # pylint: disable=too-many-arguments as_node=time_travel_checkpoint.next[0] if time_travel_checkpoint.next else "__start__" ) - stream_input = self.langgraph_default_merge_state(time_travel_checkpoint.values, [message_checkpoint], tools) + stream_input = self.langgraph_default_merge_state(time_travel_checkpoint.values, [message_checkpoint], input) subgraphs_stream_enabled = input.forwarded_props.get('stream_subgraphs') if input.forwarded_props else False stream = self.graph.astream_events( stream_input, @@ -415,7 +414,7 @@ def get_schema_keys(self, config) -> SchemaKeys: "config": [], } - def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage], tools: Any) -> State: + def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage], input: RunAgentInput) -> State: if messages and isinstance(messages[0], SystemMessage): messages = messages[1:] @@ -424,6 +423,7 @@ def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage new_messages = [msg for msg in messages if msg.id not in existing_message_ids] + tools = input.tools or [] tools_as_dicts = [] if tools: for tool in tools: @@ -438,6 +438,10 @@ def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage **state, "messages": new_messages, "tools": [*state.get("tools", []), *tools_as_dicts], + "ag-ui": { + "tools": [*state.get("tools", []), *tools_as_dicts], + "context": input.context or [] + } } def get_state_snapshot(self, state: State) -> State: diff --git a/typescript-sdk/integrations/langgraph/src/agent.ts b/typescript-sdk/integrations/langgraph/src/agent.ts index 0a34a4a6f..288678af3 100644 --- a/typescript-sdk/integrations/langgraph/src/agent.ts +++ b/typescript-sdk/integrations/langgraph/src/agent.ts @@ -196,7 +196,7 @@ export class LangGraphAgent extends AbstractAgent { } const fork = await this.client.threads.updateState(threadId, { - values: this.langGraphDefaultMergeState(timeTravelCheckpoint.values, [], tools), + values: this.langGraphDefaultMergeState(timeTravelCheckpoint.values, [], input), checkpointId: timeTravelCheckpoint.checkpoint.checkpoint_id!, asNode: timeTravelCheckpoint.next?.[0] ?? "__start__", }); @@ -206,7 +206,7 @@ export class LangGraphAgent extends AbstractAgent { input: this.langGraphDefaultMergeState( timeTravelCheckpoint.values, [messageCheckpoint], - tools, + input, ), // @ts-ignore checkpointId: fork.checkpoint.checkpoint_id!, @@ -255,7 +255,7 @@ export class LangGraphAgent extends AbstractAgent { const stateValuesDiff = this.langGraphDefaultMergeState( { ...inputState, messages: agentStateMessages }, inputMessagesToLangchain, - tools, + input, ); // Messages are a combination of existing messages in state + everything that was newly sent let threadState = { @@ -968,7 +968,7 @@ export class LangGraphAgent extends AbstractAgent { } } - langGraphDefaultMergeState(state: State, messages: LangGraphMessage[], tools: any): State { + langGraphDefaultMergeState(state: State, messages: LangGraphMessage[], input: RunAgentExtendedInput): State { if (messages.length > 0 && "role" in messages[0] && messages[0].role === "system") { // remove system message messages = messages.slice(1); @@ -980,7 +980,7 @@ export class LangGraphAgent extends AbstractAgent { const newMessages = messages.filter((message) => !existingMessageIds.has(message.id)); - const langGraphTools = [...(state.tools ?? []), ...(tools ?? [])].map((tool) => { + const langGraphTools = [...(state.tools ?? []), ...(input.tools ?? [])].map((tool) => { if (tool.type) { return tool; } @@ -999,6 +999,10 @@ export class LangGraphAgent extends AbstractAgent { ...state, messages: newMessages, tools: langGraphTools, + 'ag-ui': { + tools: langGraphTools, + context: input.context, + } }; } From c7830f2d306a463b479c2dfeed3dd3b802e39f3a Mon Sep 17 00:00:00 2001 From: ran Date: Tue, 9 Sep 2025 12:21:09 +0200 Subject: [PATCH 3/4] fix: type adjusted state for langgraph --- .../integrations/langgraph/src/agent.ts | 10 ++++--- .../integrations/langgraph/src/types.ts | 26 ++++++++++++++++--- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/typescript-sdk/integrations/langgraph/src/agent.ts b/typescript-sdk/integrations/langgraph/src/agent.ts index 288678af3..1dd41be9b 100644 --- a/typescript-sdk/integrations/langgraph/src/agent.ts +++ b/typescript-sdk/integrations/langgraph/src/agent.ts @@ -24,6 +24,8 @@ import { RunMetadata, PredictStateTool, LangGraphReasoning, + StateEnrichment, + LangGraphTool, } from "./types"; import { AbstractAgent, @@ -181,7 +183,7 @@ export class LangGraphAgent extends AbstractAgent { } async prepareRegenerateStream(input: RegenerateInput, streamMode: StreamMode | StreamMode[]) { - const { threadId, messageCheckpoint, tools } = input; + const { threadId, messageCheckpoint } = input; const timeTravelCheckpoint = await this.getCheckpointByMessage( messageCheckpoint!.id!, @@ -262,7 +264,7 @@ export class LangGraphAgent extends AbstractAgent { ...agentState, values: { ...stateValuesDiff, - messages: [...agentStateMessages, ...stateValuesDiff.messages], + messages: [...agentStateMessages, ...(stateValuesDiff.messages ?? [])], }, }; let stateValues = threadState.values; @@ -968,7 +970,7 @@ export class LangGraphAgent extends AbstractAgent { } } - langGraphDefaultMergeState(state: State, messages: LangGraphMessage[], input: RunAgentExtendedInput): State { + langGraphDefaultMergeState(state: State, messages: LangGraphMessage[], input: RunAgentExtendedInput): State { if (messages.length > 0 && "role" in messages[0] && messages[0].role === "system") { // remove system message messages = messages.slice(1); @@ -980,7 +982,7 @@ export class LangGraphAgent extends AbstractAgent { const newMessages = messages.filter((message) => !existingMessageIds.has(message.id)); - const langGraphTools = [...(state.tools ?? []), ...(input.tools ?? [])].map((tool) => { + const langGraphTools: LangGraphTool[] = [...(state.tools ?? []), ...(input.tools ?? [])].map((tool) => { if (tool.type) { return tool; } diff --git a/typescript-sdk/integrations/langgraph/src/types.ts b/typescript-sdk/integrations/langgraph/src/types.ts index 0a94756ee..3648cb409 100644 --- a/typescript-sdk/integrations/langgraph/src/types.ts +++ b/typescript-sdk/integrations/langgraph/src/types.ts @@ -1,5 +1,6 @@ -import { AssistantGraph, Message } from "@langchain/langgraph-sdk"; +import { AssistantGraph, Message as LangGraphMessage, } from "@langchain/langgraph-sdk"; import { MessageType } from "@langchain/core/messages"; +import { RunAgentInput } from "@ag-ui/core"; export enum LangGraphEventTypes { OnChainStart = "on_chain_start", @@ -14,7 +15,26 @@ export enum LangGraphEventTypes { OnInterrupt = "on_interrupt", } -export type State = Record; +export type LangGraphTool = { + type: "function"; + function: { + name: string; + description: string; + parameters: any; + }, +} + +export type State> = { + [k in keyof TDefinedState]: TDefinedState[k] | null; +} & Record; +export interface StateEnrichment { + messages: LangGraphMessage[]; + tools: LangGraphTool[]; + 'ag-ui': { + tools: LangGraphTool[]; + context: RunAgentInput['context'] + } +} export type SchemaKeys = { input: string[] | null; @@ -54,7 +74,7 @@ export interface ToolCall { } type BaseLangGraphPlatformMessage = Omit< - Message, + LangGraphMessage, | "isResultMessage" | "isTextMessage" | "isImageMessage" From 2ad6a49c56f7befe849501d71dae0d8c850f4e3a Mon Sep 17 00:00:00 2001 From: ran Date: Tue, 9 Sep 2025 18:12:43 +0200 Subject: [PATCH 4/4] make it work with mastra --- typescript-sdk/integrations/mastra/src/mastra.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/typescript-sdk/integrations/mastra/src/mastra.ts b/typescript-sdk/integrations/mastra/src/mastra.ts index a8f0d4b4c..f8310767b 100644 --- a/typescript-sdk/integrations/mastra/src/mastra.ts +++ b/typescript-sdk/integrations/mastra/src/mastra.ts @@ -56,7 +56,7 @@ export class MastraAgent extends AbstractAgent { super(rest); this.agent = agent; this.resourceId = resourceId; - this.runtimeContext = runtimeContext; + this.runtimeContext = runtimeContext ?? new RuntimeContext(); } protected run(input: RunAgentInput): Observable { @@ -250,7 +250,7 @@ export class MastraAgent extends AbstractAgent { ); const resourceId = this.resourceId ?? threadId; const convertedMessages = convertAGUIMessagesToMastra(messages); - this.runtimeContext?.set('ag-ui', inputContext); + this.runtimeContext?.set('ag-ui', { context: inputContext }); const runtimeContext = this.runtimeContext; if (this.isLocalMastraAgent(this.agent)) {