Skip to content

Commit 7a4a385

Browse files
fix(langchain): make ToolNode implementation less middleware centric (#9315)
1 parent 1838473 commit 7a4a385

File tree

12 files changed

+68
-60
lines changed

12 files changed

+68
-60
lines changed

libs/langchain-core/src/tools/index.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import {
2020
import type { RunnableFunc } from "../runnables/base.js";
2121
import { isDirectToolOutput, ToolCall, ToolMessage } from "../messages/tool.js";
2222
import { AsyncLocalStorageProviderSingleton } from "../singletons/index.js";
23+
import type { RunnableToolLike } from "../runnables/base.js";
2324
import {
2425
_configHasToolCallId,
2526
_isToolCall,
@@ -773,3 +774,9 @@ function _stringify(content: unknown): string {
773774
return `${content}`;
774775
}
775776
}
777+
778+
export type ServerTool = Record<string, unknown>;
779+
export type ClientTool =
780+
| StructuredToolInterface
781+
| DynamicTool
782+
| RunnableToolLike;

libs/langchain/src/agents/ReactAgent.ts

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,14 @@ import { ToolMessage, AIMessage } from "@langchain/core/messages";
1818
import { IterableReadableStream } from "@langchain/core/utils/stream";
1919
import type { Runnable, RunnableConfig } from "@langchain/core/runnables";
2020
import type { StreamEvent } from "@langchain/core/tracers/log_stream";
21+
import type { ClientTool, ServerTool } from "@langchain/core/tools";
2122

2223
import { createAgentAnnotationConditional } from "./annotation.js";
23-
import { isClientTool, validateLLMHasNoBoundTools } from "./utils.js";
24+
import {
25+
isClientTool,
26+
validateLLMHasNoBoundTools,
27+
wrapToolCall,
28+
} from "./utils.js";
2429

2530
import { AgentNode } from "./nodes/AgentNode.js";
2631
import { ToolNode } from "./nodes/ToolNode.js";
@@ -35,7 +40,6 @@ import {
3540
import { StateManager } from "./state.js";
3641

3742
import type { WithStateGraphNodes } from "./types.js";
38-
import type { ClientTool, ServerTool } from "./tools.js";
3943

4044
import type {
4145
CreateAgentParams,
@@ -333,8 +337,7 @@ export class ReactAgent<
333337
if (toolClasses.filter(isClientTool).length > 0) {
334338
const toolNode = new ToolNode(toolClasses.filter(isClientTool), {
335339
signal: this.options.signal,
336-
middleware,
337-
stateManager: this.#stateManager,
340+
wrapToolCall: wrapToolCall(middleware),
338341
});
339342
allNodeWorkflows.addNode("tools", toolNode);
340343
}

libs/langchain/src/agents/middleware.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ import type {
33
InteropZodObject,
44
InferInteropZodOutput,
55
} from "@langchain/core/utils/types";
6+
import type { ClientTool, ServerTool } from "@langchain/core/tools";
67

7-
import type { ClientTool, ServerTool } from "./tools.js";
88
import type {
99
AgentMiddleware,
1010
WrapToolCallHook,

libs/langchain/src/agents/middleware/toolRetry.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
/**
22
* Tool retry middleware for agents.
33
*/
4-
5-
import { ToolMessage } from "@langchain/core/messages";
64
import { z } from "zod/v3";
5+
import { ToolMessage } from "@langchain/core/messages";
6+
import type { ClientTool, ServerTool } from "@langchain/core/tools";
7+
78
import { createMiddleware } from "../middleware.js";
8-
import type { ClientTool, ServerTool } from "../tools.js";
99
import type { AgentMiddleware } from "./types.js";
1010
import { sleep } from "./utils.js";
1111

libs/langchain/src/agents/middleware/types.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,17 @@ import type { AnnotationRoot } from "@langchain/langgraph";
1111
import type { AIMessage, ToolMessage } from "@langchain/core/messages";
1212
import type { ToolCall } from "@langchain/core/messages/tool";
1313
import type { Command } from "@langchain/langgraph";
14+
import type { ClientTool, ServerTool } from "@langchain/core/tools";
1415

1516
import type { JumpToTarget } from "../constants.js";
16-
import type { ClientTool, ServerTool } from "../tools.js";
1717
import type { Runtime, AgentBuiltInState } from "../runtime.js";
1818
import type { ModelRequest } from "../nodes/types.js";
1919

2020
type PromiseOrValue<T> = T | Promise<T>;
2121

2222
export type AnyAnnotationRoot = AnnotationRoot<any>;
2323

24-
type NormalizedSchemaInput<
24+
export type NormalizedSchemaInput<
2525
TSchema extends InteropZodObject | undefined | never = any
2626
> = [TSchema] extends [never]
2727
? AgentBuiltInState

libs/langchain/src/agents/nodes/AgentNode.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import {
1111
interopZodObjectPartial,
1212
} from "@langchain/core/utils/types";
1313
import type { ToolCall } from "@langchain/core/messages/tool";
14+
import type { ClientTool, ServerTool } from "@langchain/core/tools";
1415

1516
import { initChatModel } from "../../chat_models/universal.js";
1617
import { MultipleStructuredOutputsError } from "../errors.js";
@@ -32,7 +33,6 @@ import type {
3233
WrapModelCallHandler,
3334
} from "../middleware/types.js";
3435
import type { ModelRequest } from "./types.js";
35-
import type { ClientTool, ServerTool } from "../tools.js";
3636
import { withAgentName } from "../withAgentName.js";
3737
import {
3838
ToolStrategy,

libs/langchain/src/agents/nodes/ToolNode.ts

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,13 @@ import {
2020
import { RunnableCallable } from "../RunnableCallable.js";
2121
import { PreHookAnnotation } from "../annotation.js";
2222
import { mergeAbortSignals } from "./utils.js";
23-
import { wrapToolCall } from "../utils.js";
2423
import { ToolInvocationError } from "../errors.js";
2524
import type {
26-
AnyAnnotationRoot,
25+
WrapToolCallHook,
2726
ToolCallRequest,
2827
ToAnnotationRoot,
2928
} from "../middleware/types.js";
30-
import type { AgentMiddleware } from "../middleware/types.js";
31-
import type { StateManager } from "../state.js";
29+
import type { AgentBuiltInState } from "../runtime.js";
3230

3331
export interface ToolNodeOptions {
3432
/**
@@ -66,13 +64,11 @@ export interface ToolNodeOptions {
6664
| boolean
6765
| ((error: unknown, toolCall: ToolCall) => ToolMessage | undefined);
6866
/**
69-
* The middleware to use for tool execution.
67+
* Optional wrapper function for tool execution.
68+
* Allows middleware to intercept and modify tool calls before execution.
69+
* The wrapper receives the tool call request and a handler function to execute the tool.
7070
*/
71-
middleware?: readonly AgentMiddleware[];
72-
/**
73-
* The state manager to use for tool execution.
74-
*/
75-
stateManager?: StateManager;
71+
wrapToolCall?: WrapToolCallHook;
7672
}
7773

7874
const isBaseMessageArray = (input: unknown): input is BaseMessage[] =>
@@ -165,8 +161,8 @@ function defaultHandleToolErrors(
165161
* ```
166162
*/
167163
export class ToolNode<
168-
StateSchema extends AnyAnnotationRoot | InteropZodObject = any,
169-
ContextSchema extends AnyAnnotationRoot | InteropZodObject = any
164+
StateSchema extends InteropZodObject = any,
165+
ContextSchema extends InteropZodObject = any
170166
> extends RunnableCallable<StateSchema, ContextSchema> {
171167
tools: (StructuredToolInterface | DynamicTool | RunnableToolLike)[];
172168

@@ -179,15 +175,13 @@ export class ToolNode<
179175
| ((error: unknown, toolCall: ToolCall) => ToolMessage | undefined) =
180176
defaultHandleToolErrors;
181177

182-
middleware: readonly AgentMiddleware[] = [];
183-
184-
stateManager?: StateManager;
178+
wrapToolCall: WrapToolCallHook | undefined;
185179

186180
constructor(
187181
tools: (StructuredToolInterface | DynamicTool | RunnableToolLike)[],
188182
public options?: ToolNodeOptions
189183
) {
190-
const { name, tags, handleToolErrors, middleware, stateManager, signal } =
184+
const { name, tags, handleToolErrors, signal, wrapToolCall } =
191185
options ?? {};
192186
super({
193187
name,
@@ -201,9 +195,8 @@ export class ToolNode<
201195
});
202196
this.tools = tools;
203197
this.handleToolErrors = handleToolErrors ?? this.handleToolErrors;
204-
this.middleware = middleware ?? [];
205198
this.signal = signal;
206-
this.stateManager = stateManager;
199+
this.wrapToolCall = wrapToolCall;
207200
}
208201

209202
/**
@@ -279,7 +272,7 @@ export class ToolNode<
279272
protected async runTool(
280273
call: ToolCall,
281274
config: RunnableConfig,
282-
state: ToAnnotationRoot<StateSchema>["State"] & PreHookAnnotation["State"]
275+
state: AgentBuiltInState
283276
): Promise<ToolMessage | Command> {
284277
/**
285278
* Define the base handler that executes the tool.
@@ -355,20 +348,12 @@ export class ToolNode<
355348
runtime,
356349
};
357350

358-
/**
359-
* Collect and compose wrapToolCall handlers from middleware
360-
* Wrap each handler with error handling and validation
361-
*/
362-
const wrapToolCallHandler = this.stateManager
363-
? wrapToolCall(this.middleware, state)
364-
: undefined;
365-
366351
/**
367352
* If wrapToolCall is provided, use it to wrap the tool execution
368353
*/
369-
if (wrapToolCallHandler) {
354+
if (this.wrapToolCall) {
370355
try {
371-
return await wrapToolCallHandler(request, baseHandler);
356+
return await this.wrapToolCall(request, baseHandler);
372357
} catch (e: unknown) {
373358
/**
374359
* Handle middleware errors

libs/langchain/src/agents/nodes/types.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import type { LanguageModelLike } from "@langchain/core/language_models/base";
22
import type { BaseMessage } from "@langchain/core/messages";
3-
import type { ServerTool, ClientTool } from "../tools.js";
3+
import type { ServerTool, ClientTool } from "@langchain/core/tools";
4+
45
import type { Runtime, AgentBuiltInState } from "../runtime.js";
56

67
/**

libs/langchain/src/agents/tests/middleware.test-d.ts

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ import { describe, it, expectTypeOf } from "vitest";
22
import { z } from "zod/v3";
33
import { HumanMessage, BaseMessage, AIMessage } from "@langchain/core/messages";
44
import { tool } from "@langchain/core/tools";
5+
import type { ServerTool, ClientTool } from "@langchain/core/tools";
56

67
import { createAgent, createMiddleware } from "../index.js";
78
import type { AgentBuiltInState } from "../runtime.js";
8-
import type { ServerTool, ClientTool } from "../tools.js";
99

1010
describe("middleware types", () => {
1111
it("a middleware can define a state schema which is propagated to the result", async () => {
@@ -202,12 +202,34 @@ describe("middleware types", () => {
202202
customDefaultContextProp: string;
203203
customOptionalContextProp?: string;
204204
}>();
205+
expectTypeOf(request.state).toEqualTypeOf<
206+
{
207+
customDefaultStateProp: string;
208+
customOptionalStateProp?: string;
209+
customRequiredStateProp: string;
210+
} & AgentBuiltInState
211+
>();
205212

206213
return handler({
207214
...request,
208215
tools: [tool(() => "result", { name: "toolA" })],
209216
});
210217
},
218+
wrapToolCall: async (request, handler) => {
219+
expectTypeOf(request.runtime.context).toEqualTypeOf<{
220+
customDefaultContextProp: string;
221+
customOptionalContextProp?: string;
222+
}>();
223+
expectTypeOf(request.state).toEqualTypeOf<
224+
{
225+
customDefaultStateProp: string;
226+
customOptionalStateProp?: string;
227+
customRequiredStateProp: string;
228+
} & AgentBuiltInState
229+
>();
230+
231+
return handler(request);
232+
},
211233
});
212234

213235
const agent = createAgent({

libs/langchain/src/agents/tools.ts

Lines changed: 0 additions & 11 deletions
This file was deleted.

0 commit comments

Comments
 (0)