diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index d7a54c5c71..383b214dca 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -426,7 +426,7 @@ async def stream( model_response = streamed_response.get() - self._finish_handling(ctx, model_response) + await self._finish_handling(ctx, model_response) assert self._result is not None # this should be set by the previous line async def _make_request( @@ -439,7 +439,7 @@ async def _make_request( model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters) ctx.state.usage.requests += 1 - return self._finish_handling(ctx, model_response) + return await self._finish_handling(ctx, model_response) async def _prepare_request( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] @@ -481,7 +481,7 @@ async def _prepare_request( return model_settings, model_request_parameters, message_history, run_context - def _finish_handling( + async def _finish_handling( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], response: _messages.ModelResponse, diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py index a5546a4e01..167092657d 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -1,10 +1,12 @@ from __future__ import annotations +import asyncio import json from collections.abc import Iterator from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass, field, replace +from functools import cached_property from typing import Any, Generic from opentelemetry.trace import Tracer @@ -72,6 +74,11 @@ def tool_defs(self) -> list[ToolDefinition]: return [tool.tool_def for tool in self.tools.values()] + @cached_property + def _usage_lock(self) -> asyncio.Lock: + """Lock to prevent race conditions when incrementing usage.tool_calls from concurrent tool executions.""" + return asyncio.Lock() + def should_call_sequentially(self, calls: list[ToolCallPart]) -> bool: """Whether to require sequential tool calls for a list of tool calls.""" return _sequential_tool_calls_ctx_var.get() or any( @@ -234,7 +241,8 @@ async def _call_function_tool( ) as span: try: tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors) - usage.tool_calls += 1 + async with self._usage_lock: + usage.tool_calls += 1 except ToolRetryError as e: part = e.tool_retry diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py index d3c40ed375..59bc049dc6 100644 --- a/tests/test_usage_limits.py +++ b/tests/test_usage_limits.py @@ -17,6 +17,7 @@ ModelRequest, ModelResponse, RunContext, + TextPart, ToolCallPart, ToolReturnPart, UsageLimitExceeded, @@ -355,6 +356,41 @@ def test_deprecated_usage_limits(): assert UsageLimits(output_tokens_limit=100).response_tokens_limit == 100 # type: ignore +async def test_race_condition_parallel_tool_calls(): + """Test that demonstrates race condition in parallel tool execution. + + This test would fail intermittently on main without the fix because multiple + asyncio tasks calling usage.incr() can interleave their read-modify-write operations. + """ + # Run multiple iterations to increase chance of catching race condition + for iteration in range(20): + call_count = 0 + + def parallel_tools_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal call_count + call_count += 1 + if call_count == 1: + # Return 10 parallel tool calls for more contention + return ModelResponse(parts=[ToolCallPart('tool_a', {}, f'call_{i}') for i in range(10)]) + else: + # Return final text response + return ModelResponse(parts=[TextPart(content='done')]) + + agent = Agent(FunctionModel(parallel_tools_model)) + + @agent.tool_plain + async def tool_a() -> str: + # Add multiple await points to increase chance of task interleaving + await asyncio.sleep(0.0001) + await asyncio.sleep(0.0001) + return 'result' + + result = await agent.run('test') + # Without proper synchronization, tool_calls might be undercounted + actual = result.usage().tool_calls + assert actual == 10, f'Iteration {iteration}: Expected 10 tool calls, got {actual}' + + async def test_parallel_tool_calls_limit_enforced(): """Parallel tool calls must not exceed the limit and should raise immediately.""" executed_tools: list[str] = []