Skip to content

Commit 8d2b4c9

Browse files
Fix race condition with threading.Lock (exclude lock from comparison)
- Use threading.Lock instead of asyncio.Lock for Python GIL safety - Exclude lock from dataclass comparison and repr to fix serialization issues - Keep incr() synchronous for backward compatibility - Lock prevents race condition in concurrent tool calls - Compatible with Python 3.10, 3.11, 3.12, and 3.13 Addresses issue #3120 where usage.tool_calls was undercounting when running tools in parallel due to non-atomic increment operations.
1 parent 8a11d28 commit 8d2b4c9

File tree

3 files changed

+14
-11
lines changed

3 files changed

+14
-11
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ async def stream(
408408
message_history, model_settings, model_request_parameters, run_context
409409
) as streamed_response:
410410
self._did_stream = True
411-
await ctx.state.usage.incr(_usage.RunUsage(requests=1))
411+
ctx.state.usage.incr(_usage.RunUsage(requests=1))
412412
agent_stream = result.AgentStream[DepsT, T](
413413
_raw_stream_response=streamed_response,
414414
_output_schema=ctx.deps.output_schema,
@@ -437,9 +437,9 @@ async def _make_request(
437437

438438
model_settings, model_request_parameters, message_history, _ = await self._prepare_request(ctx)
439439
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
440-
await ctx.state.usage.incr(_usage.RunUsage(requests=1))
440+
ctx.state.usage.incr(_usage.RunUsage(requests=1))
441441

442-
return await self._finish_handling(ctx, model_response)
442+
return self._finish_handling(ctx, model_response)
443443

444444
async def _prepare_request(
445445
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
@@ -475,19 +475,19 @@ async def _prepare_request(
475475
usage = deepcopy(usage)
476476

477477
counted_usage = await ctx.deps.model.count_tokens(message_history, model_settings, model_request_parameters)
478-
await usage.incr(counted_usage)
478+
usage.incr(counted_usage)
479479

480480
ctx.deps.usage_limits.check_before_request(usage)
481481

482482
return model_settings, model_request_parameters, message_history, run_context
483483

484-
async def _finish_handling(
484+
def _finish_handling(
485485
self,
486486
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
487487
response: _messages.ModelResponse,
488488
) -> CallToolsNode[DepsT, NodeRunEndT]:
489489
# Update usage
490-
await ctx.state.usage.incr(response.usage)
490+
ctx.state.usage.incr(response.usage)
491491
if ctx.deps.usage_limits: # pragma: no branch
492492
ctx.deps.usage_limits.check_tokens(ctx.state.usage)
493493

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ async def _call_function_tool(
234234
) as span:
235235
try:
236236
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
237-
await usage.incr(RunUsage(tool_calls=1))
237+
usage.incr(RunUsage(tool_calls=1))
238238

239239
except ToolRetryError as e:
240240
part = e.tool_retry

pydantic_ai_slim/pydantic_ai/usage.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations as _annotations
22

3-
import asyncio
3+
import threading
44
import dataclasses
55
from copy import copy
66
from dataclasses import dataclass, fields
@@ -190,16 +190,16 @@ class RunUsage(UsageBase):
190190
details: dict[str, int] = dataclasses.field(default_factory=dict)
191191
"""Any extra details returned by the model."""
192192

193-
_lock: asyncio.Lock = dataclasses.field(default_factory=asyncio.Lock)
193+
_lock: threading.Lock = dataclasses.field(default_factory=threading.Lock, compare=False, repr=False)
194194
"""Lock to prevent race conditions when incrementing usage from concurrent tool calls."""
195195

196-
async def incr(self, incr_usage: RunUsage | RequestUsage) -> None:
196+
def incr(self, incr_usage: RunUsage | RequestUsage) -> None:
197197
"""Increment the usage in place.
198198
199199
Args:
200200
incr_usage: The usage to increment by.
201201
"""
202-
async with self._lock:
202+
with self._lock:
203203
if isinstance(incr_usage, RunUsage):
204204
self.requests += incr_usage.requests
205205
self.tool_calls += incr_usage.tool_calls
@@ -214,6 +214,9 @@ def __add__(self, other: RunUsage | RequestUsage) -> RunUsage:
214214
# Note: We can't use await here since __add__ must be synchronous
215215
# But __add__ creates a new object, so there's no race condition
216216
# The race condition only happens when modifying the same object concurrently
217+
# Create a new lock for the new instance
218+
new_usage._lock = threading.Lock()
219+
217220
if isinstance(other, RunUsage):
218221
new_usage.requests += other.requests
219222
new_usage.tool_calls += other.tool_calls

0 commit comments

Comments
 (0)