Skip to content

Commit a235ee7

Browse files
committed
fix: enforce tool call limit enforcement for parallel tool calls
1 parent 79ef2bf commit a235ee7

File tree

2 files changed

+107
-3
lines changed

2 files changed

+107
-3
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,29 @@ async def process_tool_calls( # noqa: C901
860860
output_final_result.append(final_result)
861861

862862

863+
def _enforce_tool_call_limits(
864+
tool_manager: ToolManager[DepsT],
865+
tool_calls: list[_messages.ToolCallPart],
866+
usage_limits: _usage.UsageLimits | None,
867+
) -> tuple[list[_messages.ToolCallPart], int]:
868+
"""Enforce tool call limits and return limited calls and extra count."""
869+
if usage_limits is None or usage_limits.tool_calls_limit is None:
870+
return tool_calls, 0
871+
872+
current_tool_calls = tool_manager.ctx.usage.tool_calls if tool_manager.ctx is not None else 0
873+
remaining_allowed = usage_limits.tool_calls_limit - current_tool_calls
874+
875+
if remaining_allowed <= 0:
876+
usage_limits.check_before_tool_call(tool_manager.ctx.usage if tool_manager.ctx else _usage.RunUsage())
877+
878+
if remaining_allowed < len(tool_calls):
879+
limited_tool_calls = tool_calls[: max(0, remaining_allowed)]
880+
extra_calls_count = len(tool_calls) - len(limited_tool_calls)
881+
return limited_tool_calls, extra_calls_count
882+
883+
return tool_calls, 0
884+
885+
863886
async def _call_tools(
864887
tool_manager: ToolManager[DepsT],
865888
tool_calls: list[_messages.ToolCallPart],
@@ -906,6 +929,8 @@ async def handle_call_or_result(
906929

907930
return _messages.FunctionToolResultEvent(tool_part)
908931

932+
executed_calls: list[_messages.ToolCallPart] = tool_calls
933+
909934
if tool_manager.should_call_sequentially(tool_calls):
910935
for index, call in enumerate(tool_calls):
911936
if event := await handle_call_or_result(
@@ -915,12 +940,14 @@ async def handle_call_or_result(
915940
yield event
916941

917942
else:
943+
executed_calls, extra_calls_count = _enforce_tool_call_limits(tool_manager, tool_calls, usage_limits)
944+
918945
tasks = [
919946
asyncio.create_task(
920947
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id), usage_limits),
921948
name=call.tool_name,
922949
)
923-
for call in tool_calls
950+
for call in executed_calls
924951
]
925952

926953
pending = tasks
@@ -931,13 +958,17 @@ async def handle_call_or_result(
931958
if event := await handle_call_or_result(coro_or_task=task, index=index):
932959
yield event
933960

961+
# If there were extra calls beyond the allowed limit, raise now
962+
if extra_calls_count and usage_limits is not None:
963+
usage_limits.check_before_tool_call(tool_manager.ctx.usage if tool_manager.ctx else _usage.RunUsage())
964+
934965
# We append the results at the end, rather than as they are received, to retain a consistent ordering
935966
# This is mostly just to simplify testing
936967
output_parts.extend([tool_parts_by_index[k] for k in sorted(tool_parts_by_index)])
937968
output_parts.extend([user_parts_by_index[k] for k in sorted(user_parts_by_index)])
938969

939970
for k in sorted(deferred_calls_by_index):
940-
output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k])
971+
output_deferred_calls[deferred_calls_by_index[k]].append(executed_calls[k])
941972

942973

943974
async def _call_tool(

tests/test_usage_limits.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import functools
23
import operator
34
import re
@@ -12,7 +13,15 @@
1213

1314
from pydantic_ai import Agent, RunContext, UsageLimitExceeded
1415
from pydantic_ai.exceptions import ModelRetry
15-
from pydantic_ai.messages import ModelRequest, ModelResponse, ToolCallPart, ToolReturnPart, UserPromptPart
16+
from pydantic_ai.messages import (
17+
ModelMessage,
18+
ModelRequest,
19+
ModelResponse,
20+
ToolCallPart,
21+
ToolReturnPart,
22+
UserPromptPart,
23+
)
24+
from pydantic_ai.models.function import AgentInfo, FunctionModel
1625
from pydantic_ai.models.test import TestModel
1726
from pydantic_ai.output import ToolOutput
1827
from pydantic_ai.usage import RequestUsage, RunUsage, UsageLimits
@@ -308,3 +317,67 @@ def test_deprecated_usage_limits():
308317
snapshot(['DeprecationWarning: `response_tokens_limit` is deprecated, use `output_tokens_limit` instead'])
309318
):
310319
assert UsageLimits(output_tokens_limit=100).response_tokens_limit == 100 # type: ignore
320+
321+
322+
async def test_parallel_tool_calls_limit_enforced():
323+
"""Parallel tool calls must not exceed the limit and should raise immediately."""
324+
executed_tools: list[str] = []
325+
326+
model_call_count = 0
327+
328+
def test_model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
329+
nonlocal model_call_count
330+
model_call_count += 1
331+
332+
if model_call_count == 1:
333+
# First response: 5 parallel tool calls
334+
return ModelResponse(
335+
parts=[
336+
ToolCallPart('tool_a', {}, 'call_1'),
337+
ToolCallPart('tool_b', {}, 'call_2'),
338+
ToolCallPart('tool_c', {}, 'call_3'),
339+
ToolCallPart('tool_a', {}, 'call_4'),
340+
ToolCallPart('tool_b', {}, 'call_5'),
341+
]
342+
)
343+
else:
344+
assert model_call_count == 2
345+
# Second response: 3 parallel tool calls (should exceed limit)
346+
return ModelResponse(
347+
parts=[
348+
ToolCallPart('tool_c', {}, 'call_6'),
349+
ToolCallPart('tool_a', {}, 'call_7'),
350+
ToolCallPart('tool_b', {}, 'call_8'),
351+
]
352+
)
353+
354+
test_model = FunctionModel(test_model_function)
355+
agent = Agent(test_model)
356+
357+
@agent.tool_plain
358+
async def tool_a() -> str:
359+
await asyncio.sleep(0.01)
360+
executed_tools.append('a')
361+
return 'result a'
362+
363+
@agent.tool_plain
364+
async def tool_b() -> str:
365+
await asyncio.sleep(0.01)
366+
executed_tools.append('b')
367+
return 'result b'
368+
369+
@agent.tool_plain
370+
async def tool_c() -> str:
371+
await asyncio.sleep(0.01)
372+
executed_tools.append('c')
373+
return 'result c'
374+
375+
# Run with tool call limit of 6; expecting an error once the limit is reached
376+
with pytest.raises(
377+
UsageLimitExceeded,
378+
match=r'The next tool call would exceed the tool_calls_limit of 6 \(tool_calls=(6)\)',
379+
):
380+
await agent.run('Use tools', usage_limits=UsageLimits(tool_calls_limit=6))
381+
382+
# Only 6 tool calls should have actually executed
383+
assert len(executed_tools) == 6

0 commit comments

Comments
 (0)