Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ async def stream(

model_response = streamed_response.get()

self._finish_handling(ctx, model_response)
await self._finish_handling(ctx, model_response)
assert self._result is not None # this should be set by the previous line

async def _make_request(
Expand All @@ -439,7 +439,7 @@ async def _make_request(
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
ctx.state.usage.requests += 1

return self._finish_handling(ctx, model_response)
return await self._finish_handling(ctx, model_response)

async def _prepare_request(
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
Expand Down Expand Up @@ -481,7 +481,7 @@ async def _prepare_request(

return model_settings, model_request_parameters, message_history, run_context

def _finish_handling(
async def _finish_handling(
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
response: _messages.ModelResponse,
Expand Down
10 changes: 9 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import asyncio
import json
from collections.abc import Iterator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field, replace
from functools import cached_property
from typing import Any, Generic

from opentelemetry.trace import Tracer
Expand Down Expand Up @@ -72,6 +74,11 @@ def tool_defs(self) -> list[ToolDefinition]:

return [tool.tool_def for tool in self.tools.values()]

@cached_property
def _usage_lock(self) -> asyncio.Lock:
"""Lock to prevent race conditions when incrementing usage.tool_calls from concurrent tool executions."""
return asyncio.Lock()

def should_call_sequentially(self, calls: list[ToolCallPart]) -> bool:
"""Whether to require sequential tool calls for a list of tool calls."""
return _sequential_tool_calls_ctx_var.get() or any(
Expand Down Expand Up @@ -234,7 +241,8 @@ async def _call_function_tool(
) as span:
try:
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
usage.tool_calls += 1
async with self._usage_lock:
usage.tool_calls += 1

except ToolRetryError as e:
part = e.tool_retry
Expand Down
36 changes: 36 additions & 0 deletions tests/test_usage_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ModelRequest,
ModelResponse,
RunContext,
TextPart,
ToolCallPart,
ToolReturnPart,
UsageLimitExceeded,
Expand Down Expand Up @@ -355,6 +356,41 @@ def test_deprecated_usage_limits():
assert UsageLimits(output_tokens_limit=100).response_tokens_limit == 100 # type: ignore


async def test_race_condition_parallel_tool_calls():
"""Test that demonstrates race condition in parallel tool execution.

This test would fail intermittently on main without the fix because multiple
asyncio tasks calling usage.incr() can interleave their read-modify-write operations.
"""
# Run multiple iterations to increase chance of catching race condition
for iteration in range(20):
call_count = 0

def parallel_tools_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
nonlocal call_count
call_count += 1
if call_count == 1:
# Return 10 parallel tool calls for more contention
return ModelResponse(parts=[ToolCallPart('tool_a', {}, f'call_{i}') for i in range(10)])
else:
# Return final text response
return ModelResponse(parts=[TextPart(content='done')])

agent = Agent(FunctionModel(parallel_tools_model))

@agent.tool_plain
async def tool_a() -> str:
# Add multiple await points to increase chance of task interleaving
await asyncio.sleep(0.0001)
await asyncio.sleep(0.0001)
return 'result'

result = await agent.run('test')
# Without proper synchronization, tool_calls might be undercounted
actual = result.usage().tool_calls
assert actual == 10, f'Iteration {iteration}: Expected 10 tool calls, got {actual}'


async def test_parallel_tool_calls_limit_enforced():
"""Parallel tool calls must not exceed the limit and should raise immediately."""
executed_tools: list[str] = []
Expand Down