Skip to content

Commit 3b8ff2c

Browse files
authored
Fix parallel tool call limit enforcement (#2978)
1 parent 3db21d4 commit 3b8ff2c

File tree

6 files changed

+137
-34
lines changed

6 files changed

+137
-34
lines changed

docs/agents.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -630,12 +630,12 @@ try:
630630
agent.run_sync('Please call the tool twice', usage_limits=UsageLimits(tool_calls_limit=1))
631631
except UsageLimitExceeded as e:
632632
print(e)
633-
#> The next tool call would exceed the tool_calls_limit of 1 (tool_calls=1)
633+
#> The next tool call(s) would exceed the tool_calls_limit of 1 (tool_calls=2).
634634
```
635635

636636
!!! note
637637
- Usage limits are especially relevant if you've registered many tools. Use `request_limit` to bound the number of model turns, and `tool_calls_limit` to cap the number of successful tool executions within a run.
638-
- These limits are enforced at the final stage before the LLM is called. If your limits are stricter than your retry settings, the usage limit will be reached before all retries are attempted.
638+
- The `tool_calls_limit` is checked before executing tool calls. If the model returns parallel tool calls that would exceed the limit, no tools will be executed.
639639

640640
#### Model (Run) Settings
641641

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,6 @@ async def run( # noqa: C901
230230
# Build the run context after `ctx.deps.prompt` has been updated
231231
run_context = build_run_context(ctx)
232232

233-
parts: list[_messages.ModelRequestPart] = []
234233
if messages:
235234
await self._reevaluate_dynamic_prompts(messages, run_context)
236235

@@ -840,6 +839,7 @@ async def process_tool_calls( # noqa: C901
840839
tool_calls=calls_to_run,
841840
tool_call_results=calls_to_run_results,
842841
tracer=ctx.deps.tracer,
842+
usage=ctx.state.usage,
843843
usage_limits=ctx.deps.usage_limits,
844844
output_parts=output_parts,
845845
output_deferred_calls=deferred_calls,
@@ -886,14 +886,20 @@ async def _call_tools(
886886
tool_calls: list[_messages.ToolCallPart],
887887
tool_call_results: dict[str, DeferredToolResult],
888888
tracer: Tracer,
889-
usage_limits: _usage.UsageLimits | None,
889+
usage: _usage.RunUsage,
890+
usage_limits: _usage.UsageLimits,
890891
output_parts: list[_messages.ModelRequestPart],
891892
output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
892893
) -> AsyncIterator[_messages.HandleResponseEvent]:
893894
tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
894895
user_parts_by_index: dict[int, _messages.UserPromptPart] = {}
895896
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {}
896897

898+
if usage_limits.tool_calls_limit is not None:
899+
projected_usage = deepcopy(usage)
900+
projected_usage.tool_calls += len(tool_calls)
901+
usage_limits.check_before_tool_call(projected_usage)
902+
897903
for call in tool_calls:
898904
yield _messages.FunctionToolCallEvent(call)
899905

@@ -930,15 +936,15 @@ async def handle_call_or_result(
930936
if tool_manager.should_call_sequentially(tool_calls):
931937
for index, call in enumerate(tool_calls):
932938
if event := await handle_call_or_result(
933-
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id), usage_limits),
939+
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
934940
index,
935941
):
936942
yield event
937943

938944
else:
939945
tasks = [
940946
asyncio.create_task(
941-
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id), usage_limits),
947+
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
942948
name=call.tool_name,
943949
)
944950
for call in tool_calls
@@ -965,15 +971,14 @@ async def _call_tool(
965971
tool_manager: ToolManager[DepsT],
966972
tool_call: _messages.ToolCallPart,
967973
tool_call_result: DeferredToolResult | None,
968-
usage_limits: _usage.UsageLimits | None,
969974
) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]:
970975
try:
971976
if tool_call_result is None:
972-
tool_result = await tool_manager.handle_call(tool_call, usage_limits=usage_limits)
977+
tool_result = await tool_manager.handle_call(tool_call)
973978
elif isinstance(tool_call_result, ToolApproved):
974979
if tool_call_result.override_args is not None:
975980
tool_call = dataclasses.replace(tool_call, args=tool_call_result.override_args)
976-
tool_result = await tool_manager.handle_call(tool_call, usage_limits=usage_limits)
981+
tool_result = await tool_manager.handle_call(tool_call)
977982
elif isinstance(tool_call_result, ToolDenied):
978983
return _messages.ToolReturnPart(
979984
tool_name=tool_call.tool_name,

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from .messages import ToolCallPart
1919
from .tools import ToolDefinition
2020
from .toolsets.abstract import AbstractToolset, ToolsetTool
21-
from .usage import UsageLimits
21+
from .usage import RunUsage
2222

2323
_sequential_tool_calls_ctx_var: ContextVar[bool] = ContextVar('sequential_tool_calls', default=False)
2424

@@ -93,7 +93,6 @@ async def handle_call(
9393
call: ToolCallPart,
9494
allow_partial: bool = False,
9595
wrap_validation_errors: bool = True,
96-
usage_limits: UsageLimits | None = None,
9796
) -> Any:
9897
"""Handle a tool call by validating the arguments, calling the tool, and handling retries.
9998
@@ -108,25 +107,23 @@ async def handle_call(
108107

109108
if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output':
110109
# Output tool calls are not traced and not counted
111-
return await self._call_tool(call, allow_partial, wrap_validation_errors, count_tool_usage=False)
110+
return await self._call_tool(call, allow_partial, wrap_validation_errors)
112111
else:
113-
return await self._call_tool_traced(
112+
return await self._call_function_tool(
114113
call,
115114
allow_partial,
116115
wrap_validation_errors,
117116
self.ctx.tracer,
118117
self.ctx.trace_include_content,
119118
self.ctx.instrumentation_version,
120-
usage_limits,
119+
self.ctx.usage,
121120
)
122121

123122
async def _call_tool(
124123
self,
125124
call: ToolCallPart,
126125
allow_partial: bool,
127126
wrap_validation_errors: bool,
128-
usage_limits: UsageLimits | None = None,
129-
count_tool_usage: bool = True,
130127
) -> Any:
131128
if self.tools is None or self.ctx is None:
132129
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
@@ -159,14 +156,8 @@ async def _call_tool(
159156
else:
160157
args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial)
161158

162-
if usage_limits is not None and count_tool_usage:
163-
usage_limits.check_before_tool_call(self.ctx.usage)
164-
165159
result = await self.toolset.call_tool(name, args_dict, ctx, tool)
166160

167-
if count_tool_usage:
168-
self.ctx.usage.tool_calls += 1
169-
170161
return result
171162
except (ValidationError, ModelRetry) as e:
172163
max_retries = tool.max_retries if tool is not None else 1
@@ -199,15 +190,15 @@ async def _call_tool(
199190

200191
raise e
201192

202-
async def _call_tool_traced(
193+
async def _call_function_tool(
203194
self,
204195
call: ToolCallPart,
205196
allow_partial: bool,
206197
wrap_validation_errors: bool,
207198
tracer: Tracer,
208199
include_content: bool,
209200
instrumentation_version: int,
210-
usage_limits: UsageLimits | None = None,
201+
usage: RunUsage,
211202
) -> Any:
212203
"""See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>."""
213204
instrumentation_names = InstrumentationNames.for_version(instrumentation_version)
@@ -242,7 +233,9 @@ async def _call_tool_traced(
242233
attributes=span_attributes,
243234
) as span:
244235
try:
245-
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors, usage_limits)
236+
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
237+
usage.tool_calls += 1
238+
246239
except ToolRetryError as e:
247240
part = e.tool_retry
248241
if include_content and span.is_recording():

pydantic_ai_slim/pydantic_ai/usage.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -340,12 +340,13 @@ def check_tokens(self, usage: RunUsage) -> None:
340340
if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
341341
raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})')
342342

343-
def check_before_tool_call(self, usage: RunUsage) -> None:
344-
"""Raises a `UsageLimitExceeded` exception if the next tool call would exceed the tool call limit."""
343+
def check_before_tool_call(self, projected_usage: RunUsage) -> None:
344+
"""Raises a `UsageLimitExceeded` exception if the next tool call(s) would exceed the tool call limit."""
345345
tool_calls_limit = self.tool_calls_limit
346-
if tool_calls_limit is not None and usage.tool_calls >= tool_calls_limit:
346+
tool_calls = projected_usage.tool_calls
347+
if tool_calls_limit is not None and tool_calls > tool_calls_limit:
347348
raise UsageLimitExceeded(
348-
f'The next tool call would exceed the tool_calls_limit of {tool_calls_limit} (tool_calls={usage.tool_calls})'
349+
f'The next tool call(s) would exceed the tool_calls_limit of {tool_calls_limit} ({tool_calls=}).'
349350
)
350351

351352
__repr__ = _utils.dataclasses_no_defaults_repr

tests/test_examples.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,10 @@ async def call_tool(
393393
'The capital of Italy is Rome (Roma, in Italian), which has been a cultural and political center for centuries.'
394394
'Rome is known for its rich history, stunning architecture, and delicious cuisine.'
395395
),
396-
'Please call the tool twice': ToolCallPart(tool_name='do_work', args={}, tool_call_id='pyd_ai_tool_call_id'),
396+
'Please call the tool twice': [
397+
ToolCallPart(tool_name='do_work', args={}, tool_call_id='pyd_ai_tool_call_id_1'),
398+
ToolCallPart(tool_name='do_work', args={}, tool_call_id='pyd_ai_tool_call_id_2'),
399+
],
397400
'Begin infinite retry loop!': ToolCallPart(
398401
tool_name='infinite_retry_tool', args={}, tool_call_id='pyd_ai_tool_call_id'
399402
),

tests/test_usage_limits.py

Lines changed: 105 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import functools
23
import operator
34
import re
@@ -12,6 +13,7 @@
1213

1314
from pydantic_ai import (
1415
Agent,
16+
ModelMessage,
1517
ModelRequest,
1618
ModelResponse,
1719
RunContext,
@@ -21,6 +23,7 @@
2123
UserPromptPart,
2224
)
2325
from pydantic_ai.exceptions import ModelRetry
26+
from pydantic_ai.models.function import AgentInfo, FunctionModel
2427
from pydantic_ai.models.test import TestModel
2528
from pydantic_ai.output import ToolOutput
2629
from pydantic_ai.usage import RequestUsage, RunUsage, UsageLimits
@@ -253,7 +256,8 @@ async def ret_a(x: str) -> str:
253256
return f'{x}-apple'
254257

255258
with pytest.raises(
256-
UsageLimitExceeded, match=re.escape('The next tool call would exceed the tool_calls_limit of 0 (tool_calls=0)')
259+
UsageLimitExceeded,
260+
match=re.escape('The next tool call(s) would exceed the tool_calls_limit of 0 (tool_calls=1).'),
257261
):
258262
await test_agent.run('Hello', usage_limits=UsageLimits(tool_calls_limit=0))
259263

@@ -286,8 +290,42 @@ async def another_regular_tool(x: str) -> str:
286290
assert result_output.usage() == snapshot(RunUsage(requests=2, input_tokens=103, output_tokens=15, tool_calls=1))
287291

288292

293+
async def test_output_tool_allowed_at_limit() -> None:
294+
"""Test that output tools can be called even when at the tool_calls_limit."""
295+
296+
class MyOutput(BaseModel):
297+
result: str
298+
299+
def call_output_after_regular(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
300+
if len(messages) == 1:
301+
return ModelResponse(
302+
parts=[
303+
ToolCallPart('regular_tool', {'x': 'test'}, 'call_1'),
304+
],
305+
usage=RequestUsage(input_tokens=10, output_tokens=5),
306+
)
307+
else:
308+
return ModelResponse(
309+
parts=[
310+
ToolCallPart('final_result', {'result': 'success'}, 'call_2'),
311+
],
312+
usage=RequestUsage(input_tokens=10, output_tokens=5),
313+
)
314+
315+
test_agent = Agent(FunctionModel(call_output_after_regular), output_type=ToolOutput(MyOutput))
316+
317+
@test_agent.tool_plain
318+
async def regular_tool(x: str) -> str:
319+
return f'{x}-processed'
320+
321+
result = await test_agent.run('test', usage_limits=UsageLimits(tool_calls_limit=1))
322+
323+
assert result.output.result == 'success'
324+
assert result.usage() == snapshot(RunUsage(requests=2, input_tokens=20, output_tokens=10, tool_calls=1))
325+
326+
289327
async def test_failed_tool_calls_not_counted() -> None:
290-
"""Test that failed tool calls (raising ModelRetry) are not counted."""
328+
"""Test that failed tool calls (raising ModelRetry) are not counted in usage or against limits."""
291329
test_agent = Agent(TestModel())
292330

293331
call_count = 0
@@ -300,8 +338,7 @@ async def flaky_tool(x: str) -> str:
300338
raise ModelRetry('Temporary failure, please retry')
301339
return f'{x}-success'
302340

303-
result = await test_agent.run('test')
304-
# The tool was called twice (1 failure + 1 success), but only the successful call should be counted
341+
result = await test_agent.run('test', usage_limits=UsageLimits(tool_calls_limit=1))
305342
assert call_count == 2
306343
assert result.usage() == snapshot(RunUsage(requests=3, input_tokens=176, output_tokens=29, tool_calls=1))
307344

@@ -316,3 +353,67 @@ def test_deprecated_usage_limits():
316353
snapshot(['DeprecationWarning: `response_tokens_limit` is deprecated, use `output_tokens_limit` instead'])
317354
):
318355
assert UsageLimits(output_tokens_limit=100).response_tokens_limit == 100 # type: ignore
356+
357+
358+
async def test_parallel_tool_calls_limit_enforced():
359+
"""Parallel tool calls must not exceed the limit and should raise immediately."""
360+
executed_tools: list[str] = []
361+
362+
model_call_count = 0
363+
364+
def test_model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
365+
nonlocal model_call_count
366+
model_call_count += 1
367+
368+
if model_call_count == 1:
369+
# First response: 5 parallel tool calls (within limit)
370+
return ModelResponse(
371+
parts=[
372+
ToolCallPart('tool_a', {}, 'call_1'),
373+
ToolCallPart('tool_b', {}, 'call_2'),
374+
ToolCallPart('tool_c', {}, 'call_3'),
375+
ToolCallPart('tool_a', {}, 'call_4'),
376+
ToolCallPart('tool_b', {}, 'call_5'),
377+
]
378+
)
379+
else:
380+
assert model_call_count == 2
381+
# Second response: 3 parallel tool calls (would exceed limit of 6)
382+
return ModelResponse(
383+
parts=[
384+
ToolCallPart('tool_c', {}, 'call_6'),
385+
ToolCallPart('tool_a', {}, 'call_7'),
386+
ToolCallPart('tool_b', {}, 'call_8'),
387+
]
388+
)
389+
390+
test_model = FunctionModel(test_model_function)
391+
agent = Agent(test_model)
392+
393+
@agent.tool_plain
394+
async def tool_a() -> str:
395+
await asyncio.sleep(0.01)
396+
executed_tools.append('a')
397+
return 'result a'
398+
399+
@agent.tool_plain
400+
async def tool_b() -> str:
401+
await asyncio.sleep(0.01)
402+
executed_tools.append('b')
403+
return 'result b'
404+
405+
@agent.tool_plain
406+
async def tool_c() -> str:
407+
await asyncio.sleep(0.01)
408+
executed_tools.append('c')
409+
return 'result c'
410+
411+
# Run with tool call limit of 6; expecting an error when trying to execute 3 more tools
412+
with pytest.raises(
413+
UsageLimitExceeded,
414+
match=re.escape('The next tool call(s) would exceed the tool_calls_limit of 6 (tool_calls=8).'),
415+
):
416+
await agent.run('Use tools', usage_limits=UsageLimits(tool_calls_limit=6))
417+
418+
# Only the first batch of 5 tools should have executed
419+
assert len(executed_tools) == 5

0 commit comments

Comments
 (0)