Skip to content

Commit 3d2f13a

Browse files
feat(langchain): model call limits (#33178)
This PR adds a model call limit middleware that helps to manage: * number of model calls during a run (helps w/ avoiding tool calling loops) - implemented w/ `UntrackedValue` * number of model calls on a thread (helps w/ avoiding lengthy convos) - standard state Concern here is w/ other middlewares overwriting the model call count... we could use a `_` prefixed field?
1 parent 99361e6 commit 3d2f13a

File tree

4 files changed

+405
-2
lines changed

4 files changed

+405
-2
lines changed
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
"""Call tracking middleware for agents."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING, Any, Literal
6+
7+
from langchain_core.messages import AIMessage
8+
9+
from langchain.agents.middleware.types import AgentMiddleware, AgentState, hook_config
10+
11+
if TYPE_CHECKING:
12+
from langgraph.runtime import Runtime
13+
14+
15+
def _build_limit_exceeded_message(
16+
thread_count: int,
17+
run_count: int,
18+
thread_limit: int | None,
19+
run_limit: int | None,
20+
) -> str:
21+
"""Build a message indicating which limits were exceeded.
22+
23+
Args:
24+
thread_count: Current thread model call count.
25+
run_count: Current run model call count.
26+
thread_limit: Thread model call limit (if set).
27+
run_limit: Run model call limit (if set).
28+
29+
Returns:
30+
A formatted message describing which limits were exceeded.
31+
"""
32+
exceeded_limits = []
33+
if thread_limit is not None and thread_count >= thread_limit:
34+
exceeded_limits.append(f"thread limit ({thread_count}/{thread_limit})")
35+
if run_limit is not None and run_count >= run_limit:
36+
exceeded_limits.append(f"run limit ({run_count}/{run_limit})")
37+
38+
return f"Model call limits exceeded: {', '.join(exceeded_limits)}"
39+
40+
41+
class ModelCallLimitExceededError(Exception):
42+
"""Exception raised when model call limits are exceeded.
43+
44+
This exception is raised when the configured exit behavior is 'error'
45+
and either the thread or run model call limit has been exceeded.
46+
"""
47+
48+
def __init__(
49+
self,
50+
thread_count: int,
51+
run_count: int,
52+
thread_limit: int | None,
53+
run_limit: int | None,
54+
) -> None:
55+
"""Initialize the exception with call count information.
56+
57+
Args:
58+
thread_count: Current thread model call count.
59+
run_count: Current run model call count.
60+
thread_limit: Thread model call limit (if set).
61+
run_limit: Run model call limit (if set).
62+
"""
63+
self.thread_count = thread_count
64+
self.run_count = run_count
65+
self.thread_limit = thread_limit
66+
self.run_limit = run_limit
67+
68+
msg = _build_limit_exceeded_message(thread_count, run_count, thread_limit, run_limit)
69+
super().__init__(msg)
70+
71+
72+
class ModelCallLimitMiddleware(AgentMiddleware):
73+
"""Middleware that tracks model call counts and enforces limits.
74+
75+
This middleware monitors the number of model calls made during agent execution
76+
and can terminate the agent when specified limits are reached. It supports
77+
both thread-level and run-level call counting with configurable exit behaviors.
78+
79+
Thread-level: The middleware tracks the number of model calls and persists
80+
call count across multiple runs (invocations) of the agent.
81+
82+
Run-level: The middleware tracks the number of model calls made during a single
83+
run (invocation) of the agent.
84+
85+
Example:
86+
```python
87+
from langchain.agents.middleware.call_tracking import ModelCallLimitMiddleware
88+
from langchain.agents import create_agent
89+
90+
# Create middleware with limits
91+
call_tracker = ModelCallLimitMiddleware(thread_limit=10, run_limit=5, exit_behavior="end")
92+
93+
agent = create_agent("openai:gpt-4o", middleware=[call_tracker])
94+
95+
# Agent will automatically jump to end when limits are exceeded
96+
result = await agent.invoke({"messages": [HumanMessage("Help me with a task")]})
97+
```
98+
"""
99+
100+
def __init__(
101+
self,
102+
*,
103+
thread_limit: int | None = None,
104+
run_limit: int | None = None,
105+
exit_behavior: Literal["end", "error"] = "end",
106+
) -> None:
107+
"""Initialize the call tracking middleware.
108+
109+
Args:
110+
thread_limit: Maximum number of model calls allowed per thread.
111+
None means no limit. Defaults to None.
112+
run_limit: Maximum number of model calls allowed per run.
113+
None means no limit. Defaults to None.
114+
exit_behavior: What to do when limits are exceeded.
115+
- "end": Jump to the end of the agent execution and
116+
inject an artificial AI message indicating that the limit was exceeded.
117+
- "error": Raise a ModelCallLimitExceededError
118+
Defaults to "end".
119+
120+
Raises:
121+
ValueError: If both limits are None or if exit_behavior is invalid.
122+
"""
123+
super().__init__()
124+
125+
if thread_limit is None and run_limit is None:
126+
msg = "At least one limit must be specified (thread_limit or run_limit)"
127+
raise ValueError(msg)
128+
129+
if exit_behavior not in ("end", "error"):
130+
msg = f"Invalid exit_behavior: {exit_behavior}. Must be 'end' or 'error'"
131+
raise ValueError(msg)
132+
133+
self.thread_limit = thread_limit
134+
self.run_limit = run_limit
135+
self.exit_behavior = exit_behavior
136+
137+
@hook_config(can_jump_to=["end"])
138+
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
139+
"""Check model call limits before making a model call.
140+
141+
Args:
142+
state: The current agent state containing call counts.
143+
runtime: The langgraph runtime.
144+
145+
Returns:
146+
If limits are exceeded and exit_behavior is "end", returns
147+
a Command to jump to the end with a limit exceeded message. Otherwise returns None.
148+
149+
Raises:
150+
ModelCallLimitExceededError: If limits are exceeded and exit_behavior
151+
is "error".
152+
"""
153+
thread_count = state.get("thread_model_call_count", 0)
154+
run_count = state.get("run_model_call_count", 0)
155+
156+
# Check if any limits will be exceeded after the next call
157+
thread_limit_exceeded = self.thread_limit is not None and thread_count >= self.thread_limit
158+
run_limit_exceeded = self.run_limit is not None and run_count >= self.run_limit
159+
160+
if thread_limit_exceeded or run_limit_exceeded:
161+
if self.exit_behavior == "error":
162+
raise ModelCallLimitExceededError(
163+
thread_count=thread_count,
164+
run_count=run_count,
165+
thread_limit=self.thread_limit,
166+
run_limit=self.run_limit,
167+
)
168+
if self.exit_behavior == "end":
169+
# Create a message indicating the limit was exceeded
170+
limit_message = _build_limit_exceeded_message(
171+
thread_count, run_count, self.thread_limit, self.run_limit
172+
)
173+
limit_ai_message = AIMessage(content=limit_message)
174+
175+
return {"jump_to": "end", "messages": [limit_ai_message]}
176+
177+
return None

libs/langchain_v1/langchain/agents/middleware/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# needed as top level import for pydantic schema generation on AgentState
2525
from langchain_core.messages import AnyMessage # noqa: TC002
2626
from langgraph.channels.ephemeral_value import EphemeralValue
27+
from langgraph.channels.untracked_value import UntrackedValue
2728
from langgraph.graph.message import add_messages
2829
from langgraph.runtime import Runtime
2930
from langgraph.typing import ContextT
@@ -93,6 +94,8 @@ class AgentState(TypedDict, Generic[ResponseT]):
9394
messages: Required[Annotated[list[AnyMessage], add_messages]]
9495
jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]]
9596
structured_response: NotRequired[ResponseT]
97+
thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
98+
run_model_call_count: NotRequired[Annotated[int, UntrackedValue, PrivateStateAttr]]
9699

97100

98101
class PublicAgentState(TypedDict, Generic[ResponseT]):

libs/langchain_v1/langchain/agents/middleware_agent.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,11 @@ def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, An
438438
messages = [SystemMessage(request.system_prompt), *messages]
439439

440440
output = model_.invoke(messages)
441-
return _handle_model_output(output)
441+
return {
442+
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
443+
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
444+
**_handle_model_output(output),
445+
}
442446

443447
async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
444448
"""Async model request handler with sequential middleware processing."""
@@ -462,7 +466,11 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[
462466
messages = [SystemMessage(request.system_prompt), *messages]
463467

464468
output = await model_.ainvoke(messages)
465-
return _handle_model_output(output)
469+
return {
470+
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
471+
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
472+
**_handle_model_output(output),
473+
}
466474

467475
# Use sync or async based on model capabilities
468476
from langgraph._internal._runnable import RunnableCallable

0 commit comments

Comments
 (0)