Skip to content

Commit f1cd425

Browse files
committed
trace the output tool call but not count it
1 parent 53962e5 commit f1cd425

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -107,16 +107,7 @@ async def handle_call(
107107
if self.tools is None or self.ctx is None:
108108
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
109109

110-
if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output':
111-
# Output tool calls are not traced and not counted
112-
return await self._call_tool(
113-
call,
114-
allow_partial=allow_partial,
115-
wrap_validation_errors=wrap_validation_errors,
116-
approved=approved,
117-
)
118-
else:
119-
return await self._call_function_tool(
110+
return await self._call_function_tool(
120111
call,
121112
allow_partial=allow_partial,
122113
wrap_validation_errors=wrap_validation_errors,
@@ -217,12 +208,19 @@ async def _call_function_tool(
217208
"""See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>."""
218209
instrumentation_names = InstrumentationNames.for_version(instrumentation_version)
219210

211+
if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output':
212+
tool_name: str = 'Output Tool'
213+
output_tool_flag = True
214+
else:
215+
tool_name: str = call.tool_name
216+
output_tool_flag = False
217+
220218
span_attributes = {
221-
'gen_ai.tool.name': call.tool_name,
219+
'gen_ai.tool.name': tool_name,
222220
# NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai
223221
'gen_ai.tool.call.id': call.tool_call_id,
224222
**({instrumentation_names.tool_arguments_attr: call.args_as_json_str()} if include_content else {}),
225-
'logfire.msg': f'running tool: {call.tool_name}',
223+
'logfire.msg': f'running tool: {tool_name}',
226224
# add the JSON schema so these attributes are formatted nicely in Logfire
227225
'logfire.json_schema': json.dumps(
228226
{
@@ -243,7 +241,7 @@ async def _call_function_tool(
243241
),
244242
}
245243
with tracer.start_as_current_span(
246-
instrumentation_names.get_tool_span_name(call.tool_name),
244+
instrumentation_names.get_tool_span_name(tool_name),
247245
attributes=span_attributes,
248246
) as span:
249247
try:
@@ -253,7 +251,9 @@ async def _call_function_tool(
253251
wrap_validation_errors=wrap_validation_errors,
254252
approved=approved,
255253
)
256-
usage.tool_calls += 1
254+
if not output_tool_flag:
255+
# Output tool calls are not counted
256+
usage.tool_calls += 1
257257

258258
except ToolRetryError as e:
259259
part = e.tool_retry

0 commit comments

Comments
 (0)