Skip to content

Commit 1e0e99c

Browse files
authored
Revert "Fix RunUsage.tool_calls being undercounted due to race condition when running tools in parallel" (#3174)
1 parent afccc1b commit 1e0e99c

File tree

3 files changed

+4
-48
lines changed

3 files changed

+4
-48
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ async def stream(
426426

427427
model_response = streamed_response.get()
428428

429-
await self._finish_handling(ctx, model_response)
429+
self._finish_handling(ctx, model_response)
430430
assert self._result is not None # this should be set by the previous line
431431

432432
async def _make_request(
@@ -439,7 +439,7 @@ async def _make_request(
439439
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
440440
ctx.state.usage.requests += 1
441441

442-
return await self._finish_handling(ctx, model_response)
442+
return self._finish_handling(ctx, model_response)
443443

444444
async def _prepare_request(
445445
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
@@ -481,7 +481,7 @@ async def _prepare_request(
481481

482482
return model_settings, model_request_parameters, message_history, run_context
483483

484-
async def _finish_handling(
484+
def _finish_handling(
485485
self,
486486
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
487487
response: _messages.ModelResponse,

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from __future__ import annotations
22

3-
import asyncio
43
import json
54
from collections.abc import Iterator
65
from contextlib import contextmanager
76
from contextvars import ContextVar
87
from dataclasses import dataclass, field, replace
9-
from functools import cached_property
108
from typing import Any, Generic
119

1210
from opentelemetry.trace import Tracer
@@ -74,11 +72,6 @@ def tool_defs(self) -> list[ToolDefinition]:
7472

7573
return [tool.tool_def for tool in self.tools.values()]
7674

77-
@cached_property
78-
def _usage_lock(self) -> asyncio.Lock:
79-
"""Lock to prevent race conditions when incrementing usage.tool_calls from concurrent tool executions."""
80-
return asyncio.Lock()
81-
8275
def should_call_sequentially(self, calls: list[ToolCallPart]) -> bool:
8376
"""Whether to require sequential tool calls for a list of tool calls."""
8477
return _sequential_tool_calls_ctx_var.get() or any(
@@ -241,8 +234,7 @@ async def _call_function_tool(
241234
) as span:
242235
try:
243236
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
244-
async with self._usage_lock:
245-
usage.tool_calls += 1
237+
usage.tool_calls += 1
246238

247239
except ToolRetryError as e:
248240
part = e.tool_retry

tests/test_usage_limits.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
ModelRequest,
1818
ModelResponse,
1919
RunContext,
20-
TextPart,
2120
ToolCallPart,
2221
ToolReturnPart,
2322
UsageLimitExceeded,
@@ -356,41 +355,6 @@ def test_deprecated_usage_limits():
356355
assert UsageLimits(output_tokens_limit=100).response_tokens_limit == 100 # type: ignore
357356

358357

359-
async def test_race_condition_parallel_tool_calls():
360-
"""Test that demonstrates race condition in parallel tool execution.
361-
362-
This test would fail intermittently on main without the fix because multiple
363-
asyncio tasks calling usage.incr() can interleave their read-modify-write operations.
364-
"""
365-
# Run multiple iterations to increase chance of catching race condition
366-
for iteration in range(20):
367-
call_count = 0
368-
369-
def parallel_tools_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
370-
nonlocal call_count
371-
call_count += 1
372-
if call_count == 1:
373-
# Return 10 parallel tool calls for more contention
374-
return ModelResponse(parts=[ToolCallPart('tool_a', {}, f'call_{i}') for i in range(10)])
375-
else:
376-
# Return final text response
377-
return ModelResponse(parts=[TextPart(content='done')])
378-
379-
agent = Agent(FunctionModel(parallel_tools_model))
380-
381-
@agent.tool_plain
382-
async def tool_a() -> str:
383-
# Add multiple await points to increase chance of task interleaving
384-
await asyncio.sleep(0.0001)
385-
await asyncio.sleep(0.0001)
386-
return 'result'
387-
388-
result = await agent.run('test')
389-
# Without proper synchronization, tool_calls might be undercounted
390-
actual = result.usage().tool_calls
391-
assert actual == 10, f'Iteration {iteration}: Expected 10 tool calls, got {actual}'
392-
393-
394358
async def test_parallel_tool_calls_limit_enforced():
395359
"""Parallel tool calls must not exceed the limit and should raise immediately."""
396360
executed_tools: list[str] = []

0 commit comments

Comments
 (0)