Skip to content

Commit ff0dfdc

Browse files
authored
feat: enable app context injection for mastra and langgraph (#358)
1 parent 88f24ef commit ff0dfdc

File tree

4 files changed

+47
-16
lines changed

4 files changed

+47
-16
lines changed

typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,14 +262,13 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
262262
async def prepare_stream(self, input: RunAgentInput, agent_state: State, config: RunnableConfig):
263263
state_input = input.state or {}
264264
messages = input.messages or []
265-
tools = input.tools or []
266265
forwarded_props = input.forwarded_props or {}
267266
thread_id = input.thread_id
268267

269268
state_input["messages"] = agent_state.values.get("messages", [])
270269
self.active_run["current_graph_state"] = agent_state.values.copy()
271270
langchain_messages = agui_messages_to_langchain(messages)
272-
state = self.langgraph_default_merge_state(state_input, langchain_messages, tools)
271+
state = self.langgraph_default_merge_state(state_input, langchain_messages, input)
273272
self.active_run["current_graph_state"].update(state)
274273
config["configurable"]["thread_id"] = thread_id
275274
interrupts = agent_state.tasks[0].interrupts if agent_state.tasks and len(agent_state.tasks) > 0 else []
@@ -368,7 +367,7 @@ async def prepare_regenerate_stream( # pylint: disable=too-many-arguments
368367
as_node=time_travel_checkpoint.next[0] if time_travel_checkpoint.next else "__start__"
369368
)
370369

371-
stream_input = self.langgraph_default_merge_state(time_travel_checkpoint.values, [message_checkpoint], tools)
370+
stream_input = self.langgraph_default_merge_state(time_travel_checkpoint.values, [message_checkpoint], input)
372371
subgraphs_stream_enabled = input.forwarded_props.get('stream_subgraphs') if input.forwarded_props else False
373372
stream = self.graph.astream_events(
374373
stream_input,
@@ -415,7 +414,7 @@ def get_schema_keys(self, config) -> SchemaKeys:
415414
"config": [],
416415
}
417416

418-
def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage], tools: Any) -> State:
417+
def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage], input: RunAgentInput) -> State:
419418
if messages and isinstance(messages[0], SystemMessage):
420419
messages = messages[1:]
421420

@@ -424,6 +423,7 @@ def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage
424423

425424
new_messages = [msg for msg in messages if msg.id not in existing_message_ids]
426425

426+
tools = input.tools or []
427427
tools_as_dicts = []
428428
if tools:
429429
for tool in tools:
@@ -438,6 +438,10 @@ def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage
438438
**state,
439439
"messages": new_messages,
440440
"tools": [*state.get("tools", []), *tools_as_dicts],
441+
"ag-ui": {
442+
"tools": [*state.get("tools", []), *tools_as_dicts],
443+
"context": input.context or []
444+
}
441445
}
442446

443447
def get_state_snapshot(self, state: State) -> State:

typescript-sdk/integrations/langgraph/src/agent.ts

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import {
2424
RunMetadata,
2525
PredictStateTool,
2626
LangGraphReasoning,
27+
StateEnrichment,
28+
LangGraphTool,
2729
} from "./types";
2830
import {
2931
AbstractAgent,
@@ -181,7 +183,7 @@ export class LangGraphAgent extends AbstractAgent {
181183
}
182184

183185
async prepareRegenerateStream(input: RegenerateInput, streamMode: StreamMode | StreamMode[]) {
184-
const { threadId, messageCheckpoint, tools } = input;
186+
const { threadId, messageCheckpoint } = input;
185187

186188
const timeTravelCheckpoint = await this.getCheckpointByMessage(
187189
messageCheckpoint!.id!,
@@ -196,7 +198,7 @@ export class LangGraphAgent extends AbstractAgent {
196198
}
197199

198200
const fork = await this.client.threads.updateState(threadId, {
199-
values: this.langGraphDefaultMergeState(timeTravelCheckpoint.values, [], tools),
201+
values: this.langGraphDefaultMergeState(timeTravelCheckpoint.values, [], input),
200202
checkpointId: timeTravelCheckpoint.checkpoint.checkpoint_id!,
201203
asNode: timeTravelCheckpoint.next?.[0] ?? "__start__",
202204
});
@@ -206,7 +208,7 @@ export class LangGraphAgent extends AbstractAgent {
206208
input: this.langGraphDefaultMergeState(
207209
timeTravelCheckpoint.values,
208210
[messageCheckpoint],
209-
tools,
211+
input,
210212
),
211213
// @ts-ignore
212214
checkpointId: fork.checkpoint.checkpoint_id!,
@@ -255,14 +257,14 @@ export class LangGraphAgent extends AbstractAgent {
255257
const stateValuesDiff = this.langGraphDefaultMergeState(
256258
{ ...inputState, messages: agentStateMessages },
257259
inputMessagesToLangchain,
258-
tools,
260+
input,
259261
);
260262
// Messages are a combination of existing messages in state + everything that was newly sent
261263
let threadState = {
262264
...agentState,
263265
values: {
264266
...stateValuesDiff,
265-
messages: [...agentStateMessages, ...stateValuesDiff.messages],
267+
messages: [...agentStateMessages, ...(stateValuesDiff.messages ?? [])],
266268
},
267269
};
268270
let stateValues = threadState.values;
@@ -968,7 +970,7 @@ export class LangGraphAgent extends AbstractAgent {
968970
}
969971
}
970972

971-
langGraphDefaultMergeState(state: State, messages: LangGraphMessage[], tools: any): State {
973+
langGraphDefaultMergeState(state: State, messages: LangGraphMessage[], input: RunAgentExtendedInput): State<StateEnrichment> {
972974
if (messages.length > 0 && "role" in messages[0] && messages[0].role === "system") {
973975
// remove system message
974976
messages = messages.slice(1);
@@ -980,7 +982,7 @@ export class LangGraphAgent extends AbstractAgent {
980982

981983
const newMessages = messages.filter((message) => !existingMessageIds.has(message.id));
982984

983-
const langGraphTools = [...(state.tools ?? []), ...(tools ?? [])].map((tool) => {
985+
const langGraphTools: LangGraphTool[] = [...(state.tools ?? []), ...(input.tools ?? [])].map((tool) => {
984986
if (tool.type) {
985987
return tool;
986988
}
@@ -999,6 +1001,10 @@ export class LangGraphAgent extends AbstractAgent {
9991001
...state,
10001002
messages: newMessages,
10011003
tools: langGraphTools,
1004+
'ag-ui': {
1005+
tools: langGraphTools,
1006+
context: input.context,
1007+
}
10021008
};
10031009
}
10041010

typescript-sdk/integrations/langgraph/src/types.ts

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import { AssistantGraph, Message } from "@langchain/langgraph-sdk";
1+
import { AssistantGraph, Message as LangGraphMessage, } from "@langchain/langgraph-sdk";
22
import { MessageType } from "@langchain/core/messages";
3+
import { RunAgentInput } from "@ag-ui/core";
34

45
export enum LangGraphEventTypes {
56
OnChainStart = "on_chain_start",
@@ -14,7 +15,26 @@ export enum LangGraphEventTypes {
1415
OnInterrupt = "on_interrupt",
1516
}
1617

17-
export type State = Record<string, any>;
18+
export type LangGraphTool = {
19+
type: "function";
20+
function: {
21+
name: string;
22+
description: string;
23+
parameters: any;
24+
},
25+
}
26+
27+
export type State<TDefinedState = Record<string, any>> = {
28+
[k in keyof TDefinedState]: TDefinedState[k] | null;
29+
} & Record<string, any>;
30+
export interface StateEnrichment {
31+
messages: LangGraphMessage[];
32+
tools: LangGraphTool[];
33+
'ag-ui': {
34+
tools: LangGraphTool[];
35+
context: RunAgentInput['context']
36+
}
37+
}
1838

1939
export type SchemaKeys = {
2040
input: string[] | null;
@@ -54,7 +74,7 @@ export interface ToolCall {
5474
}
5575

5676
type BaseLangGraphPlatformMessage = Omit<
57-
Message,
77+
LangGraphMessage,
5878
| "isResultMessage"
5979
| "isTextMessage"
6080
| "isImageMessage"

typescript-sdk/integrations/mastra/src/mastra.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ export class MastraAgent extends AbstractAgent {
5656
super(rest);
5757
this.agent = agent;
5858
this.resourceId = resourceId;
59-
this.runtimeContext = runtimeContext;
59+
this.runtimeContext = runtimeContext ?? new RuntimeContext();
6060
}
6161

6262
protected run(input: RunAgentInput): Observable<BaseEvent> {
@@ -227,7 +227,7 @@ export class MastraAgent extends AbstractAgent {
227227
* @returns The stream of the mastra agent.
228228
*/
229229
private async streamMastraAgent(
230-
{ threadId, runId, messages, tools }: RunAgentInput,
230+
{ threadId, runId, messages, tools, context: inputContext }: RunAgentInput,
231231
{
232232
onTextPart,
233233
onFinishMessagePart,
@@ -250,6 +250,7 @@ export class MastraAgent extends AbstractAgent {
250250
);
251251
const resourceId = this.resourceId ?? threadId;
252252
const convertedMessages = convertAGUIMessagesToMastra(messages);
253+
this.runtimeContext?.set('ag-ui', { context: inputContext });
253254
const runtimeContext = this.runtimeContext;
254255

255256
if (this.isLocalMastraAgent(this.agent)) {

0 commit comments

Comments
 (0)