diff --git a/typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py b/typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py index 392a76450..2d4893045 100644 --- a/typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py +++ b/typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py @@ -262,14 +262,13 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st async def prepare_stream(self, input: RunAgentInput, agent_state: State, config: RunnableConfig): state_input = input.state or {} messages = input.messages or [] - tools = input.tools or [] forwarded_props = input.forwarded_props or {} thread_id = input.thread_id state_input["messages"] = agent_state.values.get("messages", []) self.active_run["current_graph_state"] = agent_state.values.copy() langchain_messages = agui_messages_to_langchain(messages) - state = self.langgraph_default_merge_state(state_input, langchain_messages, tools) + state = self.langgraph_default_merge_state(state_input, langchain_messages, input) self.active_run["current_graph_state"].update(state) config["configurable"]["thread_id"] = thread_id interrupts = agent_state.tasks[0].interrupts if agent_state.tasks and len(agent_state.tasks) > 0 else [] @@ -368,7 +367,7 @@ async def prepare_regenerate_stream( # pylint: disable=too-many-arguments as_node=time_travel_checkpoint.next[0] if time_travel_checkpoint.next else "__start__" ) - stream_input = self.langgraph_default_merge_state(time_travel_checkpoint.values, [message_checkpoint], tools) + stream_input = self.langgraph_default_merge_state(time_travel_checkpoint.values, [message_checkpoint], input) subgraphs_stream_enabled = input.forwarded_props.get('stream_subgraphs') if input.forwarded_props else False stream = self.graph.astream_events( stream_input, @@ -415,7 +414,7 @@ def get_schema_keys(self, config) -> SchemaKeys: "config": [], } - def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage], tools: Any) -> State: + def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage], input: RunAgentInput) -> State: if messages and isinstance(messages[0], SystemMessage): messages = messages[1:] @@ -424,6 +423,7 @@ def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage new_messages = [msg for msg in messages if msg.id not in existing_message_ids] + tools = input.tools or [] tools_as_dicts = [] if tools: for tool in tools: @@ -438,6 +438,10 @@ def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage **state, "messages": new_messages, "tools": [*state.get("tools", []), *tools_as_dicts], + "ag-ui": { + "tools": [*state.get("tools", []), *tools_as_dicts], + "context": input.context or [] + } } def get_state_snapshot(self, state: State) -> State: diff --git a/typescript-sdk/integrations/langgraph/src/agent.ts b/typescript-sdk/integrations/langgraph/src/agent.ts index 0a34a4a6f..1dd41be9b 100644 --- a/typescript-sdk/integrations/langgraph/src/agent.ts +++ b/typescript-sdk/integrations/langgraph/src/agent.ts @@ -24,6 +24,8 @@ import { RunMetadata, PredictStateTool, LangGraphReasoning, + StateEnrichment, + LangGraphTool, } from "./types"; import { AbstractAgent, @@ -181,7 +183,7 @@ export class LangGraphAgent extends AbstractAgent { } async prepareRegenerateStream(input: RegenerateInput, streamMode: StreamMode | StreamMode[]) { - const { threadId, messageCheckpoint, tools } = input; + const { threadId, messageCheckpoint } = input; const timeTravelCheckpoint = await this.getCheckpointByMessage( messageCheckpoint!.id!, @@ -196,7 +198,7 @@ export class LangGraphAgent extends AbstractAgent { } const fork = await this.client.threads.updateState(threadId, { - values: this.langGraphDefaultMergeState(timeTravelCheckpoint.values, [], tools), + values: this.langGraphDefaultMergeState(timeTravelCheckpoint.values, [], input), checkpointId: timeTravelCheckpoint.checkpoint.checkpoint_id!, asNode: timeTravelCheckpoint.next?.[0] ?? "__start__", }); @@ -206,7 +208,7 @@ export class LangGraphAgent extends AbstractAgent { input: this.langGraphDefaultMergeState( timeTravelCheckpoint.values, [messageCheckpoint], - tools, + input, ), // @ts-ignore checkpointId: fork.checkpoint.checkpoint_id!, @@ -255,14 +257,14 @@ export class LangGraphAgent extends AbstractAgent { const stateValuesDiff = this.langGraphDefaultMergeState( { ...inputState, messages: agentStateMessages }, inputMessagesToLangchain, - tools, + input, ); // Messages are a combination of existing messages in state + everything that was newly sent let threadState = { ...agentState, values: { ...stateValuesDiff, - messages: [...agentStateMessages, ...stateValuesDiff.messages], + messages: [...agentStateMessages, ...(stateValuesDiff.messages ?? [])], }, }; let stateValues = threadState.values; @@ -968,7 +970,7 @@ export class LangGraphAgent extends AbstractAgent { } } - langGraphDefaultMergeState(state: State, messages: LangGraphMessage[], tools: any): State { + langGraphDefaultMergeState(state: State, messages: LangGraphMessage[], input: RunAgentExtendedInput): State { if (messages.length > 0 && "role" in messages[0] && messages[0].role === "system") { // remove system message messages = messages.slice(1); @@ -980,7 +982,7 @@ export class LangGraphAgent extends AbstractAgent { const newMessages = messages.filter((message) => !existingMessageIds.has(message.id)); - const langGraphTools = [...(state.tools ?? []), ...(tools ?? [])].map((tool) => { + const langGraphTools: LangGraphTool[] = [...(state.tools ?? []), ...(input.tools ?? [])].map((tool) => { if (tool.type) { return tool; } @@ -999,6 +1001,10 @@ export class LangGraphAgent extends AbstractAgent { ...state, messages: newMessages, tools: langGraphTools, + 'ag-ui': { + tools: langGraphTools, + context: input.context, + } }; } diff --git a/typescript-sdk/integrations/langgraph/src/types.ts b/typescript-sdk/integrations/langgraph/src/types.ts index 0a94756ee..3648cb409 100644 --- a/typescript-sdk/integrations/langgraph/src/types.ts +++ b/typescript-sdk/integrations/langgraph/src/types.ts @@ -1,5 +1,6 @@ -import { AssistantGraph, Message } from "@langchain/langgraph-sdk"; +import { AssistantGraph, Message as LangGraphMessage, } from "@langchain/langgraph-sdk"; import { MessageType } from "@langchain/core/messages"; +import { RunAgentInput } from "@ag-ui/core"; export enum LangGraphEventTypes { OnChainStart = "on_chain_start", @@ -14,7 +15,26 @@ export enum LangGraphEventTypes { OnInterrupt = "on_interrupt", } -export type State = Record; +export type LangGraphTool = { + type: "function"; + function: { + name: string; + description: string; + parameters: any; + }, +} + +export type State> = { + [k in keyof TDefinedState]: TDefinedState[k] | null; +} & Record; +export interface StateEnrichment { + messages: LangGraphMessage[]; + tools: LangGraphTool[]; + 'ag-ui': { + tools: LangGraphTool[]; + context: RunAgentInput['context'] + } +} export type SchemaKeys = { input: string[] | null; @@ -54,7 +74,7 @@ export interface ToolCall { } type BaseLangGraphPlatformMessage = Omit< - Message, + LangGraphMessage, | "isResultMessage" | "isTextMessage" | "isImageMessage" diff --git a/typescript-sdk/integrations/mastra/src/mastra.ts b/typescript-sdk/integrations/mastra/src/mastra.ts index 8c0c2babe..f8310767b 100644 --- a/typescript-sdk/integrations/mastra/src/mastra.ts +++ b/typescript-sdk/integrations/mastra/src/mastra.ts @@ -56,7 +56,7 @@ export class MastraAgent extends AbstractAgent { super(rest); this.agent = agent; this.resourceId = resourceId; - this.runtimeContext = runtimeContext; + this.runtimeContext = runtimeContext ?? new RuntimeContext(); } protected run(input: RunAgentInput): Observable { @@ -227,7 +227,7 @@ export class MastraAgent extends AbstractAgent { * @returns The stream of the mastra agent. */ private async streamMastraAgent( - { threadId, runId, messages, tools }: RunAgentInput, + { threadId, runId, messages, tools, context: inputContext }: RunAgentInput, { onTextPart, onFinishMessagePart, @@ -250,6 +250,7 @@ export class MastraAgent extends AbstractAgent { ); const resourceId = this.resourceId ?? threadId; const convertedMessages = convertAGUIMessagesToMastra(messages); + this.runtimeContext?.set('ag-ui', { context: inputContext }); const runtimeContext = this.runtimeContext; if (this.isLocalMastraAgent(this.agent)) {