Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 80 additions & 67 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ async def stream(
message_history, model_settings, model_request_parameters, run_context
) as streamed_response:
self._did_stream = True
ctx.state.usage.requests += 1
# Request count is incremented in _finish_handling via response.usage
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to include this comment, or the next identical one

agent_stream = result.AgentStream[DepsT, T](
_raw_stream_response=streamed_response,
_output_schema=ctx.deps.output_schema,
Expand All @@ -426,7 +426,7 @@ async def stream(

model_response = streamed_response.get()

self._finish_handling(ctx, model_response)
await self._finish_handling(ctx, model_response)
assert self._result is not None # this should be set by the previous line

async def _make_request(
Expand All @@ -437,9 +437,9 @@ async def _make_request(

model_settings, model_request_parameters, message_history, _ = await self._prepare_request(ctx)
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
ctx.state.usage.requests += 1
# Request count is incremented in _finish_handling via response.usage

return self._finish_handling(ctx, model_response)
return await self._finish_handling(ctx, model_response)

async def _prepare_request(
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
Expand Down Expand Up @@ -481,7 +481,7 @@ async def _prepare_request(

return model_settings, model_request_parameters, message_history, run_context

def _finish_handling(
async def _finish_handling(
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
response: _messages.ModelResponse,
Expand Down Expand Up @@ -895,6 +895,8 @@ async def _call_tools(
tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
user_parts_by_index: dict[int, _messages.UserPromptPart] = {}
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {}
# Lock to prevent race conditions when incrementing usage.tool_calls from concurrent tool executions
usage_lock = asyncio.Lock()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be a cached_property on ToolManager, and we wouldn't need to touch agent_graph at all


if usage_limits.tool_calls_limit is not None:
projected_usage = deepcopy(usage)
Expand All @@ -904,74 +906,85 @@ async def _call_tools(
for call in tool_calls:
yield _messages.FunctionToolCallEvent(call)

with tracer.start_as_current_span(
'running tools',
attributes={
'tools': [call.tool_name for call in tool_calls],
'logfire.msg': f'running {len(tool_calls)} tool{"" if len(tool_calls) == 1 else "s"}',
},
):
# Import and set the usage lock context variable for parallel tool execution
from ._tool_manager import _usage_increment_lock_ctx_var # pyright: ignore[reportPrivateUsage]

async def handle_call_or_result(
coro_or_task: Awaitable[
tuple[
_messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None
]
]
| Task[
tuple[
_messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None
]
],
index: int,
) -> _messages.HandleResponseEvent | None:
try:
tool_part, tool_user_content = (
(await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result()
)
except exceptions.CallDeferred:
deferred_calls_by_index[index] = 'external'
except exceptions.ApprovalRequired:
deferred_calls_by_index[index] = 'unapproved'
else:
tool_parts_by_index[index] = tool_part
if tool_user_content:
user_parts_by_index[index] = _messages.UserPromptPart(content=tool_user_content)
token = _usage_increment_lock_ctx_var.set(usage_lock)

return _messages.FunctionToolResultEvent(tool_part, content=tool_user_content)

if tool_manager.should_call_sequentially(tool_calls):
for index, call in enumerate(tool_calls):
if event := await handle_call_or_result(
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
index,
):
yield event
try:
with tracer.start_as_current_span(
'running tools',
attributes={
'tools': [call.tool_name for call in tool_calls],
'logfire.msg': f'running {len(tool_calls)} tool{"" if len(tool_calls) == 1 else "s"}',
},
):

else:
tasks = [
asyncio.create_task(
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
name=call.tool_name,
)
for call in tool_calls
]

pending = tasks
while pending:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
for task in done:
index = tasks.index(task)
if event := await handle_call_or_result(coro_or_task=task, index=index):
async def handle_call_or_result(
coro_or_task: Awaitable[
tuple[
_messages.ToolReturnPart | _messages.RetryPromptPart,
str | Sequence[_messages.UserContent] | None,
]
]
| Task[
tuple[
_messages.ToolReturnPart | _messages.RetryPromptPart,
str | Sequence[_messages.UserContent] | None,
]
],
index: int,
) -> _messages.HandleResponseEvent | None:
try:
tool_part, tool_user_content = (
(await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result()
)
except exceptions.CallDeferred:
deferred_calls_by_index[index] = 'external'
except exceptions.ApprovalRequired:
deferred_calls_by_index[index] = 'unapproved'
else:
tool_parts_by_index[index] = tool_part
if tool_user_content:
user_parts_by_index[index] = _messages.UserPromptPart(content=tool_user_content)

return _messages.FunctionToolResultEvent(tool_part, content=tool_user_content)

if tool_manager.should_call_sequentially(tool_calls):
for index, call in enumerate(tool_calls):
if event := await handle_call_or_result(
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
index,
):
yield event

# We append the results at the end, rather than as they are received, to retain a consistent ordering
# This is mostly just to simplify testing
output_parts.extend([tool_parts_by_index[k] for k in sorted(tool_parts_by_index)])
output_parts.extend([user_parts_by_index[k] for k in sorted(user_parts_by_index)])
else:
tasks = [
asyncio.create_task(
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
name=call.tool_name,
)
for call in tool_calls
]

pending = tasks
while pending:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
for task in done:
index = tasks.index(task)
if event := await handle_call_or_result(coro_or_task=task, index=index):
yield event

for k in sorted(deferred_calls_by_index):
output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k])
# We append the results at the end, rather than as they are received, to retain a consistent ordering
# This is mostly just to simplify testing
output_parts.extend([tool_parts_by_index[k] for k in sorted(tool_parts_by_index)])
output_parts.extend([user_parts_by_index[k] for k in sorted(user_parts_by_index)])

for k in sorted(deferred_calls_by_index):
output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k])
finally:
# Reset the context variable
_usage_increment_lock_ctx_var.reset(token)


async def _call_tool(
Expand Down
10 changes: 9 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import json
from collections.abc import Iterator
from contextlib import contextmanager
Expand All @@ -21,6 +22,7 @@
from .usage import RunUsage

_sequential_tool_calls_ctx_var: ContextVar[bool] = ContextVar('sequential_tool_calls', default=False)
_usage_increment_lock_ctx_var: ContextVar[asyncio.Lock | None] = ContextVar('usage_increment_lock', default=None)


@dataclass
Expand Down Expand Up @@ -234,7 +236,13 @@ async def _call_function_tool(
) as span:
try:
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
usage.tool_calls += 1
# Use lock if available (for parallel tool execution) to prevent race conditions
lock = _usage_increment_lock_ctx_var.get()
if lock is not None:
async with lock:
usage.incr(RunUsage(tool_calls=1))
else:
usage.incr(RunUsage(tool_calls=1))

except ToolRetryError as e:
part = e.tool_retry
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,7 +1178,7 @@ async def _responses_create(
truncation=model_settings.get('openai_truncation', NOT_GIVEN),
timeout=model_settings.get('timeout', NOT_GIVEN),
service_tier=model_settings.get('openai_service_tier', NOT_GIVEN),
previous_response_id=previous_response_id or NOT_GIVEN,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a broken merge conflict resolution! Please remove it from the diff to make sure we don't accidentally merge this into main.

previous_response_id=previous_response_id,
reasoning=reasoning,
user=model_settings.get('openai_user', NOT_GIVEN),
text=text or NOT_GIVEN,
Expand Down
5 changes: 5 additions & 0 deletions pydantic_ai_slim/pydantic_ai/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,17 @@ def incr(self, incr_usage: RunUsage | RequestUsage) -> None:
if isinstance(incr_usage, RunUsage):
self.requests += incr_usage.requests
self.tool_calls += incr_usage.tool_calls
else:
# RequestUsage: requests is a property that returns 1
self.requests += incr_usage.requests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicated with if branch

return _incr_usage_tokens(self, incr_usage)

def __add__(self, other: RunUsage | RequestUsage) -> RunUsage:
"""Add two RunUsages together.

This is provided so it's trivial to sum usage information from multiple runs.

**WARNING:** this CANNOT be used to sum multiple requests without breaking some pricing calculations.
"""
new_usage = copy(self)
new_usage.incr(other)
Expand Down
11 changes: 11 additions & 0 deletions tests/models/test_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -2600,6 +2600,7 @@ async def test_google_image_generation(allow_model_requests: None, google_provid
BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='8a7952',
identifier='8a7952',
)
)
Expand All @@ -2620,6 +2621,7 @@ async def test_google_image_generation(allow_model_requests: None, google_provid
content=BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='8a7952',
identifier='8a7952',
)
),
Expand All @@ -2644,6 +2646,7 @@ async def test_google_image_generation(allow_model_requests: None, google_provid
BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='7d173c',
identifier='7d173c',
)
)
Expand All @@ -2664,6 +2667,7 @@ async def test_google_image_generation(allow_model_requests: None, google_provid
content=BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='7d173c',
identifier='7d173c',
)
),
Expand Down Expand Up @@ -2693,6 +2697,7 @@ async def test_google_image_generation_stream(allow_model_requests: None, google
BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='9ff9cc',
identifier='9ff9cc',
)
)
Expand All @@ -2710,6 +2715,7 @@ async def test_google_image_generation_stream(allow_model_requests: None, google
BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='2af2a7',
identifier='2af2a7',
)
)
Expand All @@ -2730,6 +2736,7 @@ async def test_google_image_generation_stream(allow_model_requests: None, google
content=BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='2af2a7',
identifier='2af2a7',
)
),
Expand Down Expand Up @@ -2758,6 +2765,7 @@ async def test_google_image_generation_stream(allow_model_requests: None, google
content=BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='2af2a7',
identifier='2af2a7',
)
),
Expand Down Expand Up @@ -2796,6 +2804,7 @@ async def test_google_image_generation_with_text(allow_model_requests: None, goo
content=BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='00f2af',
identifier=IsStr(),
)
),
Expand Down Expand Up @@ -2831,6 +2840,7 @@ async def test_google_image_or_text_output(allow_model_requests: None, google_pr
BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='f82faf',
identifier='f82faf',
)
)
Expand All @@ -2849,6 +2859,7 @@ async def test_google_image_and_text_output(allow_model_requests: None, google_p
BinaryImage(
data=IsBytes(),
media_type='image/png',
_identifier='67b12f',
identifier='67b12f',
)
]
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3346,7 +3346,7 @@ class Result(BaseModel):
assert response_stream.usage() == snapshot(
RunUsage(input_tokens=53, output_tokens=469, details={'reasoning_tokens': 448}, requests=1)
)
assert run.usage() == snapshot(RunUsage(requests=1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a breaking change we shouldn't make

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, this would be a breaking change, so I'd rather ensure at the call site where we call incr that RunUsage.requests == 0.

assert run.usage() == snapshot(RunUsage())
assert run.usage() == snapshot(
RunUsage(input_tokens=53, output_tokens=469, details={'reasoning_tokens': 448}, requests=1)
)
Expand Down
Loading