Skip to content

Commit cb5de05

Browse files
Convert to asyncio.Lock for async compatibility
- Replace threading.Lock with asyncio.Lock to fix async context issues - Make incr() method async and use async with lock - Update all usage.incr() calls to use await - Fix tool manager to use await usage.incr(RunUsage(tool_calls=1)) - Update agent graph _finish_handling to be async - Fix streaming tests that were failing with TypeError This resolves the 16 TypeError failures in streaming tests by ensuring thread-safe operations work correctly in async contexts.
1 parent 306626c commit cb5de05

File tree

4 files changed

+14
-14
lines changed

4 files changed

+14
-14
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ async def stream(
426426

427427
model_response = streamed_response.get()
428428

429-
self._finish_handling(ctx, model_response)
429+
await self._finish_handling(ctx, model_response)
430430
assert self._result is not None # this should be set by the previous line
431431

432432
async def _make_request(
@@ -437,9 +437,9 @@ async def _make_request(
437437

438438
model_settings, model_request_parameters, message_history, _ = await self._prepare_request(ctx)
439439
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
440-
ctx.state.usage.requests += 1
440+
await ctx.state.usage.incr(_usage.RunUsage(requests=1))
441441

442-
return self._finish_handling(ctx, model_response)
442+
return await self._finish_handling(ctx, model_response)
443443

444444
async def _prepare_request(
445445
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
@@ -475,19 +475,19 @@ async def _prepare_request(
475475
usage = deepcopy(usage)
476476

477477
counted_usage = await ctx.deps.model.count_tokens(message_history, model_settings, model_request_parameters)
478-
usage.incr(counted_usage)
478+
await usage.incr(counted_usage)
479479

480480
ctx.deps.usage_limits.check_before_request(usage)
481481

482482
return model_settings, model_request_parameters, message_history, run_context
483483

484-
def _finish_handling(
484+
async def _finish_handling(
485485
self,
486486
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
487487
response: _messages.ModelResponse,
488488
) -> CallToolsNode[DepsT, NodeRunEndT]:
489489
# Update usage
490-
ctx.state.usage.incr(response.usage)
490+
await ctx.state.usage.incr(response.usage)
491491
if ctx.deps.usage_limits: # pragma: no branch
492492
ctx.deps.usage_limits.check_tokens(ctx.state.usage)
493493

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ async def _call_function_tool(
234234
) as span:
235235
try:
236236
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
237-
usage.tool_calls += 1
237+
await usage.incr(RunUsage(tool_calls=1))
238238

239239
except ToolRetryError as e:
240240
part = e.tool_retry

pydantic_ai_slim/pydantic_ai/usage.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations as _annotations
22

3+
import asyncio
34
import dataclasses
4-
import threading
55
from copy import copy
66
from dataclasses import dataclass, fields
77
from typing import Annotated, Any
@@ -193,19 +193,19 @@ class RunUsage(UsageBase):
193193
"""Lock to prevent race conditions when incrementing usage from concurrent tool calls."""
194194

195195
@property
196-
def _lock(self) -> threading.Lock:
196+
def _lock(self) -> asyncio.Lock:
197197
"""Get the lock, creating it if it doesn't exist."""
198198
if not hasattr(self, '__lock'):
199-
self.__lock = threading.Lock()
199+
self.__lock = asyncio.Lock()
200200
return self.__lock
201201

202-
def incr(self, incr_usage: RunUsage | RequestUsage) -> None:
202+
async def incr(self, incr_usage: RunUsage | RequestUsage) -> None:
203203
"""Increment the usage in place.
204204
205205
Args:
206206
incr_usage: The usage to increment by.
207207
"""
208-
with self._lock:
208+
async with self._lock:
209209
if isinstance(incr_usage, RunUsage):
210210
self.requests += incr_usage.requests
211211
self.tool_calls += incr_usage.tool_calls

tests/test_usage_limits.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,9 @@ async def test_multi_agent_usage_sync():
200200
controller_agent = Agent(TestModel())
201201

202202
@controller_agent.tool
203-
def delegate_to_other_agent(ctx: RunContext[None], sentence: str) -> int:
203+
async def delegate_to_other_agent(ctx: RunContext[None], sentence: str) -> int:
204204
new_usage = RunUsage(requests=5, input_tokens=2, output_tokens=3)
205-
ctx.usage.incr(new_usage)
205+
await ctx.usage.incr(new_usage)
206206
return 0
207207

208208
result = await controller_agent.run('foobar')

0 commit comments

Comments
 (0)