Skip to content

Commit 68918cb

Browse files
committed
refactor: only successful tool invocations are counted towards tool_calls usage metric
- Adjusted the tool call counting mechanism to ensure that only successful tool invocations are counted towards the `tool_calls` metric. - Updated documentation in `tools.md` to clarify that output tools do not increment the `tool_calls` count. - Modified multiple test cases to reflect the correct counting of tool calls, including tests for failed tool calls and their impact on usage metrics.
1 parent 146ad10 commit 68918cb

File tree

6 files changed

+31
-6
lines changed

6 files changed

+31
-6
lines changed

docs/tools.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ When a model returns multiple tool calls in one response, Pydantic AI schedules
730730
Async functions are run on the event loop, while sync functions are offloaded to threads. To get the best performance, _always_ use an async function _unless_ you're doing blocking I/O (and there's no way to use a non-blocking library instead) or CPU-bound work (like `numpy` or `scikit-learn` operations), so that simple functions are not offloaded to threads unnecessarily.
731731

732732
!!! note "Limiting exact tool executions"
733-
You can cap the exact number of tool executions within a run using `UsageLimits(tool_calls_limit=...)`. The counter increments immediately before each actual tool invocation (after successful argument validation), and concurrent calls are counted safely.
733+
You can cap the exact number of tool executions within a run using `UsageLimits(tool_calls_limit=...)`. The counter increments after each successful tool invocation. Note that output tools (used for structured output) are not counted in the `tool_calls` metric.
734734

735735
## Third-Party Tools
736736

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,13 @@ async def _call_tool(self, call: ToolCallPart, allow_partial: bool, wrap_validat
122122

123123
if tool.tool_def.kind != 'output' and self.ctx.usage_limits is not None:
124124
self.ctx.usage_limits.check_before_tool_call(self.ctx.usage)
125+
126+
result = await self.toolset.call_tool(name, args_dict, ctx, tool)
127+
128+
if tool.tool_def.kind != 'output':
125129
self.ctx.usage.tool_calls += 1
126130

127-
return await self.toolset.call_tool(name, args_dict, ctx, tool)
131+
return result
128132
except (ValidationError, ModelRetry) as e:
129133
max_retries = tool.max_retries if tool is not None else 1
130134
current_retry = self.ctx.retries.get(name, 0)

tests/models/test_cohere.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ async def get_location(loc_name: str) -> str:
331331
input_tokens=5,
332332
output_tokens=3,
333333
details={'input_tokens': 4, 'output_tokens': 2},
334-
tool_calls=2,
334+
tool_calls=1,
335335
)
336336
)
337337

tests/models/test_gemini.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,7 @@ async def get_location(loc_name: str) -> str:
783783
),
784784
]
785785
)
786-
assert result.usage() == snapshot(RunUsage(requests=3, input_tokens=3, output_tokens=6, tool_calls=3))
786+
assert result.usage() == snapshot(RunUsage(requests=3, input_tokens=3, output_tokens=6, tool_calls=2))
787787

788788

789789
async def test_unexpected_response(client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None):

tests/models/test_openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ async def get_location(loc_name: str) -> str:
423423
]
424424
)
425425
assert result.usage() == snapshot(
426-
RunUsage(requests=3, cache_read_tokens=3, input_tokens=5, output_tokens=3, tool_calls=2)
426+
RunUsage(requests=3, cache_read_tokens=3, input_tokens=5, output_tokens=3, tool_calls=1)
427427
)
428428

429429

tests/test_usage_limits.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pydantic import BaseModel
1212

1313
from pydantic_ai import Agent, RunContext, UsageLimitExceeded
14+
from pydantic_ai.exceptions import ModelRetry
1415
from pydantic_ai.messages import ModelRequest, ModelResponse, ToolCallPart, ToolReturnPart, UserPromptPart
1516
from pydantic_ai.models.test import TestModel
1617
from pydantic_ai.output import ToolOutput
@@ -163,7 +164,7 @@ async def delegate_to_other_agent1(ctx: RunContext[None], sentence: str) -> int:
163164
async def delegate_to_other_agent2(ctx: RunContext[None], sentence: str) -> int:
164165
delegate_result = await delegate_agent.run(sentence, usage=ctx.usage)
165166
delegate_usage = delegate_result.usage()
166-
assert delegate_usage == snapshot(RunUsage(requests=2, input_tokens=102, output_tokens=9, tool_calls=1))
167+
assert delegate_usage == snapshot(RunUsage(requests=2, input_tokens=102, output_tokens=9))
167168
return delegate_result.output
168169

169170
result2 = await controller_agent2.run('foobar')
@@ -287,3 +288,23 @@ async def another_regular_tool(x: str) -> str:
287288
result_output = await test_agent_with_output.run('test')
288289

289290
assert result_output.usage() == snapshot(RunUsage(requests=2, input_tokens=103, output_tokens=15, tool_calls=1))
291+
292+
293+
async def test_failed_tool_calls_not_counted() -> None:
294+
"""Test that failed tool calls (raising ModelRetry) are not counted."""
295+
test_agent = Agent(TestModel())
296+
297+
call_count = 0
298+
299+
@test_agent.tool_plain
300+
async def flaky_tool(x: str) -> str:
301+
nonlocal call_count
302+
call_count += 1
303+
if call_count == 1:
304+
raise ModelRetry('Temporary failure, please retry')
305+
return f'{x}-success'
306+
307+
result = await test_agent.run('test')
308+
# The tool was called twice (1 failure + 1 success), but only the successful call should be counted
309+
assert call_count == 2
310+
assert result.usage() == snapshot(RunUsage(requests=3, input_tokens=176, output_tokens=29, tool_calls=1))

0 commit comments

Comments
 (0)