Skip to content

Commit afccc1b

Browse files
Fix RunUsage.tool_calls being undercounted due to race condition when running tools in parallel (#3133)
1 parent 2af6faf commit afccc1b

File tree

3 files changed

+48
-4
lines changed

3 files changed

+48
-4
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ async def stream(
426426

427427
model_response = streamed_response.get()
428428

429-
self._finish_handling(ctx, model_response)
429+
await self._finish_handling(ctx, model_response)
430430
assert self._result is not None # this should be set by the previous line
431431

432432
async def _make_request(
@@ -439,7 +439,7 @@ async def _make_request(
439439
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
440440
ctx.state.usage.requests += 1
441441

442-
return self._finish_handling(ctx, model_response)
442+
return await self._finish_handling(ctx, model_response)
443443

444444
async def _prepare_request(
445445
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
@@ -481,7 +481,7 @@ async def _prepare_request(
481481

482482
return model_settings, model_request_parameters, message_history, run_context
483483

484-
def _finish_handling(
484+
async def _finish_handling(
485485
self,
486486
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
487487
response: _messages.ModelResponse,

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import json
45
from collections.abc import Iterator
56
from contextlib import contextmanager
67
from contextvars import ContextVar
78
from dataclasses import dataclass, field, replace
9+
from functools import cached_property
810
from typing import Any, Generic
911

1012
from opentelemetry.trace import Tracer
@@ -72,6 +74,11 @@ def tool_defs(self) -> list[ToolDefinition]:
7274

7375
return [tool.tool_def for tool in self.tools.values()]
7476

77+
@cached_property
78+
def _usage_lock(self) -> asyncio.Lock:
79+
"""Lock to prevent race conditions when incrementing usage.tool_calls from concurrent tool executions."""
80+
return asyncio.Lock()
81+
7582
def should_call_sequentially(self, calls: list[ToolCallPart]) -> bool:
7683
"""Whether to require sequential tool calls for a list of tool calls."""
7784
return _sequential_tool_calls_ctx_var.get() or any(
@@ -234,7 +241,8 @@ async def _call_function_tool(
234241
) as span:
235242
try:
236243
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
237-
usage.tool_calls += 1
244+
async with self._usage_lock:
245+
usage.tool_calls += 1
238246

239247
except ToolRetryError as e:
240248
part = e.tool_retry

tests/test_usage_limits.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ModelRequest,
1818
ModelResponse,
1919
RunContext,
20+
TextPart,
2021
ToolCallPart,
2122
ToolReturnPart,
2223
UsageLimitExceeded,
@@ -355,6 +356,41 @@ def test_deprecated_usage_limits():
355356
assert UsageLimits(output_tokens_limit=100).response_tokens_limit == 100 # type: ignore
356357

357358

359+
async def test_race_condition_parallel_tool_calls():
360+
"""Test that demonstrates race condition in parallel tool execution.
361+
362+
This test would fail intermittently on main without the fix because multiple
363+
asyncio tasks calling usage.incr() can interleave their read-modify-write operations.
364+
"""
365+
# Run multiple iterations to increase chance of catching race condition
366+
for iteration in range(20):
367+
call_count = 0
368+
369+
def parallel_tools_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
370+
nonlocal call_count
371+
call_count += 1
372+
if call_count == 1:
373+
# Return 10 parallel tool calls for more contention
374+
return ModelResponse(parts=[ToolCallPart('tool_a', {}, f'call_{i}') for i in range(10)])
375+
else:
376+
# Return final text response
377+
return ModelResponse(parts=[TextPart(content='done')])
378+
379+
agent = Agent(FunctionModel(parallel_tools_model))
380+
381+
@agent.tool_plain
382+
async def tool_a() -> str:
383+
# Add multiple await points to increase chance of task interleaving
384+
await asyncio.sleep(0.0001)
385+
await asyncio.sleep(0.0001)
386+
return 'result'
387+
388+
result = await agent.run('test')
389+
# Without proper synchronization, tool_calls might be undercounted
390+
actual = result.usage().tool_calls
391+
assert actual == 10, f'Iteration {iteration}: Expected 10 tool calls, got {actual}'
392+
393+
358394
async def test_parallel_tool_calls_limit_enforced():
359395
"""Parallel tool calls must not exceed the limit and should raise immediately."""
360396
executed_tools: list[str] = []

0 commit comments

Comments
 (0)