Skip to content

Commit e2c0b2e

Browse files
fix(langchain): fix state in wrapToolCall and wrapModelCall (langchain-ai#9306)
1 parent 5d8b75f commit e2c0b2e

File tree

7 files changed

+304
-114
lines changed

7 files changed

+304
-114
lines changed

libs/langchain/src/agents/ReactAgent.ts

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,7 @@ import type { Runnable, RunnableConfig } from "@langchain/core/runnables";
2020
import type { StreamEvent } from "@langchain/core/tracers/log_stream";
2121

2222
import { createAgentAnnotationConditional } from "./annotation.js";
23-
import {
24-
isClientTool,
25-
validateLLMHasNoBoundTools,
26-
wrapToolCall,
27-
} from "./utils.js";
23+
import { isClientTool, validateLLMHasNoBoundTools } from "./utils.js";
2824

2925
import { AgentNode } from "./nodes/AgentNode.js";
3026
import { ToolNode } from "./nodes/ToolNode.js";
@@ -36,6 +32,7 @@ import {
3632
initializeMiddlewareStates,
3733
parseJumpToTarget,
3834
} from "./nodes/utils.js";
35+
import { StateManager } from "./state.js";
3936

4037
import type { WithStateGraphNodes } from "./types.js";
4138
import type { ClientTool, ServerTool } from "./tools.js";
@@ -129,6 +126,8 @@ export class ReactAgent<
129126

130127
#agentNode: AgentNode<any, AnyAnnotationRoot>;
131128

129+
#stateManager = new StateManager();
130+
132131
constructor(
133132
public options: CreateAgentParams<
134133
StructuredResponseFormat,
@@ -245,21 +244,12 @@ export class ReactAgent<
245244
throw new Error(`Middleware ${m.name} is defined multiple times`);
246245
}
247246

248-
const getState = () => {
249-
return {
250-
...beforeAgentNode?.getState(),
251-
...beforeModelNode?.getState(),
252-
...afterModelNode?.getState(),
253-
...afterAgentNode?.getState(),
254-
...this.#agentNode.getState(),
255-
};
256-
};
257-
258247
middlewareNames.add(m.name);
259248
if (m.beforeAgent) {
260249
beforeAgentNode = new BeforeAgentNode(m, {
261-
getState,
250+
getState: () => this.#stateManager.getState(m.name),
262251
});
252+
this.#stateManager.addNode(m, beforeAgentNode);
263253
const name = `${m.name}.before_agent`;
264254
beforeAgentNodes.push({
265255
index: i,
@@ -274,8 +264,9 @@ export class ReactAgent<
274264
}
275265
if (m.beforeModel) {
276266
beforeModelNode = new BeforeModelNode(m, {
277-
getState,
267+
getState: () => this.#stateManager.getState(m.name),
278268
});
269+
this.#stateManager.addNode(m, beforeModelNode);
279270
const name = `${m.name}.before_model`;
280271
beforeModelNodes.push({
281272
index: i,
@@ -290,8 +281,9 @@ export class ReactAgent<
290281
}
291282
if (m.afterModel) {
292283
afterModelNode = new AfterModelNode(m, {
293-
getState,
284+
getState: () => this.#stateManager.getState(m.name),
294285
});
286+
this.#stateManager.addNode(m, afterModelNode);
295287
const name = `${m.name}.after_model`;
296288
afterModelNodes.push({
297289
index: i,
@@ -306,8 +298,9 @@ export class ReactAgent<
306298
}
307299
if (m.afterAgent) {
308300
afterAgentNode = new AfterAgentNode(m, {
309-
getState,
301+
getState: () => this.#stateManager.getState(m.name),
310302
});
303+
this.#stateManager.addNode(m, afterAgentNode);
311304
const name = `${m.name}.after_agent`;
312305
afterAgentNodes.push({
313306
index: i,
@@ -322,32 +315,26 @@ export class ReactAgent<
322315
}
323316

324317
if (m.wrapModelCall) {
325-
wrapModelCallHookMiddleware.push([m, getState]);
318+
wrapModelCallHookMiddleware.push([
319+
m,
320+
() => this.#stateManager.getState(m.name),
321+
]);
326322
}
327323
}
328324

329325
/**
330326
* Add Nodes
331327
*/
332-
allNodeWorkflows.addNode(
333-
"model_request",
334-
this.#agentNode,
335-
AgentNode.nodeOptions
336-
);
337-
338-
/**
339-
* Collect and compose wrapToolCall handlers from middleware
340-
* Wrap each handler with error handling and validation
341-
*/
342-
const wrapToolCallHandler = wrapToolCall(middleware);
328+
allNodeWorkflows.addNode("model_request", this.#agentNode);
343329

344330
/**
345331
* add single tool node for all tools
346332
*/
347333
if (toolClasses.filter(isClientTool).length > 0) {
348334
const toolNode = new ToolNode(toolClasses.filter(isClientTool), {
349335
signal: this.options.signal,
350-
wrapToolCall: wrapToolCallHandler,
336+
middleware,
337+
stateManager: this.#stateManager,
351338
});
352339
allNodeWorkflows.addNode("tools", toolNode);
353340
}

libs/langchain/src/agents/nodes/AgentNode.ts

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
/* eslint-disable no-instanceof/no-instanceof */
22
import { Runnable, RunnableConfig } from "@langchain/core/runnables";
33
import { BaseMessage, AIMessage, ToolMessage } from "@langchain/core/messages";
4-
import { z } from "zod/v3";
54
import { Command, type LangGraphRunnableConfig } from "@langchain/langgraph";
65
import { type LanguageModelLike } from "@langchain/core/language_models/base";
76
import { type BaseChatModelCallOptions } from "@langchain/core/language_models/chat_models";
87
import {
98
InteropZodObject,
109
getSchemaDescription,
1110
interopParse,
11+
interopZodObjectPartial,
1212
} from "@langchain/core/utils/types";
1313
import type { ToolCall } from "@langchain/core/messages/tool";
1414

@@ -403,6 +403,12 @@ export class AgentNode<
403403
> = {
404404
...request,
405405
state: {
406+
...(middleware.stateSchema
407+
? interopParse(
408+
interopZodObjectPartial(middleware.stateSchema),
409+
state
410+
)
411+
: {}),
406412
...currentGetState(),
407413
messages: state.messages,
408414
} as InternalAgentState<StructuredResponseFormat> &
@@ -510,10 +516,7 @@ export class AgentNode<
510516
systemPrompt: this.#options.systemPrompt,
511517
messages: state.messages,
512518
tools: this.#options.toolClasses,
513-
state: {
514-
messages: state.messages,
515-
} as InternalAgentState<StructuredResponseFormat> &
516-
PreHookAnnotation["State"],
519+
state,
517520
runtime: Object.freeze({
518521
context: lgConfig?.context,
519522
writer: lgConfig.writer,
@@ -814,18 +817,6 @@ export class AgentNode<
814817
return modelRunnable;
815818
}
816819

817-
static get nodeOptions(): {
818-
input: z.ZodObject<{
819-
messages: z.ZodArray<z.ZodType<BaseMessage>>;
820-
}>;
821-
} {
822-
return {
823-
input: z.object({
824-
messages: z.array(z.custom<BaseMessage>()),
825-
}),
826-
};
827-
}
828-
829820
getState(): {
830821
messages: BaseMessage[];
831822
} {

libs/langchain/src/agents/nodes/ToolNode.ts

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ import {
2020
import { RunnableCallable } from "../RunnableCallable.js";
2121
import { PreHookAnnotation } from "../annotation.js";
2222
import { mergeAbortSignals } from "./utils.js";
23+
import { wrapToolCall } from "../utils.js";
2324
import { ToolInvocationError } from "../errors.js";
2425
import type {
2526
AnyAnnotationRoot,
26-
WrapToolCallHook,
2727
ToolCallRequest,
2828
ToAnnotationRoot,
2929
} from "../middleware/types.js";
30+
import type { AgentMiddleware } from "../middleware/types.js";
31+
import type { StateManager } from "../state.js";
3032

3133
export interface ToolNodeOptions {
3234
/**
@@ -64,11 +66,13 @@ export interface ToolNodeOptions {
6466
| boolean
6567
| ((error: unknown, toolCall: ToolCall) => ToolMessage | undefined);
6668
/**
67-
* Optional wrapper function for tool execution.
68-
* Allows middleware to intercept and modify tool calls before execution.
69-
* The wrapper receives the tool call request and a handler function to execute the tool.
69+
* The middleware to use for tool execution.
7070
*/
71-
wrapToolCall?: WrapToolCallHook;
71+
middleware?: readonly AgentMiddleware[];
72+
/**
73+
* The state manager to use for tool execution.
74+
*/
75+
stateManager?: StateManager;
7276
}
7377

7478
const isBaseMessageArray = (input: unknown): input is BaseMessage[] =>
@@ -175,13 +179,16 @@ export class ToolNode<
175179
| ((error: unknown, toolCall: ToolCall) => ToolMessage | undefined) =
176180
defaultHandleToolErrors;
177181

178-
wrapToolCall?: WrapToolCallHook;
182+
middleware: readonly AgentMiddleware[] = [];
183+
184+
stateManager?: StateManager;
179185

180186
constructor(
181187
tools: (StructuredToolInterface | DynamicTool | RunnableToolLike)[],
182188
public options?: ToolNodeOptions
183189
) {
184-
const { name, tags, handleToolErrors, wrapToolCall } = options ?? {};
190+
const { name, tags, handleToolErrors, middleware, stateManager, signal } =
191+
options ?? {};
185192
super({
186193
name,
187194
tags,
@@ -194,8 +201,9 @@ export class ToolNode<
194201
});
195202
this.tools = tools;
196203
this.handleToolErrors = handleToolErrors ?? this.handleToolErrors;
197-
this.wrapToolCall = wrapToolCall;
198-
this.signal = options?.signal;
204+
this.middleware = middleware ?? [];
205+
this.signal = signal;
206+
this.stateManager = stateManager;
199207
}
200208

201209
/**
@@ -271,7 +279,7 @@ export class ToolNode<
271279
protected async runTool(
272280
call: ToolCall,
273281
config: RunnableConfig,
274-
state?: ToAnnotationRoot<StateSchema>["State"] & PreHookAnnotation["State"]
282+
state: ToAnnotationRoot<StateSchema>["State"] & PreHookAnnotation["State"]
275283
): Promise<ToolMessage | Command> {
276284
/**
277285
* Define the base handler that executes the tool.
@@ -343,16 +351,24 @@ export class ToolNode<
343351
const request = {
344352
toolCall: call,
345353
tool,
346-
state: state || ({} as any),
354+
state,
347355
runtime,
348356
};
349357

358+
/**
359+
* Collect and compose wrapToolCall handlers from middleware
360+
* Wrap each handler with error handling and validation
361+
*/
362+
const wrapToolCallHandler = this.stateManager
363+
? wrapToolCall(this.middleware, state)
364+
: undefined;
365+
350366
/**
351367
* If wrapToolCall is provided, use it to wrap the tool execution
352368
*/
353-
if (this.wrapToolCall && state) {
369+
if (wrapToolCallHandler) {
354370
try {
355-
return await this.wrapToolCall(request, baseHandler);
371+
return await wrapToolCallHandler(request, baseHandler);
356372
} catch (e: unknown) {
357373
/**
358374
* Handle middleware errors
@@ -381,7 +397,8 @@ export class ToolNode<
381397
let outputs: (ToolMessage | Command)[];
382398

383399
if (isSendInput(state)) {
384-
outputs = [await this.runTool(state.lg_tool_call, config, state)];
400+
const { lg_tool_call, jumpTo, ...newState } = state;
401+
outputs = [await this.runTool(state.lg_tool_call, config, newState)];
385402
} else {
386403
let messages: BaseMessage[];
387404
if (isBaseMessageArray(state)) {

0 commit comments

Comments
 (0)