|
| 1 | +import asyncio |
1 | 2 | import functools |
2 | 3 | import operator |
3 | 4 | import re |
|
12 | 13 |
|
13 | 14 | from pydantic_ai import Agent, RunContext, UsageLimitExceeded |
14 | 15 | from pydantic_ai.exceptions import ModelRetry |
15 | | -from pydantic_ai.messages import ModelRequest, ModelResponse, ToolCallPart, ToolReturnPart, UserPromptPart |
| 16 | +from pydantic_ai.messages import ( |
| 17 | + ModelMessage, |
| 18 | + ModelRequest, |
| 19 | + ModelResponse, |
| 20 | + ToolCallPart, |
| 21 | + ToolReturnPart, |
| 22 | + UserPromptPart, |
| 23 | +) |
| 24 | +from pydantic_ai.models.function import AgentInfo, FunctionModel |
16 | 25 | from pydantic_ai.models.test import TestModel |
17 | 26 | from pydantic_ai.output import ToolOutput |
18 | 27 | from pydantic_ai.usage import RequestUsage, RunUsage, UsageLimits |
@@ -308,3 +317,67 @@ def test_deprecated_usage_limits(): |
308 | 317 | snapshot(['DeprecationWarning: `response_tokens_limit` is deprecated, use `output_tokens_limit` instead']) |
309 | 318 | ): |
310 | 319 | assert UsageLimits(output_tokens_limit=100).response_tokens_limit == 100 # type: ignore |
| 320 | + |
| 321 | + |
| 322 | +async def test_parallel_tool_calls_limit_enforced(): |
| 323 | + """Parallel tool calls must not exceed the limit and should raise immediately.""" |
| 324 | + executed_tools: list[str] = [] |
| 325 | + |
| 326 | + model_call_count = 0 |
| 327 | + |
| 328 | + def test_model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: |
| 329 | + nonlocal model_call_count |
| 330 | + model_call_count += 1 |
| 331 | + |
| 332 | + if model_call_count == 1: |
| 333 | + # First response: 5 parallel tool calls |
| 334 | + return ModelResponse( |
| 335 | + parts=[ |
| 336 | + ToolCallPart('tool_a', {}, 'call_1'), |
| 337 | + ToolCallPart('tool_b', {}, 'call_2'), |
| 338 | + ToolCallPart('tool_c', {}, 'call_3'), |
| 339 | + ToolCallPart('tool_a', {}, 'call_4'), |
| 340 | + ToolCallPart('tool_b', {}, 'call_5'), |
| 341 | + ] |
| 342 | + ) |
| 343 | + else: |
| 344 | + assert model_call_count == 2 |
| 345 | + # Second response: 3 parallel tool calls (should exceed limit) |
| 346 | + return ModelResponse( |
| 347 | + parts=[ |
| 348 | + ToolCallPart('tool_c', {}, 'call_6'), |
| 349 | + ToolCallPart('tool_a', {}, 'call_7'), |
| 350 | + ToolCallPart('tool_b', {}, 'call_8'), |
| 351 | + ] |
| 352 | + ) |
| 353 | + |
| 354 | + test_model = FunctionModel(test_model_function) |
| 355 | + agent = Agent(test_model) |
| 356 | + |
| 357 | + @agent.tool_plain |
| 358 | + async def tool_a() -> str: |
| 359 | + await asyncio.sleep(0.01) |
| 360 | + executed_tools.append('a') |
| 361 | + return 'result a' |
| 362 | + |
| 363 | + @agent.tool_plain |
| 364 | + async def tool_b() -> str: |
| 365 | + await asyncio.sleep(0.01) |
| 366 | + executed_tools.append('b') |
| 367 | + return 'result b' |
| 368 | + |
| 369 | + @agent.tool_plain |
| 370 | + async def tool_c() -> str: |
| 371 | + await asyncio.sleep(0.01) |
| 372 | + executed_tools.append('c') |
| 373 | + return 'result c' |
| 374 | + |
| 375 | + # Run with tool call limit of 6; expecting an error once the limit is reached |
| 376 | + with pytest.raises( |
| 377 | + UsageLimitExceeded, |
| 378 | + match=r'The next tool call would exceed the tool_calls_limit of 6 \(tool_calls=(6)\)', |
| 379 | + ): |
| 380 | + await agent.run('Use tools', usage_limits=UsageLimits(tool_calls_limit=6)) |
| 381 | + |
| 382 | + # Only 6 tool calls should have actually executed |
| 383 | + assert len(executed_tools) == 6 |
0 commit comments