Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,6 @@ async def stream(
message_history, model_settings, model_request_parameters, run_context
) as streamed_response:
self._did_stream = True
ctx.state.usage.requests += 1
agent_stream = result.AgentStream[DepsT, T](
_raw_stream_response=streamed_response,
_output_schema=ctx.deps.output_schema,
Expand All @@ -419,14 +418,12 @@ async def stream(
_tool_manager=ctx.deps.tool_manager,
)
yield agent_stream
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
# otherwise usage won't be properly counted:
async for _ in agent_stream:
pass

model_response = streamed_response.get()

self._finish_handling(ctx, model_response)
await self._finish_handling(ctx, model_response)
assert self._result is not None # this should be set by the previous line

async def _make_request(
Expand All @@ -437,9 +434,8 @@ async def _make_request(

model_settings, model_request_parameters, message_history, _ = await self._prepare_request(ctx)
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
ctx.state.usage.requests += 1

return self._finish_handling(ctx, model_response)
return await self._finish_handling(ctx, model_response)

async def _prepare_request(
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
Expand Down Expand Up @@ -481,7 +477,7 @@ async def _prepare_request(

return model_settings, model_request_parameters, message_history, run_context

def _finish_handling(
async def _finish_handling(
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
response: _messages.ModelResponse,
Expand Down Expand Up @@ -915,12 +911,14 @@ async def _call_tools(
async def handle_call_or_result(
coro_or_task: Awaitable[
tuple[
_messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None
_messages.ToolReturnPart | _messages.RetryPromptPart,
str | Sequence[_messages.UserContent] | None,
]
]
| Task[
tuple[
_messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None
_messages.ToolReturnPart | _messages.RetryPromptPart,
str | Sequence[_messages.UserContent] | None,
]
],
index: int,
Expand Down
10 changes: 9 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import asyncio
import json
from collections.abc import Iterator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field, replace
from functools import cached_property
from typing import Any, Generic

from opentelemetry.trace import Tracer
Expand Down Expand Up @@ -72,6 +74,11 @@ def tool_defs(self) -> list[ToolDefinition]:

return [tool.tool_def for tool in self.tools.values()]

@cached_property
def _usage_lock(self) -> asyncio.Lock:
"""Lock to prevent race conditions when incrementing usage.tool_calls from concurrent tool executions."""
return asyncio.Lock()

def should_call_sequentially(self, calls: list[ToolCallPart]) -> bool:
"""Whether to require sequential tool calls for a list of tool calls."""
return _sequential_tool_calls_ctx_var.get() or any(
Expand Down Expand Up @@ -234,7 +241,8 @@ async def _call_function_tool(
) as span:
try:
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
usage.tool_calls += 1
async with self._usage_lock:
usage.incr(RunUsage(tool_calls=1))

except ToolRetryError as e:
part = e.tool_retry
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,7 +1178,7 @@ async def _responses_create(
truncation=model_settings.get('openai_truncation', NOT_GIVEN),
timeout=model_settings.get('timeout', NOT_GIVEN),
service_tier=model_settings.get('openai_service_tier', NOT_GIVEN),
previous_response_id=previous_response_id or NOT_GIVEN,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a broken merge conflict resolution! Please remove it from the diff to make sure we don't accidentally merge this into main.

previous_response_id=previous_response_id,
reasoning=reasoning,
user=model_settings.get('openai_user', NOT_GIVEN),
text=text or NOT_GIVEN,
Expand Down
4 changes: 3 additions & 1 deletion pydantic_ai_slim/pydantic_ai/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,17 @@ def incr(self, incr_usage: RunUsage | RequestUsage) -> None:
Args:
incr_usage: The usage to increment by.
"""
self.requests += incr_usage.requests
if isinstance(incr_usage, RunUsage):
self.requests += incr_usage.requests
self.tool_calls += incr_usage.tool_calls
return _incr_usage_tokens(self, incr_usage)

def __add__(self, other: RunUsage | RequestUsage) -> RunUsage:
"""Add two RunUsages together.

This is provided so it's trivial to sum usage information from multiple runs.

**WARNING:** this CANNOT be used to sum multiple requests without breaking some pricing calculations.
"""
new_usage = copy(self)
new_usage.incr(other)
Expand Down
11 changes: 11 additions & 0 deletions tests/models/test_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -2600,6 +2600,7 @@ async def test_google_image_generation(allow_model_requests: None, google_provid
BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='8a7952',
identifier='8a7952',
)
)
Expand All @@ -2620,6 +2621,7 @@ async def test_google_image_generation(allow_model_requests: None, google_provid
content=BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='8a7952',
identifier='8a7952',
)
),
Expand All @@ -2644,6 +2646,7 @@ async def test_google_image_generation(allow_model_requests: None, google_provid
BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='7d173c',
identifier='7d173c',
)
)
Expand All @@ -2664,6 +2667,7 @@ async def test_google_image_generation(allow_model_requests: None, google_provid
content=BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='7d173c',
identifier='7d173c',
)
),
Expand Down Expand Up @@ -2693,6 +2697,7 @@ async def test_google_image_generation_stream(allow_model_requests: None, google
BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='9ff9cc',
identifier='9ff9cc',
)
)
Expand All @@ -2710,6 +2715,7 @@ async def test_google_image_generation_stream(allow_model_requests: None, google
BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='2af2a7',
identifier='2af2a7',
)
)
Expand All @@ -2730,6 +2736,7 @@ async def test_google_image_generation_stream(allow_model_requests: None, google
content=BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='2af2a7',
identifier='2af2a7',
)
),
Expand Down Expand Up @@ -2758,6 +2765,7 @@ async def test_google_image_generation_stream(allow_model_requests: None, google
content=BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='2af2a7',
identifier='2af2a7',
)
),
Expand Down Expand Up @@ -2796,6 +2804,7 @@ async def test_google_image_generation_with_text(allow_model_requests: None, goo
content=BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='00f2af',
identifier=IsStr(),
)
),
Expand Down Expand Up @@ -2831,6 +2840,7 @@ async def test_google_image_or_text_output(allow_model_requests: None, google_pr
BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='f82faf',
identifier='f82faf',
)
)
Expand All @@ -2849,6 +2859,7 @@ async def test_google_image_and_text_output(allow_model_requests: None, google_p
BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='67b12f',
identifier='67b12f',
)
]
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3346,7 +3346,7 @@ class Result(BaseModel):
assert response_stream.usage() == snapshot(
RunUsage(input_tokens=53, output_tokens=469, details={'reasoning_tokens': 448}, requests=1)
)
assert run.usage() == snapshot(RunUsage(requests=1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a breaking change we shouldn't make

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, this would be a breaking change, so I'd rather ensure at the call site where we call incr that RunUsage.requests == 0.

assert run.usage() == snapshot(RunUsage())
assert run.usage() == snapshot(
RunUsage(input_tokens=53, output_tokens=469, details={'reasoning_tokens': 448}, requests=1)
)
Expand Down
73 changes: 72 additions & 1 deletion tests/test_usage_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ModelRequest,
ModelResponse,
RunContext,
TextPart,
ToolCallPart,
ToolReturnPart,
UsageLimitExceeded,
Expand Down Expand Up @@ -200,7 +201,7 @@ async def test_multi_agent_usage_sync():
controller_agent = Agent(TestModel())

@controller_agent.tool
def delegate_to_other_agent(ctx: RunContext[None], sentence: str) -> int:
async def delegate_to_other_agent(ctx: RunContext[None], sentence: str) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed?

new_usage = RunUsage(requests=5, input_tokens=2, output_tokens=3)
ctx.usage.incr(new_usage)
return 0
Expand Down Expand Up @@ -248,6 +249,41 @@ def test_add_usages():
assert RunUsage() + RunUsage() == RunUsage()


def test_run_usage_with_request_usage():
"""Test RunUsage operations with RequestUsage to ensure coverage of RequestUsage branches."""
run_usage = RunUsage(requests=1, input_tokens=10, output_tokens=20, tool_calls=1)
request_usage = RequestUsage(input_tokens=5, output_tokens=10)

# Test __add__ with RequestUsage
result = run_usage + request_usage
assert result.requests == 2 # 1 + 1 (RequestUsage.requests property returns 1)
assert result.input_tokens == 15
assert result.output_tokens == 30
assert result.tool_calls == 1 # RequestUsage doesn't have tool_calls

# Test incr with RequestUsage (covers elif isinstance(incr_usage, RequestUsage) branch)
run_usage2 = RunUsage(requests=2, input_tokens=20, output_tokens=30, tool_calls=2)
run_usage2.incr(request_usage)
assert run_usage2.requests == 3 # 2 + 1
assert run_usage2.input_tokens == 25 # 20 + 5
assert run_usage2.output_tokens == 40 # 30 + 10
assert run_usage2.tool_calls == 2 # Unchanged

# Test incr with empty details dict (covers empty for loop branch in _incr_usage_tokens)
run_usage3 = RunUsage(requests=0, tool_calls=0)
request_usage_no_details = RequestUsage(input_tokens=5, output_tokens=10)
assert request_usage_no_details.details == {} # Ensure details is empty
run_usage3.incr(request_usage_no_details)
assert run_usage3.requests == 1
assert run_usage3.details == {}

# Test incr with non-empty details dict
run_usage4 = RunUsage(requests=0, tool_calls=0, details={'reasoning_tokens': 10})
request_usage_with_details = RequestUsage(input_tokens=5, output_tokens=10, details={'reasoning_tokens': 5})
run_usage4.incr(request_usage_with_details)
assert run_usage4.details == {'reasoning_tokens': 15}


async def test_tool_call_limit() -> None:
test_agent = Agent(TestModel())

Expand Down Expand Up @@ -355,6 +391,41 @@ def test_deprecated_usage_limits():
assert UsageLimits(output_tokens_limit=100).response_tokens_limit == 100 # type: ignore


async def test_race_condition_parallel_tool_calls():
"""Test that demonstrates race condition in parallel tool execution.

This test would fail intermittently on main without the fix because multiple
asyncio tasks calling usage.incr() can interleave their read-modify-write operations.
"""
# Run multiple iterations to increase chance of catching race condition
for iteration in range(20):
call_count = 0

def parallel_tools_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
nonlocal call_count
call_count += 1
if call_count == 1:
# Return 10 parallel tool calls for more contention
return ModelResponse(parts=[ToolCallPart('tool_a', {}, f'call_{i}') for i in range(10)])
else:
# Return final text response
return ModelResponse(parts=[TextPart(content='done')])

agent = Agent(FunctionModel(parallel_tools_model))

@agent.tool_plain
async def tool_a() -> str:
# Add multiple await points to increase chance of task interleaving
await asyncio.sleep(0.0001)
await asyncio.sleep(0.0001)
return 'result'

result = await agent.run('test')
# Without proper synchronization, tool_calls might be undercounted
actual = result.usage().tool_calls
assert actual == 10, f'Iteration {iteration}: Expected 10 tool calls, got {actual}'


async def test_parallel_tool_calls_limit_enforced():
"""Parallel tool calls must not exceed the limit and should raise immediately."""
executed_tools: list[str] = []
Expand Down