Skip to content

Commit 6f8cfc7

Browse files
Address code review feedback
- Remove unnecessary comments about request counting - Move usage_lock to ToolManager as cached_property for better encapsulation - Simplify RunUsage.incr() to avoid code duplication - Clean up _agent_graph.py by removing context var management This makes the lock management more localized to ToolManager where parallel execution actually happens, improving code organization and maintainability.
1 parent bea4e20 commit 6f8cfc7

File tree

3 files changed

+72
-90
lines changed

3 files changed

+72
-90
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 64 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,6 @@ async def stream(
408408
message_history, model_settings, model_request_parameters, run_context
409409
) as streamed_response:
410410
self._did_stream = True
411-
# Request count is incremented in _finish_handling via response.usage
412411
agent_stream = result.AgentStream[DepsT, T](
413412
_raw_stream_response=streamed_response,
414413
_output_schema=ctx.deps.output_schema,
@@ -419,8 +418,6 @@ async def stream(
419418
_tool_manager=ctx.deps.tool_manager,
420419
)
421420
yield agent_stream
422-
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
423-
# otherwise usage won't be properly counted:
424421
async for _ in agent_stream:
425422
pass
426423

@@ -437,7 +434,6 @@ async def _make_request(
437434

438435
model_settings, model_request_parameters, message_history, _ = await self._prepare_request(ctx)
439436
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
440-
# Request count is incremented in _finish_handling via response.usage
441437

442438
return await self._finish_handling(ctx, model_response)
443439

@@ -895,8 +891,6 @@ async def _call_tools(
895891
tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
896892
user_parts_by_index: dict[int, _messages.UserPromptPart] = {}
897893
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {}
898-
# Lock to prevent race conditions when incrementing usage.tool_calls from concurrent tool executions
899-
usage_lock = asyncio.Lock()
900894

901895
if usage_limits.tool_calls_limit is not None:
902896
projected_usage = deepcopy(usage)
@@ -906,85 +900,76 @@ async def _call_tools(
906900
for call in tool_calls:
907901
yield _messages.FunctionToolCallEvent(call)
908902

909-
# Import and set the usage lock context variable for parallel tool execution
910-
from ._tool_manager import _usage_increment_lock_ctx_var # pyright: ignore[reportPrivateUsage]
911-
912-
token = _usage_increment_lock_ctx_var.set(usage_lock)
913-
914-
try:
915-
with tracer.start_as_current_span(
916-
'running tools',
917-
attributes={
918-
'tools': [call.tool_name for call in tool_calls],
919-
'logfire.msg': f'running {len(tool_calls)} tool{"" if len(tool_calls) == 1 else "s"}',
920-
},
921-
):
903+
with tracer.start_as_current_span(
904+
'running tools',
905+
attributes={
906+
'tools': [call.tool_name for call in tool_calls],
907+
'logfire.msg': f'running {len(tool_calls)} tool{"" if len(tool_calls) == 1 else "s"}',
908+
},
909+
):
922910

923-
async def handle_call_or_result(
924-
coro_or_task: Awaitable[
925-
tuple[
926-
_messages.ToolReturnPart | _messages.RetryPromptPart,
927-
str | Sequence[_messages.UserContent] | None,
928-
]
911+
async def handle_call_or_result(
912+
coro_or_task: Awaitable[
913+
tuple[
914+
_messages.ToolReturnPart | _messages.RetryPromptPart,
915+
str | Sequence[_messages.UserContent] | None,
929916
]
930-
| Task[
931-
tuple[
932-
_messages.ToolReturnPart | _messages.RetryPromptPart,
933-
str | Sequence[_messages.UserContent] | None,
934-
]
935-
],
936-
index: int,
937-
) -> _messages.HandleResponseEvent | None:
938-
try:
939-
tool_part, tool_user_content = (
940-
(await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result()
941-
)
942-
except exceptions.CallDeferred:
943-
deferred_calls_by_index[index] = 'external'
944-
except exceptions.ApprovalRequired:
945-
deferred_calls_by_index[index] = 'unapproved'
946-
else:
947-
tool_parts_by_index[index] = tool_part
948-
if tool_user_content:
949-
user_parts_by_index[index] = _messages.UserPromptPart(content=tool_user_content)
950-
951-
return _messages.FunctionToolResultEvent(tool_part, content=tool_user_content)
952-
953-
if tool_manager.should_call_sequentially(tool_calls):
954-
for index, call in enumerate(tool_calls):
955-
if event := await handle_call_or_result(
956-
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
957-
index,
958-
):
959-
yield event
960-
961-
else:
962-
tasks = [
963-
asyncio.create_task(
964-
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
965-
name=call.tool_name,
966-
)
967-
for call in tool_calls
917+
]
918+
| Task[
919+
tuple[
920+
_messages.ToolReturnPart | _messages.RetryPromptPart,
921+
str | Sequence[_messages.UserContent] | None,
968922
]
923+
],
924+
index: int,
925+
) -> _messages.HandleResponseEvent | None:
926+
try:
927+
tool_part, tool_user_content = (
928+
(await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result()
929+
)
930+
except exceptions.CallDeferred:
931+
deferred_calls_by_index[index] = 'external'
932+
except exceptions.ApprovalRequired:
933+
deferred_calls_by_index[index] = 'unapproved'
934+
else:
935+
tool_parts_by_index[index] = tool_part
936+
if tool_user_content:
937+
user_parts_by_index[index] = _messages.UserPromptPart(content=tool_user_content)
969938

970-
pending = tasks
971-
while pending:
972-
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
973-
for task in done:
974-
index = tasks.index(task)
975-
if event := await handle_call_or_result(coro_or_task=task, index=index):
976-
yield event
939+
return _messages.FunctionToolResultEvent(tool_part, content=tool_user_content)
977940

978-
# We append the results at the end, rather than as they are received, to retain a consistent ordering
979-
# This is mostly just to simplify testing
980-
output_parts.extend([tool_parts_by_index[k] for k in sorted(tool_parts_by_index)])
981-
output_parts.extend([user_parts_by_index[k] for k in sorted(user_parts_by_index)])
941+
if tool_manager.should_call_sequentially(tool_calls):
942+
for index, call in enumerate(tool_calls):
943+
if event := await handle_call_or_result(
944+
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
945+
index,
946+
):
947+
yield event
982948

983-
for k in sorted(deferred_calls_by_index):
984-
output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k])
985-
finally:
986-
# Reset the context variable
987-
_usage_increment_lock_ctx_var.reset(token)
949+
else:
950+
tasks = [
951+
asyncio.create_task(
952+
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
953+
name=call.tool_name,
954+
)
955+
for call in tool_calls
956+
]
957+
958+
pending = tasks
959+
while pending:
960+
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
961+
for task in done:
962+
index = tasks.index(task)
963+
if event := await handle_call_or_result(coro_or_task=task, index=index):
964+
yield event
965+
966+
# We append the results at the end, rather than as they are received, to retain a consistent ordering
967+
# This is mostly just to simplify testing
968+
output_parts.extend([tool_parts_by_index[k] for k in sorted(tool_parts_by_index)])
969+
output_parts.extend([user_parts_by_index[k] for k in sorted(user_parts_by_index)])
970+
971+
for k in sorted(deferred_calls_by_index):
972+
output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k])
988973

989974

990975
async def _call_tool(

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from contextlib import contextmanager
77
from contextvars import ContextVar
88
from dataclasses import dataclass, field, replace
9+
from functools import cached_property
910
from typing import Any, Generic
1011

1112
from opentelemetry.trace import Tracer
@@ -22,7 +23,6 @@
2223
from .usage import RunUsage
2324

2425
_sequential_tool_calls_ctx_var: ContextVar[bool] = ContextVar('sequential_tool_calls', default=False)
25-
_usage_increment_lock_ctx_var: ContextVar[asyncio.Lock | None] = ContextVar('usage_increment_lock', default=None)
2626

2727

2828
@dataclass
@@ -74,6 +74,11 @@ def tool_defs(self) -> list[ToolDefinition]:
7474

7575
return [tool.tool_def for tool in self.tools.values()]
7676

77+
@cached_property
78+
def _usage_lock(self) -> asyncio.Lock:
79+
"""Lock to prevent race conditions when incrementing usage.tool_calls from concurrent tool executions."""
80+
return asyncio.Lock()
81+
7782
def should_call_sequentially(self, calls: list[ToolCallPart]) -> bool:
7883
"""Whether to require sequential tool calls for a list of tool calls."""
7984
return _sequential_tool_calls_ctx_var.get() or any(
@@ -236,12 +241,7 @@ async def _call_function_tool(
236241
) as span:
237242
try:
238243
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
239-
# Use lock if available (for parallel tool execution) to prevent race conditions
240-
lock = _usage_increment_lock_ctx_var.get()
241-
if lock is not None:
242-
async with lock:
243-
usage.incr(RunUsage(tool_calls=1))
244-
else:
244+
async with self._usage_lock:
245245
usage.incr(RunUsage(tool_calls=1))
246246

247247
except ToolRetryError as e:

pydantic_ai_slim/pydantic_ai/usage.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,9 @@ def incr(self, incr_usage: RunUsage | RequestUsage) -> None:
195195
Args:
196196
incr_usage: The usage to increment by.
197197
"""
198+
self.requests += incr_usage.requests
198199
if isinstance(incr_usage, RunUsage):
199-
self.requests += incr_usage.requests
200200
self.tool_calls += incr_usage.tool_calls
201-
else:
202-
# RequestUsage: requests is a property that returns 1
203-
self.requests += incr_usage.requests
204201
return _incr_usage_tokens(self, incr_usage)
205202

206203
def __add__(self, other: RunUsage | RequestUsage) -> RunUsage:

0 commit comments

Comments
 (0)