Skip to content

Commit 58e664c

Browse files
fix(langchain): improvements on tool call limit middleware (#9321)
1 parent 32bc4c2 commit 58e664c

File tree

5 files changed

+987
-183
lines changed

5 files changed

+987
-183
lines changed

libs/langchain/src/agents/middleware/modelCallLimit.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,5 +192,8 @@ export function modelCallLimitMiddleware(
192192
runModelCallCount: state.runModelCallCount + 1,
193193
threadModelCallCount: state.threadModelCallCount + 1,
194194
}),
195+
afterAgent: () => ({
196+
runModelCallCount: 0,
197+
}),
195198
});
196199
}

libs/langchain/src/agents/middleware/tests/modelCallLimit.test.ts

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,28 @@ const toolCallMessage1 = new AIMessage({
1717
},
1818
],
1919
});
20+
const toolCallMessageForRunLimit = new AIMessage({
21+
content: "",
22+
tool_calls: [
23+
{
24+
id: "call_1",
25+
name: "tool_1",
26+
args: { arg1: "arg1" },
27+
},
28+
{
29+
id: "call_2",
30+
name: "tool_2",
31+
args: { arg1: "arg2" },
32+
},
33+
{
34+
id: "call_3",
35+
name: "tool_3",
36+
args: { arg1: "arg3" },
37+
},
38+
],
39+
});
2040
const toolCallMessage2 = new AIMessage({
21-
content: "bar",
41+
content: "",
2242
tool_calls: [
2343
{
2444
id: "call_1",
@@ -28,7 +48,7 @@ const toolCallMessage2 = new AIMessage({
2848
],
2949
});
3050
const toolCallMessage3 = new AIMessage({
31-
content: "baz",
51+
content: "",
3252
tool_calls: [
3353
{
3454
id: "call_1",
@@ -43,9 +63,6 @@ const responseMessage1 = new AIMessage({
4363
const responseMessage2 = new AIMessage({
4464
content: "fuzbaz",
4565
});
46-
const responseMessage3 = new AIMessage({
47-
content: "fuzbazbaz",
48-
});
4966

5067
const tools = [
5168
tool(() => "foobar", {
@@ -56,21 +73,35 @@ const tools = [
5673
name: "tool_2",
5774
description: "tool_2",
5875
}),
76+
tool(() => "barfoo", {
77+
name: "tool_3",
78+
description: "tool_3",
79+
}),
5980
];
6081

6182
describe("ModelCallLimitMiddleware", () => {
6283
describe.each(["throw", "end"] as const)(
6384
"run limit with exit behavior %s",
6485
(exitBehavior) => {
6586
it("should not throw if the run limit exceeds", async () => {
87+
// First invocation: 2 model calls (within limit)
88+
// Call 1: Makes tool calls for 3 tools -> tools execute -> Call 2: Final response
89+
// Second invocation: 3 model calls (exceeds limit of 2)
90+
// Call 1: Makes tool call -> tool executes -> Call 2: Makes tool call -> tool executes -> Call 3: Should fail
6691
const model = new FakeToolCallingChatModel({
6792
responses: [
68-
toolCallMessage1,
93+
// First invocation - Call 1: Makes 3 tool calls
94+
toolCallMessageForRunLimit,
95+
// First invocation - Call 2: Final response after tools execute
6996
responseMessage1,
97+
// Second invocation - Call 1: Makes 1 tool call
98+
toolCallMessage1,
99+
// Second invocation - Call 2: Makes 1 tool call (after first tool executes)
70100
toolCallMessage2,
71-
responseMessage2,
101+
// Second invocation - Call 3: Makes 1 tool call (should fail here - limit exceeded)
72102
toolCallMessage3,
73-
responseMessage3,
103+
// Second invocation - Call 4: Final response (should never reach this)
104+
responseMessage2,
74105
],
75106
});
76107
const middleware = modelCallLimitMiddleware({
@@ -151,14 +182,14 @@ describe("ModelCallLimitMiddleware", () => {
151182
{ messages: ["Hello, world!"] },
152183
config
153184
);
154-
await expect(result.runModelCallCount).toBe(3);
185+
await expect(result.runModelCallCount).toBe(0);
155186
await expect(result.threadModelCallCount).toBe(3);
156187
} else {
157188
const result = await agent2.invoke(
158189
{ messages: ["Hello, world!"] },
159190
config
160191
);
161-
await expect(result.runModelCallCount).toBe(3);
192+
await expect(result.runModelCallCount).toBe(0);
162193
await expect(result.threadModelCallCount).toBe(3);
163194
expect(result.messages.at(-1)?.content).not.toContain(
164195
"Model call limits exceeded"

0 commit comments

Comments
 (0)