Skip to content

Commit 2286d0d

Browse files
nfcamposeyurtsev
andauthored
feat(langchain_v1): Add ToolCallLimitMiddleware (#33269)
which implements a tool call budget for either all tools, or a specific tool --------- Co-authored-by: Eugene Yurtsev <[email protected]>
1 parent 46b87e4 commit 2286d0d

File tree

6 files changed

+725
-20
lines changed

6 files changed

+725
-20
lines changed

.claude/settings.local.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"permissions": {
3+
"allow": [
4+
"Bash(uv run:*)",
5+
"Bash(make:*)"
6+
],
7+
"deny": [],
8+
"ask": []
9+
}
10+
}

libs/langchain_v1/langchain/agents/middleware/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .planning import PlanningMiddleware
55
from .prompt_caching import AnthropicPromptCachingMiddleware
66
from .summarization import SummarizationMiddleware
7+
from .tool_call_limit import ToolCallLimitMiddleware
78
from .types import (
89
AgentMiddleware,
910
AgentState,
@@ -24,6 +25,7 @@
2425
"ModelRequest",
2526
"PlanningMiddleware",
2627
"SummarizationMiddleware",
28+
"ToolCallLimitMiddleware",
2729
"after_model",
2830
"before_model",
2931
"dynamic_prompt",
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
"""Tool call limit middleware for agents."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING, Any, Literal
6+
7+
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
8+
9+
from langchain.agents.middleware.types import AgentMiddleware, AgentState, hook_config
10+
11+
if TYPE_CHECKING:
12+
from langgraph.runtime import Runtime
13+
14+
15+
def _count_tool_calls_in_messages(messages: list[AnyMessage], tool_name: str | None = None) -> int:
16+
"""Count tool calls in a list of messages.
17+
18+
Args:
19+
messages: List of messages to count tool calls in.
20+
tool_name: If specified, only count calls to this specific tool.
21+
If None, count all tool calls.
22+
23+
Returns:
24+
The total number of tool calls (optionally filtered by tool_name).
25+
"""
26+
count = 0
27+
for message in messages:
28+
if isinstance(message, AIMessage) and message.tool_calls:
29+
if tool_name is None:
30+
# Count all tool calls
31+
count += len(message.tool_calls)
32+
else:
33+
# Count only calls to the specified tool
34+
count += sum(1 for tc in message.tool_calls if tc["name"] == tool_name)
35+
return count
36+
37+
38+
def _get_run_messages(messages: list[AnyMessage]) -> list[AnyMessage]:
39+
"""Get messages from the current run (after the last HumanMessage).
40+
41+
Args:
42+
messages: Full list of messages.
43+
44+
Returns:
45+
Messages from the current run (after last HumanMessage).
46+
"""
47+
# Find the last HumanMessage
48+
last_human_index = -1
49+
for i in range(len(messages) - 1, -1, -1):
50+
if isinstance(messages[i], HumanMessage):
51+
last_human_index = i
52+
break
53+
54+
# If no HumanMessage found, return all messages
55+
if last_human_index == -1:
56+
return messages
57+
58+
# Return messages after the last HumanMessage
59+
return messages[last_human_index + 1 :]
60+
61+
62+
def _build_tool_limit_exceeded_message(
63+
thread_count: int,
64+
run_count: int,
65+
thread_limit: int | None,
66+
run_limit: int | None,
67+
tool_name: str | None,
68+
) -> str:
69+
"""Build a message indicating which tool call limits were exceeded.
70+
71+
Args:
72+
thread_count: Current thread tool call count.
73+
run_count: Current run tool call count.
74+
thread_limit: Thread tool call limit (if set).
75+
run_limit: Run tool call limit (if set).
76+
tool_name: Tool name being limited (if specific tool), or None for all tools.
77+
78+
Returns:
79+
A formatted message describing which limits were exceeded.
80+
"""
81+
tool_desc = f"'{tool_name}' tool call" if tool_name else "Tool call"
82+
exceeded_limits = []
83+
if thread_limit is not None and thread_count >= thread_limit:
84+
exceeded_limits.append(f"thread limit ({thread_count}/{thread_limit})")
85+
if run_limit is not None and run_count >= run_limit:
86+
exceeded_limits.append(f"run limit ({run_count}/{run_limit})")
87+
88+
return f"{tool_desc} limits exceeded: {', '.join(exceeded_limits)}"
89+
90+
91+
class ToolCallLimitExceededError(Exception):
92+
"""Exception raised when tool call limits are exceeded.
93+
94+
This exception is raised when the configured exit behavior is 'error'
95+
and either the thread or run tool call limit has been exceeded.
96+
"""
97+
98+
def __init__(
99+
self,
100+
thread_count: int,
101+
run_count: int,
102+
thread_limit: int | None,
103+
run_limit: int | None,
104+
tool_name: str | None = None,
105+
) -> None:
106+
"""Initialize the exception with call count information.
107+
108+
Args:
109+
thread_count: Current thread tool call count.
110+
run_count: Current run tool call count.
111+
thread_limit: Thread tool call limit (if set).
112+
run_limit: Run tool call limit (if set).
113+
tool_name: Tool name being limited (if specific tool), or None for all tools.
114+
"""
115+
self.thread_count = thread_count
116+
self.run_count = run_count
117+
self.thread_limit = thread_limit
118+
self.run_limit = run_limit
119+
self.tool_name = tool_name
120+
121+
msg = _build_tool_limit_exceeded_message(
122+
thread_count, run_count, thread_limit, run_limit, tool_name
123+
)
124+
super().__init__(msg)
125+
126+
127+
class ToolCallLimitMiddleware(AgentMiddleware):
128+
"""Middleware that tracks tool call counts and enforces limits.
129+
130+
This middleware monitors the number of tool calls made during agent execution
131+
and can terminate the agent when specified limits are reached. It supports
132+
both thread-level and run-level call counting with configurable exit behaviors.
133+
134+
Thread-level: The middleware counts all tool calls in the entire message history
135+
and persists this count across multiple runs (invocations) of the agent.
136+
137+
Run-level: The middleware counts tool calls made after the last HumanMessage,
138+
representing the current run (invocation) of the agent.
139+
140+
Example:
141+
```python
142+
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
143+
from langchain.agents import create_agent
144+
145+
# Limit all tool calls globally
146+
global_limiter = ToolCallLimitMiddleware(thread_limit=20, run_limit=10, exit_behavior="end")
147+
148+
# Limit a specific tool
149+
search_limiter = ToolCallLimitMiddleware(
150+
tool_name="search", thread_limit=5, run_limit=3, exit_behavior="end"
151+
)
152+
153+
# Use both in the same agent
154+
agent = create_agent("openai:gpt-4o", middleware=[global_limiter, search_limiter])
155+
156+
result = await agent.invoke({"messages": [HumanMessage("Help me with a task")]})
157+
```
158+
"""
159+
160+
def __init__(
161+
self,
162+
*,
163+
tool_name: str | None = None,
164+
thread_limit: int | None = None,
165+
run_limit: int | None = None,
166+
exit_behavior: Literal["end", "error"] = "end",
167+
) -> None:
168+
"""Initialize the tool call limit middleware.
169+
170+
Args:
171+
tool_name: Name of the specific tool to limit. If None, limits apply
172+
to all tools. Defaults to None.
173+
thread_limit: Maximum number of tool calls allowed per thread.
174+
None means no limit. Defaults to None.
175+
run_limit: Maximum number of tool calls allowed per run.
176+
None means no limit. Defaults to None.
177+
exit_behavior: What to do when limits are exceeded.
178+
- "end": Jump to the end of the agent execution and
179+
inject an artificial AI message indicating that the limit was exceeded.
180+
- "error": Raise a ToolCallLimitExceededError
181+
Defaults to "end".
182+
183+
Raises:
184+
ValueError: If both limits are None or if exit_behavior is invalid.
185+
"""
186+
super().__init__()
187+
188+
if thread_limit is None and run_limit is None:
189+
msg = "At least one limit must be specified (thread_limit or run_limit)"
190+
raise ValueError(msg)
191+
192+
if exit_behavior not in ("end", "error"):
193+
msg = f"Invalid exit_behavior: {exit_behavior}. Must be 'end' or 'error'"
194+
raise ValueError(msg)
195+
196+
self.tool_name = tool_name
197+
self.thread_limit = thread_limit
198+
self.run_limit = run_limit
199+
self.exit_behavior = exit_behavior
200+
201+
@property
202+
def name(self) -> str:
203+
"""The name of the middleware instance.
204+
205+
Includes the tool name if specified to allow multiple instances
206+
of this middleware with different tool names.
207+
"""
208+
base_name = self.__class__.__name__
209+
if self.tool_name:
210+
return f"{base_name}[{self.tool_name}]"
211+
return base_name
212+
213+
@hook_config(can_jump_to=["end"])
214+
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
215+
"""Check tool call limits before making a model call.
216+
217+
Args:
218+
state: The current agent state containing messages.
219+
runtime: The langgraph runtime.
220+
221+
Returns:
222+
If limits are exceeded and exit_behavior is "end", returns
223+
a Command to jump to the end with a limit exceeded message. Otherwise returns None.
224+
225+
Raises:
226+
ToolCallLimitExceededError: If limits are exceeded and exit_behavior
227+
is "error".
228+
"""
229+
messages = state.get("messages", [])
230+
231+
# Count tool calls in entire thread
232+
thread_count = _count_tool_calls_in_messages(messages, self.tool_name)
233+
234+
# Count tool calls in current run (after last HumanMessage)
235+
run_messages = _get_run_messages(messages)
236+
run_count = _count_tool_calls_in_messages(run_messages, self.tool_name)
237+
238+
# Check if any limits are exceeded
239+
thread_limit_exceeded = self.thread_limit is not None and thread_count >= self.thread_limit
240+
run_limit_exceeded = self.run_limit is not None and run_count >= self.run_limit
241+
242+
if thread_limit_exceeded or run_limit_exceeded:
243+
if self.exit_behavior == "error":
244+
raise ToolCallLimitExceededError(
245+
thread_count=thread_count,
246+
run_count=run_count,
247+
thread_limit=self.thread_limit,
248+
run_limit=self.run_limit,
249+
tool_name=self.tool_name,
250+
)
251+
if self.exit_behavior == "end":
252+
# Create a message indicating the limit was exceeded
253+
limit_message = _build_tool_limit_exceeded_message(
254+
thread_count, run_count, self.thread_limit, self.run_limit, self.tool_name
255+
)
256+
limit_ai_message = AIMessage(content=limit_message)
257+
258+
return {"jump_to": "end", "messages": [limit_ai_message]}
259+
260+
return None

libs/langchain_v1/langchain/agents/middleware/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,14 @@ class AgentMiddleware(Generic[StateT, ContextT]):
126126
tools: list[BaseTool]
127127
"""Additional tools registered by the middleware."""
128128

129+
@property
130+
def name(self) -> str:
131+
"""The name of the middleware instance.
132+
133+
Defaults to the class name, but can be overridden for custom naming.
134+
"""
135+
return self.__class__.__name__
136+
129137
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
130138
"""Logic to run before the model is called."""
131139

libs/langchain_v1/langchain/agents/middleware_agent.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def create_agent( # noqa: PLR0915
245245
) + middleware_tools
246246

247247
# validate middleware
248-
assert len({m.__class__.__name__ for m in middleware}) == len(middleware), ( # noqa: S101
248+
assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101
249249
"Please remove duplicate middleware instances."
250250
)
251251
middleware_w_before = [
@@ -548,9 +548,7 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[
548548
else None
549549
)
550550
before_node = RunnableCallable(sync_before, async_before)
551-
graph.add_node(
552-
f"{m.__class__.__name__}.before_model", before_node, input_schema=state_schema
553-
)
551+
graph.add_node(f"{m.name}.before_model", before_node, input_schema=state_schema)
554552

555553
if (
556554
m.__class__.after_model is not AgentMiddleware.after_model
@@ -569,20 +567,14 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[
569567
else None
570568
)
571569
after_node = RunnableCallable(sync_after, async_after)
572-
graph.add_node(
573-
f"{m.__class__.__name__}.after_model", after_node, input_schema=state_schema
574-
)
570+
graph.add_node(f"{m.name}.after_model", after_node, input_schema=state_schema)
575571

576572
# add start edge
577573
first_node = (
578-
f"{middleware_w_before[0].__class__.__name__}.before_model"
579-
if middleware_w_before
580-
else "model_request"
574+
f"{middleware_w_before[0].name}.before_model" if middleware_w_before else "model_request"
581575
)
582576
last_node = (
583-
f"{middleware_w_after[0].__class__.__name__}.after_model"
584-
if middleware_w_after
585-
else "model_request"
577+
f"{middleware_w_after[0].name}.after_model" if middleware_w_after else "model_request"
586578
)
587579
graph.add_edge(START, first_node)
588580

@@ -605,7 +597,7 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[
605597
# If after_model, then need to check for can_jump_to
606598
_add_middleware_edge(
607599
graph,
608-
f"{middleware_w_after[0].__class__.__name__}.after_model",
600+
f"{middleware_w_after[0].name}.after_model",
609601
END,
610602
first_node,
611603
can_jump_to=_get_can_jump_to(middleware_w_after[0], "after_model"),
@@ -616,29 +608,29 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[
616608
for m1, m2 in itertools.pairwise(middleware_w_before):
617609
_add_middleware_edge(
618610
graph,
619-
f"{m1.__class__.__name__}.before_model",
620-
f"{m2.__class__.__name__}.before_model",
611+
f"{m1.name}.before_model",
612+
f"{m2.name}.before_model",
621613
first_node,
622614
can_jump_to=_get_can_jump_to(m1, "before_model"),
623615
)
624616
# Go directly to model_request after the last before_model
625617
_add_middleware_edge(
626618
graph,
627-
f"{middleware_w_before[-1].__class__.__name__}.before_model",
619+
f"{middleware_w_before[-1].name}.before_model",
628620
"model_request",
629621
first_node,
630622
can_jump_to=_get_can_jump_to(middleware_w_before[-1], "before_model"),
631623
)
632624

633625
if middleware_w_after:
634-
graph.add_edge("model_request", f"{middleware_w_after[-1].__class__.__name__}.after_model")
626+
graph.add_edge("model_request", f"{middleware_w_after[-1].name}.after_model")
635627
for idx in range(len(middleware_w_after) - 1, 0, -1):
636628
m1 = middleware_w_after[idx]
637629
m2 = middleware_w_after[idx - 1]
638630
_add_middleware_edge(
639631
graph,
640-
f"{m1.__class__.__name__}.after_model",
641-
f"{m2.__class__.__name__}.after_model",
632+
f"{m1.name}.after_model",
633+
f"{m2.name}.after_model",
642634
first_node,
643635
can_jump_to=_get_can_jump_to(m1, "after_model"),
644636
)

0 commit comments

Comments
 (0)