Skip to content

Commit 8a11d28

Browse files
Fix race condition in RunUsage.incr() when running tools in parallel
- Add asyncio.Lock to RunUsage class to prevent race conditions - Make incr() method async and use lock for thread-safe increments - Update all calls to usage.incr() to use await - Replace direct tool_calls += 1 with await usage.incr(RunUsage(tool_calls=1)) - Fixes issue where concurrent tool calls could cause undercounted tool_calls - Maintains backward compatibility with synchronous __add__ method Resolves #3120
1 parent 6cf43ea commit 8a11d28

File tree

4 files changed

+26
-15
lines changed

4 files changed

+26
-15
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ async def stream(
408408
message_history, model_settings, model_request_parameters, run_context
409409
) as streamed_response:
410410
self._did_stream = True
411-
ctx.state.usage.requests += 1
411+
await ctx.state.usage.incr(_usage.RunUsage(requests=1))
412412
agent_stream = result.AgentStream[DepsT, T](
413413
_raw_stream_response=streamed_response,
414414
_output_schema=ctx.deps.output_schema,
@@ -437,9 +437,9 @@ async def _make_request(
437437

438438
model_settings, model_request_parameters, message_history, _ = await self._prepare_request(ctx)
439439
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
440-
ctx.state.usage.requests += 1
440+
await ctx.state.usage.incr(_usage.RunUsage(requests=1))
441441

442-
return self._finish_handling(ctx, model_response)
442+
return await self._finish_handling(ctx, model_response)
443443

444444
async def _prepare_request(
445445
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
@@ -475,19 +475,19 @@ async def _prepare_request(
475475
usage = deepcopy(usage)
476476

477477
counted_usage = await ctx.deps.model.count_tokens(message_history, model_settings, model_request_parameters)
478-
usage.incr(counted_usage)
478+
await usage.incr(counted_usage)
479479

480480
ctx.deps.usage_limits.check_before_request(usage)
481481

482482
return model_settings, model_request_parameters, message_history, run_context
483483

484-
def _finish_handling(
484+
async def _finish_handling(
485485
self,
486486
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
487487
response: _messages.ModelResponse,
488488
) -> CallToolsNode[DepsT, NodeRunEndT]:
489489
# Update usage
490-
ctx.state.usage.incr(response.usage)
490+
await ctx.state.usage.incr(response.usage)
491491
if ctx.deps.usage_limits: # pragma: no branch
492492
ctx.deps.usage_limits.check_tokens(ctx.state.usage)
493493

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ async def _call_function_tool(
234234
) as span:
235235
try:
236236
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
237-
usage.tool_calls += 1
237+
await usage.incr(RunUsage(tool_calls=1))
238238

239239
except ToolRetryError as e:
240240
part = e.tool_retry

pydantic_ai_slim/pydantic_ai/usage.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations as _annotations
22

3+
import asyncio
34
import dataclasses
45
from copy import copy
56
from dataclasses import dataclass, fields
@@ -189,24 +190,34 @@ class RunUsage(UsageBase):
189190
details: dict[str, int] = dataclasses.field(default_factory=dict)
190191
"""Any extra details returned by the model."""
191192

192-
def incr(self, incr_usage: RunUsage | RequestUsage) -> None:
193+
_lock: asyncio.Lock = dataclasses.field(default_factory=asyncio.Lock)
194+
"""Lock to prevent race conditions when incrementing usage from concurrent tool calls."""
195+
196+
async def incr(self, incr_usage: RunUsage | RequestUsage) -> None:
193197
"""Increment the usage in place.
194198
195199
Args:
196200
incr_usage: The usage to increment by.
197201
"""
198-
if isinstance(incr_usage, RunUsage):
199-
self.requests += incr_usage.requests
200-
self.tool_calls += incr_usage.tool_calls
201-
return _incr_usage_tokens(self, incr_usage)
202+
async with self._lock:
203+
if isinstance(incr_usage, RunUsage):
204+
self.requests += incr_usage.requests
205+
self.tool_calls += incr_usage.tool_calls
206+
return _incr_usage_tokens(self, incr_usage)
202207

203208
def __add__(self, other: RunUsage | RequestUsage) -> RunUsage:
204209
"""Add two RunUsages together.
205210
206211
This is provided so it's trivial to sum usage information from multiple runs.
207212
"""
208213
new_usage = copy(self)
209-
new_usage.incr(other)
214+
# Note: We can't use await here since __add__ must be synchronous
215+
# But __add__ creates a new object, so there's no race condition
216+
# The race condition only happens when modifying the same object concurrently
217+
if isinstance(other, RunUsage):
218+
new_usage.requests += other.requests
219+
new_usage.tool_calls += other.tool_calls
220+
_incr_usage_tokens(new_usage, other)
210221
return new_usage
211222

212223

tests/test_usage_limits.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,9 @@ async def test_multi_agent_usage_sync():
200200
controller_agent = Agent(TestModel())
201201

202202
@controller_agent.tool
203-
def delegate_to_other_agent(ctx: RunContext[None], sentence: str) -> int:
203+
async def delegate_to_other_agent(ctx: RunContext[None], sentence: str) -> int:
204204
new_usage = RunUsage(requests=5, input_tokens=2, output_tokens=3)
205-
ctx.usage.incr(new_usage)
205+
await ctx.usage.incr(new_usage)
206206
return 0
207207

208208
result = await controller_agent.run('foobar')

0 commit comments

Comments
 (0)