Skip to content

Commit 881cd7a

Browse files
authored
Add tool_calls_limit to UsageLimits and tool_calls to RunUsage (#2633)
1 parent 4918ba9 commit 881cd7a

15 files changed

+170
-21
lines changed

docs/agents.md

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ _(This example is complete, it can be run "as is")_
539539
#### Usage Limits
540540

541541
Pydantic AI offers a [`UsageLimits`][pydantic_ai.usage.UsageLimits] structure to help you limit your
542-
usage (tokens and/or requests) on model runs.
542+
usage (tokens, requests, and tool calls) on model runs.
543543

544544
You can apply these settings by passing the `usage_limits` argument to the `run{_sync,_stream}` functions.
545545

@@ -610,8 +610,31 @@ except UsageLimitExceeded as e:
610610
1. This tool has the ability to retry 5 times before erroring, simulating a tool that might get stuck in a loop.
611611
2. This run will error after 3 requests, preventing the infinite tool calling.
612612

613+
##### Capping tool calls
614+
615+
If you need a limit on the number of successful tool invocations within a single run, use `tool_calls_limit`:
616+
617+
```py
618+
from pydantic_ai import Agent
619+
from pydantic_ai.exceptions import UsageLimitExceeded
620+
from pydantic_ai.usage import UsageLimits
621+
622+
agent = Agent('anthropic:claude-3-5-sonnet-latest')
623+
624+
@agent.tool_plain
625+
def do_work() -> str:
626+
return 'ok'
627+
628+
try:
629+
# Allow at most one executed tool call in this run
630+
agent.run_sync('Please call the tool twice', usage_limits=UsageLimits(tool_calls_limit=1))
631+
except UsageLimitExceeded as e:
632+
print(e)
633+
#> The next tool call would exceed the tool_calls_limit of 1 (tool_calls=1)
634+
```
635+
613636
!!! note
614-
- Usage limits are especially relevant if you've registered many tools. The `request_limit` can be used to prevent the model from calling them in a loop too many times.
637+
- Usage limits are especially relevant if you've registered many tools. Use `request_limit` to bound the number of model turns, and `tool_calls_limit` to cap the number of successful tool executions within a run.
615638
- These limits are enforced at the final stage before the LLM is called. If your limits are stricter than your retry settings, the usage limit will be reached before all retries are attempted.
616639

617640
#### Model (Run) Settings

docs/multi-agent-applications.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Since agents are stateless and designed to be global, you do not need to include
1919
You'll generally want to pass [`ctx.usage`][pydantic_ai.RunContext.usage] to the [`usage`][pydantic_ai.agent.AbstractAgent.run] keyword argument of the delegate agent run so usage within that run counts towards the total usage of the parent agent run.
2020

2121
!!! note "Multiple models"
22-
Agent delegation doesn't need to use the same model for each agent. If you choose to use different models within a run, calculating the monetary cost from the final [`result.usage()`][pydantic_ai.agent.AgentRunResult.usage] of the run will not be possible, but you can still use [`UsageLimits`][pydantic_ai.usage.UsageLimits] to avoid unexpected costs.
22+
Agent delegation doesn't need to use the same model for each agent. If you choose to use different models within a run, calculating the monetary cost from the final [`result.usage()`][pydantic_ai.agent.AgentRunResult.usage] of the run will not be possible, but you can still use [`UsageLimits`][pydantic_ai.usage.UsageLimits] — including `request_limit`, `total_tokens_limit`, and `tool_calls_limit`to avoid unexpected costs or runaway tool loops.
2323

2424
```python {title="agent_delegation_simple.py"}
2525
from pydantic_ai import Agent, RunContext, UsageLimits
@@ -52,7 +52,7 @@ result = joke_selection_agent.run_sync(
5252
print(result.output)
5353
#> Did you hear about the toothpaste scandal? They called it Colgate.
5454
print(result.usage())
55-
#> RunUsage(input_tokens=204, output_tokens=24, requests=3)
55+
#> RunUsage(input_tokens=204, output_tokens=24, requests=3, tool_calls=1)
5656
```
5757

5858
1. The "parent" or controlling agent.
@@ -143,7 +143,7 @@ async def main():
143143
print(result.output)
144144
#> Did you hear about the toothpaste scandal? They called it Colgate.
145145
print(result.usage()) # (6)!
146-
#> RunUsage(input_tokens=309, output_tokens=32, requests=4)
146+
#> RunUsage(input_tokens=309, output_tokens=32, requests=4, tool_calls=2)
147147
```
148148

149149
1. Define a dataclass to hold the client and API key dependencies.

docs/tools-advanced.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,9 @@ When a model returns multiple tool calls in one response, Pydantic AI schedules
377377

378378
Async functions are run on the event loop, while sync functions are offloaded to threads. To get the best performance, _always_ use an async function _unless_ you're doing blocking I/O (and there's no way to use a non-blocking library instead) or CPU-bound work (like `numpy` or `scikit-learn` operations), so that simple functions are not offloaded to threads unnecessarily.
379379

380+
!!! note "Limiting tool executions"
381+
You can cap tool executions within a run using [`UsageLimits(tool_calls_limit=...)`](agents.md#usage-limits). The counter increments only after a successful tool invocation. Output tools (used for structured output) are not counted in the `tool_calls` metric.
382+
380383
## See Also
381384

382385
- [Function Tools](tools.md) - Basic tool concepts and registration

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,7 @@ async def process_function_tools( # noqa: C901
756756
calls_to_run,
757757
deferred_tool_results,
758758
ctx.deps.tracer,
759+
ctx.deps.usage_limits,
759760
output_parts,
760761
deferred_calls,
761762
):
@@ -802,6 +803,7 @@ async def _call_tools(
802803
tool_calls: list[_messages.ToolCallPart],
803804
deferred_tool_results: dict[str, DeferredToolResult],
804805
tracer: Tracer,
806+
usage_limits: _usage.UsageLimits | None,
805807
output_parts: list[_messages.ModelRequestPart],
806808
output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
807809
) -> AsyncIterator[_messages.HandleResponseEvent]:
@@ -822,7 +824,7 @@ async def _call_tools(
822824
):
823825
tasks = [
824826
asyncio.create_task(
825-
_call_tool(tool_manager, call, deferred_tool_results.get(call.tool_call_id)),
827+
_call_tool(tool_manager, call, deferred_tool_results.get(call.tool_call_id), usage_limits),
826828
name=call.tool_name,
827829
)
828830
for call in tool_calls
@@ -862,14 +864,15 @@ async def _call_tool(
862864
tool_manager: ToolManager[DepsT],
863865
tool_call: _messages.ToolCallPart,
864866
tool_call_result: DeferredToolResult | None,
867+
usage_limits: _usage.UsageLimits | None,
865868
) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]:
866869
try:
867870
if tool_call_result is None:
868-
tool_result = await tool_manager.handle_call(tool_call)
871+
tool_result = await tool_manager.handle_call(tool_call, usage_limits=usage_limits)
869872
elif isinstance(tool_call_result, ToolApproved):
870873
if tool_call_result.override_args is not None:
871874
tool_call = dataclasses.replace(tool_call, args=tool_call_result.override_args)
872-
tool_result = await tool_manager.handle_call(tool_call)
875+
tool_result = await tool_manager.handle_call(tool_call, usage_limits=usage_limits)
873876
elif isinstance(tool_call_result, ToolDenied):
874877
return _messages.ToolReturnPart(
875878
tool_name=tool_call.tool_name,

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .messages import ToolCallPart
1515
from .tools import ToolDefinition
1616
from .toolsets.abstract import AbstractToolset, ToolsetTool
17+
from .usage import UsageLimits
1718

1819

1920
@dataclass
@@ -66,31 +67,44 @@ def get_tool_def(self, name: str) -> ToolDefinition | None:
6667
return None
6768

6869
async def handle_call(
69-
self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
70+
self,
71+
call: ToolCallPart,
72+
allow_partial: bool = False,
73+
wrap_validation_errors: bool = True,
74+
usage_limits: UsageLimits | None = None,
7075
) -> Any:
7176
"""Handle a tool call by validating the arguments, calling the tool, and handling retries.
7277
7378
Args:
7479
call: The tool call part to handle.
7580
allow_partial: Whether to allow partial validation of the tool arguments.
7681
wrap_validation_errors: Whether to wrap validation errors in a retry prompt part.
82+
usage_limits: Optional usage limits to check before executing tools.
7783
"""
7884
if self.tools is None or self.ctx is None:
7985
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
8086

8187
if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output':
82-
# Output tool calls are not traced
83-
return await self._call_tool(call, allow_partial, wrap_validation_errors)
88+
# Output tool calls are not traced and not counted
89+
return await self._call_tool(call, allow_partial, wrap_validation_errors, count_tool_usage=False)
8490
else:
8591
return await self._call_tool_traced(
8692
call,
8793
allow_partial,
8894
wrap_validation_errors,
8995
self.ctx.tracer,
9096
self.ctx.trace_include_content,
97+
usage_limits,
9198
)
9299

93-
async def _call_tool(self, call: ToolCallPart, allow_partial: bool, wrap_validation_errors: bool) -> Any:
100+
async def _call_tool(
101+
self,
102+
call: ToolCallPart,
103+
allow_partial: bool,
104+
wrap_validation_errors: bool,
105+
usage_limits: UsageLimits | None = None,
106+
count_tool_usage: bool = True,
107+
) -> Any:
94108
if self.tools is None or self.ctx is None:
95109
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
96110

@@ -121,7 +135,15 @@ async def _call_tool(self, call: ToolCallPart, allow_partial: bool, wrap_validat
121135
else:
122136
args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial)
123137

124-
return await self.toolset.call_tool(name, args_dict, ctx, tool)
138+
if usage_limits is not None and count_tool_usage:
139+
usage_limits.check_before_tool_call(self.ctx.usage)
140+
141+
result = await self.toolset.call_tool(name, args_dict, ctx, tool)
142+
143+
if count_tool_usage:
144+
self.ctx.usage.tool_calls += 1
145+
146+
return result
125147
except (ValidationError, ModelRetry) as e:
126148
max_retries = tool.max_retries if tool is not None else 1
127149
current_retry = self.ctx.retries.get(name, 0)
@@ -160,6 +182,7 @@ async def _call_tool_traced(
160182
wrap_validation_errors: bool,
161183
tracer: Tracer,
162184
include_content: bool = False,
185+
usage_limits: UsageLimits | None = None,
163186
) -> Any:
164187
"""See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>."""
165188
span_attributes = {
@@ -189,7 +212,7 @@ async def _call_tool_traced(
189212
}
190213
with tracer.start_as_current_span('running tool', attributes=span_attributes) as span:
191214
try:
192-
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
215+
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors, usage_limits)
193216
except ToolRetryError as e:
194217
part = e.tool_retry
195218
if include_content and span.is_recording():

pydantic_ai_slim/pydantic_ai/usage.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ class RunUsage(UsageBase):
117117
requests: int = 0
118118
"""Number of requests made to the LLM API."""
119119

120+
tool_calls: int = 0
121+
"""Number of successful tool calls executed during the run."""
122+
120123
input_tokens: int = 0
121124
"""Total number of text input/prompt tokens."""
122125

@@ -146,6 +149,7 @@ def incr(self, incr_usage: RunUsage | RequestUsage) -> None:
146149
"""
147150
if isinstance(incr_usage, RunUsage):
148151
self.requests += incr_usage.requests
152+
self.tool_calls += incr_usage.tool_calls
149153
return _incr_usage_tokens(self, incr_usage)
150154

151155
def __add__(self, other: RunUsage | RequestUsage) -> RunUsage:
@@ -194,6 +198,8 @@ class UsageLimits:
194198

195199
request_limit: int | None = 50
196200
"""The maximum number of requests allowed to the model."""
201+
tool_calls_limit: int | None = None
202+
"""The maximum number of successful tool calls allowed to be executed."""
197203
input_tokens_limit: int | None = None
198204
"""The maximum number of input/prompt tokens allowed."""
199205
output_tokens_limit: int | None = None
@@ -220,12 +226,14 @@ def __init__(
220226
self,
221227
*,
222228
request_limit: int | None = 50,
229+
tool_calls_limit: int | None = None,
223230
input_tokens_limit: int | None = None,
224231
output_tokens_limit: int | None = None,
225232
total_tokens_limit: int | None = None,
226233
count_tokens_before_request: bool = False,
227234
) -> None:
228235
self.request_limit = request_limit
236+
self.tool_calls_limit = tool_calls_limit
229237
self.input_tokens_limit = input_tokens_limit
230238
self.output_tokens_limit = output_tokens_limit
231239
self.total_tokens_limit = total_tokens_limit
@@ -239,12 +247,14 @@ def __init__(
239247
self,
240248
*,
241249
request_limit: int | None = 50,
250+
tool_calls_limit: int | None = None,
242251
request_tokens_limit: int | None = None,
243252
response_tokens_limit: int | None = None,
244253
total_tokens_limit: int | None = None,
245254
count_tokens_before_request: bool = False,
246255
) -> None:
247256
self.request_limit = request_limit
257+
self.tool_calls_limit = tool_calls_limit
248258
self.input_tokens_limit = request_tokens_limit
249259
self.output_tokens_limit = response_tokens_limit
250260
self.total_tokens_limit = total_tokens_limit
@@ -254,6 +264,7 @@ def __init__(
254264
self,
255265
*,
256266
request_limit: int | None = 50,
267+
tool_calls_limit: int | None = None,
257268
input_tokens_limit: int | None = None,
258269
output_tokens_limit: int | None = None,
259270
total_tokens_limit: int | None = None,
@@ -263,6 +274,7 @@ def __init__(
263274
response_tokens_limit: int | None = None,
264275
):
265276
self.request_limit = request_limit
277+
self.tool_calls_limit = tool_calls_limit
266278
self.input_tokens_limit = input_tokens_limit or request_tokens_limit
267279
self.output_tokens_limit = output_tokens_limit or response_tokens_limit
268280
self.total_tokens_limit = total_tokens_limit
@@ -314,4 +326,12 @@ def check_tokens(self, usage: RunUsage) -> None:
314326
if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
315327
raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})')
316328

329+
def check_before_tool_call(self, usage: RunUsage) -> None:
330+
"""Raises a `UsageLimitExceeded` exception if the next tool call would exceed the tool call limit."""
331+
tool_calls_limit = self.tool_calls_limit
332+
if tool_calls_limit is not None and usage.tool_calls >= tool_calls_limit:
333+
raise UsageLimitExceeded(
334+
f'The next tool call would exceed the tool_calls_limit of {tool_calls_limit} (tool_calls={usage.tool_calls})'
335+
)
336+
317337
__repr__ = _utils.dataclasses_no_defaults_repr

tests/models/test_anthropic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,7 @@ async def my_tool(first: str, second: str) -> int:
667667
requests=2,
668668
input_tokens=20,
669669
output_tokens=5,
670+
tool_calls=1,
670671
details={'input_tokens': 20, 'output_tokens': 5},
671672
)
672673
)

tests/models/test_bedrock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ async def temperature(city: str, date: datetime.date) -> str:
112112

113113
result = await agent.run('What was the temperature in London 1st January 2022?', output_type=Response)
114114
assert result.output == snapshot({'temperature': '30°C', 'date': datetime.date(2022, 1, 1), 'city': 'London'})
115-
assert result.usage() == snapshot(RunUsage(requests=2, input_tokens=1236, output_tokens=298))
115+
assert result.usage() == snapshot(RunUsage(requests=2, input_tokens=1236, output_tokens=298, tool_calls=1))
116116
assert result.all_messages() == snapshot(
117117
[
118118
ModelRequest(

tests/models/test_cohere.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ async def get_location(loc_name: str) -> str:
330330
input_tokens=5,
331331
output_tokens=3,
332332
details={'input_tokens': 4, 'output_tokens': 2},
333+
tool_calls=1,
333334
)
334335
)
335336

tests/models/test_gemini.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,7 @@ async def get_location(loc_name: str) -> str:
783783
),
784784
]
785785
)
786-
assert result.usage() == snapshot(RunUsage(requests=3, input_tokens=3, output_tokens=6))
786+
assert result.usage() == snapshot(RunUsage(requests=3, input_tokens=3, output_tokens=6, tool_calls=2))
787787

788788

789789
async def test_unexpected_response(client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None):
@@ -932,7 +932,7 @@ async def bar(y: str) -> str:
932932
async with agent.run_stream('Hello') as result:
933933
response = await result.get_output()
934934
assert response == snapshot((1, 2))
935-
assert result.usage() == snapshot(RunUsage(requests=2, input_tokens=2, output_tokens=4))
935+
assert result.usage() == snapshot(RunUsage(requests=2, input_tokens=2, output_tokens=4, tool_calls=2))
936936
assert result.all_messages() == snapshot(
937937
[
938938
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),

0 commit comments

Comments
 (0)