Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -262,14 +262,13 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
async def prepare_stream(self, input: RunAgentInput, agent_state: State, config: RunnableConfig):
state_input = input.state or {}
messages = input.messages or []
tools = input.tools or []
forwarded_props = input.forwarded_props or {}
thread_id = input.thread_id

state_input["messages"] = agent_state.values.get("messages", [])
self.active_run["current_graph_state"] = agent_state.values.copy()
langchain_messages = agui_messages_to_langchain(messages)
state = self.langgraph_default_merge_state(state_input, langchain_messages, tools)
state = self.langgraph_default_merge_state(state_input, langchain_messages, input)
self.active_run["current_graph_state"].update(state)
config["configurable"]["thread_id"] = thread_id
interrupts = agent_state.tasks[0].interrupts if agent_state.tasks and len(agent_state.tasks) > 0 else []
Expand Down Expand Up @@ -368,7 +367,7 @@ async def prepare_regenerate_stream( # pylint: disable=too-many-arguments
as_node=time_travel_checkpoint.next[0] if time_travel_checkpoint.next else "__start__"
)

stream_input = self.langgraph_default_merge_state(time_travel_checkpoint.values, [message_checkpoint], tools)
stream_input = self.langgraph_default_merge_state(time_travel_checkpoint.values, [message_checkpoint], input)
subgraphs_stream_enabled = input.forwarded_props.get('stream_subgraphs') if input.forwarded_props else False
stream = self.graph.astream_events(
stream_input,
Expand Down Expand Up @@ -415,7 +414,7 @@ def get_schema_keys(self, config) -> SchemaKeys:
"config": [],
}

def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage], tools: Any) -> State:
def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage], input: RunAgentInput) -> State:
if messages and isinstance(messages[0], SystemMessage):
messages = messages[1:]

Expand All @@ -424,6 +423,7 @@ def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage

new_messages = [msg for msg in messages if msg.id not in existing_message_ids]

tools = input.tools or []
tools_as_dicts = []
if tools:
for tool in tools:
Expand All @@ -438,6 +438,10 @@ def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage
**state,
"messages": new_messages,
"tools": [*state.get("tools", []), *tools_as_dicts],
"ag-ui": {
"tools": [*state.get("tools", []), *tools_as_dicts],
"context": input.context or []
}
}

def get_state_snapshot(self, state: State) -> State:
Expand Down
20 changes: 13 additions & 7 deletions typescript-sdk/integrations/langgraph/src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import {
RunMetadata,
PredictStateTool,
LangGraphReasoning,
StateEnrichment,
LangGraphTool,
} from "./types";
import {
AbstractAgent,
Expand Down Expand Up @@ -181,7 +183,7 @@ export class LangGraphAgent extends AbstractAgent {
}

async prepareRegenerateStream(input: RegenerateInput, streamMode: StreamMode | StreamMode[]) {
const { threadId, messageCheckpoint, tools } = input;
const { threadId, messageCheckpoint } = input;

const timeTravelCheckpoint = await this.getCheckpointByMessage(
messageCheckpoint!.id!,
Expand All @@ -196,7 +198,7 @@ export class LangGraphAgent extends AbstractAgent {
}

const fork = await this.client.threads.updateState(threadId, {
values: this.langGraphDefaultMergeState(timeTravelCheckpoint.values, [], tools),
values: this.langGraphDefaultMergeState(timeTravelCheckpoint.values, [], input),
checkpointId: timeTravelCheckpoint.checkpoint.checkpoint_id!,
asNode: timeTravelCheckpoint.next?.[0] ?? "__start__",
});
Expand All @@ -206,7 +208,7 @@ export class LangGraphAgent extends AbstractAgent {
input: this.langGraphDefaultMergeState(
timeTravelCheckpoint.values,
[messageCheckpoint],
tools,
input,
),
// @ts-ignore
checkpointId: fork.checkpoint.checkpoint_id!,
Expand Down Expand Up @@ -255,14 +257,14 @@ export class LangGraphAgent extends AbstractAgent {
const stateValuesDiff = this.langGraphDefaultMergeState(
{ ...inputState, messages: agentStateMessages },
inputMessagesToLangchain,
tools,
input,
);
// Messages are a combination of existing messages in state + everything that was newly sent
let threadState = {
...agentState,
values: {
...stateValuesDiff,
messages: [...agentStateMessages, ...stateValuesDiff.messages],
messages: [...agentStateMessages, ...(stateValuesDiff.messages ?? [])],
},
};
let stateValues = threadState.values;
Expand Down Expand Up @@ -968,7 +970,7 @@ export class LangGraphAgent extends AbstractAgent {
}
}

langGraphDefaultMergeState(state: State, messages: LangGraphMessage[], tools: any): State {
langGraphDefaultMergeState(state: State, messages: LangGraphMessage[], input: RunAgentExtendedInput): State<StateEnrichment> {
if (messages.length > 0 && "role" in messages[0] && messages[0].role === "system") {
// remove system message
messages = messages.slice(1);
Expand All @@ -980,7 +982,7 @@ export class LangGraphAgent extends AbstractAgent {

const newMessages = messages.filter((message) => !existingMessageIds.has(message.id));

const langGraphTools = [...(state.tools ?? []), ...(tools ?? [])].map((tool) => {
const langGraphTools: LangGraphTool[] = [...(state.tools ?? []), ...(input.tools ?? [])].map((tool) => {
if (tool.type) {
return tool;
}
Expand All @@ -999,6 +1001,10 @@ export class LangGraphAgent extends AbstractAgent {
...state,
messages: newMessages,
tools: langGraphTools,
'ag-ui': {
tools: langGraphTools,
context: input.context,
}
};
}

Expand Down
26 changes: 23 additions & 3 deletions typescript-sdk/integrations/langgraph/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { AssistantGraph, Message } from "@langchain/langgraph-sdk";
import { AssistantGraph, Message as LangGraphMessage, } from "@langchain/langgraph-sdk";
import { MessageType } from "@langchain/core/messages";
import { RunAgentInput } from "@ag-ui/core";

export enum LangGraphEventTypes {
OnChainStart = "on_chain_start",
Expand All @@ -14,7 +15,26 @@ export enum LangGraphEventTypes {
OnInterrupt = "on_interrupt",
}

export type State = Record<string, any>;
export type LangGraphTool = {
type: "function";
function: {
name: string;
description: string;
parameters: any;
},
}

export type State<TDefinedState = Record<string, any>> = {
[k in keyof TDefinedState]: TDefinedState[k] | null;
} & Record<string, any>;
export interface StateEnrichment {
messages: LangGraphMessage[];
tools: LangGraphTool[];
'ag-ui': {
tools: LangGraphTool[];
context: RunAgentInput['context']
}
}

export type SchemaKeys = {
input: string[] | null;
Expand Down Expand Up @@ -54,7 +74,7 @@ export interface ToolCall {
}

type BaseLangGraphPlatformMessage = Omit<
Message,
LangGraphMessage,
| "isResultMessage"
| "isTextMessage"
| "isImageMessage"
Expand Down
5 changes: 3 additions & 2 deletions typescript-sdk/integrations/mastra/src/mastra.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ export class MastraAgent extends AbstractAgent {
super(rest);
this.agent = agent;
this.resourceId = resourceId;
this.runtimeContext = runtimeContext;
this.runtimeContext = runtimeContext ?? new RuntimeContext();
}

protected run(input: RunAgentInput): Observable<BaseEvent> {
Expand Down Expand Up @@ -227,7 +227,7 @@ export class MastraAgent extends AbstractAgent {
* @returns The stream of the mastra agent.
*/
private async streamMastraAgent(
{ threadId, runId, messages, tools }: RunAgentInput,
{ threadId, runId, messages, tools, context: inputContext }: RunAgentInput,
{
onTextPart,
onFinishMessagePart,
Expand All @@ -250,6 +250,7 @@ export class MastraAgent extends AbstractAgent {
);
const resourceId = this.resourceId ?? threadId;
const convertedMessages = convertAGUIMessagesToMastra(messages);
this.runtimeContext?.set('ag-ui', { context: inputContext });
const runtimeContext = this.runtimeContext;

if (this.isLocalMastraAgent(this.agent)) {
Expand Down
Loading