diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 383b214dca..d7a54c5c71 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -426,7 +426,7 @@ async def stream( model_response = streamed_response.get() - await self._finish_handling(ctx, model_response) + self._finish_handling(ctx, model_response) assert self._result is not None # this should be set by the previous line async def _make_request( @@ -439,7 +439,7 @@ async def _make_request( model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters) ctx.state.usage.requests += 1 - return await self._finish_handling(ctx, model_response) + return self._finish_handling(ctx, model_response) async def _prepare_request( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] @@ -481,7 +481,7 @@ async def _prepare_request( return model_settings, model_request_parameters, message_history, run_context - async def _finish_handling( + def _finish_handling( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], response: _messages.ModelResponse, diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py index 167092657d..a5546a4e01 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -1,12 +1,10 @@ from __future__ import annotations -import asyncio import json from collections.abc import Iterator from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass, field, replace -from functools import cached_property from typing import Any, Generic from opentelemetry.trace import Tracer @@ -74,11 +72,6 @@ def tool_defs(self) -> list[ToolDefinition]: return [tool.tool_def for tool in self.tools.values()] - @cached_property - def _usage_lock(self) -> asyncio.Lock: - """Lock to prevent race conditions when incrementing usage.tool_calls from concurrent tool executions.""" - return asyncio.Lock() - def should_call_sequentially(self, calls: list[ToolCallPart]) -> bool: """Whether to require sequential tool calls for a list of tool calls.""" return _sequential_tool_calls_ctx_var.get() or any( @@ -241,8 +234,7 @@ async def _call_function_tool( ) as span: try: tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors) - async with self._usage_lock: - usage.tool_calls += 1 + usage.tool_calls += 1 except ToolRetryError as e: part = e.tool_retry diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py index 59bc049dc6..d3c40ed375 100644 --- a/tests/test_usage_limits.py +++ b/tests/test_usage_limits.py @@ -17,7 +17,6 @@ ModelRequest, ModelResponse, RunContext, - TextPart, ToolCallPart, ToolReturnPart, UsageLimitExceeded, @@ -356,41 +355,6 @@ def test_deprecated_usage_limits(): assert UsageLimits(output_tokens_limit=100).response_tokens_limit == 100 # type: ignore -async def test_race_condition_parallel_tool_calls(): - """Test that demonstrates race condition in parallel tool execution. - - This test would fail intermittently on main without the fix because multiple - asyncio tasks calling usage.incr() can interleave their read-modify-write operations. - """ - # Run multiple iterations to increase chance of catching race condition - for iteration in range(20): - call_count = 0 - - def parallel_tools_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: - nonlocal call_count - call_count += 1 - if call_count == 1: - # Return 10 parallel tool calls for more contention - return ModelResponse(parts=[ToolCallPart('tool_a', {}, f'call_{i}') for i in range(10)]) - else: - # Return final text response - return ModelResponse(parts=[TextPart(content='done')]) - - agent = Agent(FunctionModel(parallel_tools_model)) - - @agent.tool_plain - async def tool_a() -> str: - # Add multiple await points to increase chance of task interleaving - await asyncio.sleep(0.0001) - await asyncio.sleep(0.0001) - return 'result' - - result = await agent.run('test') - # Without proper synchronization, tool_calls might be undercounted - actual = result.usage().tool_calls - assert actual == 10, f'Iteration {iteration}: Expected 10 tool calls, got {actual}' - - async def test_parallel_tool_calls_limit_enforced(): """Parallel tool calls must not exceed the limit and should raise immediately.""" executed_tools: list[str] = []