Skip to content

Commit f9f1c03

Browse files
feat: Add output function tracing (#2191)
Co-authored-by: Alex Hall <[email protected]>
1 parent c94cc03 commit f9f1c03

File tree

5 files changed

+852
-15
lines changed

5 files changed

+852
-15
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ async def stream(
341341
ctx.deps.output_schema,
342342
ctx.deps.output_validators,
343343
build_run_context(ctx),
344+
_output.build_trace_context(ctx),
344345
ctx.deps.usage_limits,
345346
)
346347
yield agent_stream
@@ -529,7 +530,8 @@ async def _handle_tool_calls(
529530
if isinstance(output_schema, _output.ToolOutputSchema):
530531
for call, output_tool in output_schema.find_tool(tool_calls):
531532
try:
532-
result_data = await output_tool.process(call, run_context)
533+
trace_context = _output.build_trace_context(ctx)
534+
result_data = await output_tool.process(call, run_context, trace_context)
533535
result_data = await _validate_output(result_data, ctx, call)
534536
except _output.ToolRetryError as e:
535537
# TODO: Should only increment retry stuff once per node execution, not for each tool call
@@ -586,7 +588,8 @@ async def _handle_text_response(
586588
try:
587589
if isinstance(output_schema, _output.TextOutputSchema):
588590
run_context = build_run_context(ctx)
589-
result_data = await output_schema.process(text, run_context)
591+
trace_context = _output.build_trace_context(ctx)
592+
result_data = await output_schema.process(text, run_context, trace_context)
590593
else:
591594
m = _messages.RetryPromptPart(
592595
content='Plain text responses are not permitted, please include your response in a tool call',

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 115 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
from __future__ import annotations as _annotations
22

3+
import dataclasses
34
import inspect
45
import json
56
from abc import ABC, abstractmethod
67
from collections.abc import Awaitable, Iterable, Iterator, Sequence
78
from dataclasses import dataclass, field
89
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload
910

11+
from opentelemetry.trace import Tracer
1012
from pydantic import TypeAdapter, ValidationError
1113
from pydantic_core import SchemaValidator
1214
from typing_extensions import TypedDict, TypeVar, assert_never
1315

16+
from pydantic_graph.nodes import GraphRunContext
17+
1418
from . import _function_schema, _utils, messages as _messages
1519
from ._run_context import AgentDepsT, RunContext
1620
from .exceptions import ModelRetry, UserError
@@ -29,6 +33,8 @@
2933
from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition
3034

3135
if TYPE_CHECKING:
36+
from pydantic_ai._agent_graph import DepsT, GraphAgentDeps, GraphAgentState
37+
3238
from .profiles import ModelProfile
3339

3440
T = TypeVar('T')
@@ -66,6 +72,71 @@
6672
DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation'
6773

6874

75+
@dataclass(frozen=True)
76+
class TraceContext:
77+
"""A context for tracing output processing."""
78+
79+
tracer: Tracer
80+
include_content: bool
81+
call: _messages.ToolCallPart | None = None
82+
83+
def with_call(self, call: _messages.ToolCallPart):
84+
return dataclasses.replace(self, call=call)
85+
86+
async def execute_function_with_span(
87+
self,
88+
function_schema: _function_schema.FunctionSchema,
89+
run_context: RunContext[AgentDepsT],
90+
args: dict[str, Any] | Any,
91+
call: _messages.ToolCallPart,
92+
include_tool_call_id: bool = True,
93+
) -> Any:
94+
"""Execute a function call within a traced span, automatically recording the response."""
95+
# Set up span attributes
96+
attributes = {
97+
'gen_ai.tool.name': call.tool_name,
98+
'logfire.msg': f'running output function: {call.tool_name}',
99+
}
100+
if include_tool_call_id:
101+
attributes['gen_ai.tool.call.id'] = call.tool_call_id
102+
if self.include_content:
103+
attributes['tool_arguments'] = call.args_as_json_str()
104+
attributes['logfire.json_schema'] = json.dumps(
105+
{
106+
'type': 'object',
107+
'properties': {
108+
'tool_arguments': {'type': 'object'},
109+
'tool_response': {'type': 'object'},
110+
},
111+
}
112+
)
113+
114+
# Execute function within span
115+
with self.tracer.start_as_current_span('running output function', attributes=attributes) as span:
116+
output = await function_schema.call(args, run_context)
117+
118+
# Record response if content inclusion is enabled
119+
if self.include_content and span.is_recording():
120+
from .models.instrumented import InstrumentedModel
121+
122+
span.set_attribute(
123+
'tool_response',
124+
output if isinstance(output, str) else json.dumps(InstrumentedModel.serialize_any(output)),
125+
)
126+
127+
return output
128+
129+
130+
def build_trace_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> TraceContext:
131+
"""Build a `TraceContext` from the current agent graph run context."""
132+
return TraceContext(
133+
tracer=ctx.deps.tracer,
134+
include_content=(
135+
ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content
136+
),
137+
)
138+
139+
69140
class ToolRetryError(Exception):
70141
"""Exception used to signal a `ToolRetry` message should be returned to the LLM."""
71142

@@ -96,6 +167,7 @@ async def validate(
96167
result: The result data after Pydantic validation the message content.
97168
tool_call: The original tool call message, `None` if there was no tool call.
98169
run_context: The current run context.
170+
trace_context: The trace context to use for tracing the output processing.
99171
100172
Returns:
101173
Result of either the validated result data (ok) or a retry message (Err).
@@ -349,6 +421,7 @@ async def process(
349421
self,
350422
text: str,
351423
run_context: RunContext[AgentDepsT],
424+
trace_context: TraceContext,
352425
allow_partial: bool = False,
353426
wrap_validation_errors: bool = True,
354427
) -> OutputDataT:
@@ -371,6 +444,7 @@ async def process(
371444
self,
372445
text: str,
373446
run_context: RunContext[AgentDepsT],
447+
trace_context: TraceContext,
374448
allow_partial: bool = False,
375449
wrap_validation_errors: bool = True,
376450
) -> OutputDataT:
@@ -379,6 +453,7 @@ async def process(
379453
Args:
380454
text: The output text to validate.
381455
run_context: The current run context.
456+
trace_context: The trace context to use for tracing the output processing.
382457
allow_partial: If true, allow partial validation.
383458
wrap_validation_errors: If true, wrap the validation errors in a retry message.
384459
@@ -389,7 +464,7 @@ async def process(
389464
return cast(OutputDataT, text)
390465

391466
return await self.processor.process(
392-
text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
467+
text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
393468
)
394469

395470

@@ -417,6 +492,7 @@ async def process(
417492
self,
418493
text: str,
419494
run_context: RunContext[AgentDepsT],
495+
trace_context: TraceContext,
420496
allow_partial: bool = False,
421497
wrap_validation_errors: bool = True,
422498
) -> OutputDataT:
@@ -425,14 +501,15 @@ async def process(
425501
Args:
426502
text: The output text to validate.
427503
run_context: The current run context.
504+
trace_context: The trace context to use for tracing the output processing.
428505
allow_partial: If true, allow partial validation.
429506
wrap_validation_errors: If true, wrap the validation errors in a retry message.
430507
431508
Returns:
432509
Either the validated output data (left) or a retry message (right).
433510
"""
434511
return await self.processor.process(
435-
text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
512+
text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
436513
)
437514

438515

@@ -468,6 +545,7 @@ async def process(
468545
self,
469546
text: str,
470547
run_context: RunContext[AgentDepsT],
548+
trace_context: TraceContext,
471549
allow_partial: bool = False,
472550
wrap_validation_errors: bool = True,
473551
) -> OutputDataT:
@@ -476,6 +554,7 @@ async def process(
476554
Args:
477555
text: The output text to validate.
478556
run_context: The current run context.
557+
trace_context: The trace context to use for tracing the output processing.
479558
allow_partial: If true, allow partial validation.
480559
wrap_validation_errors: If true, wrap the validation errors in a retry message.
481560
@@ -485,7 +564,7 @@ async def process(
485564
text = _utils.strip_markdown_fences(text)
486565

487566
return await self.processor.process(
488-
text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
567+
text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
489568
)
490569

491570

@@ -568,6 +647,7 @@ async def process(
568647
self,
569648
data: str,
570649
run_context: RunContext[AgentDepsT],
650+
trace_context: TraceContext,
571651
allow_partial: bool = False,
572652
wrap_validation_errors: bool = True,
573653
) -> OutputDataT:
@@ -637,6 +717,7 @@ async def process(
637717
self,
638718
data: str | dict[str, Any] | None,
639719
run_context: RunContext[AgentDepsT],
720+
trace_context: TraceContext,
640721
allow_partial: bool = False,
641722
wrap_validation_errors: bool = True,
642723
) -> OutputDataT:
@@ -645,6 +726,7 @@ async def process(
645726
Args:
646727
data: The output data to validate.
647728
run_context: The current run context.
729+
trace_context: The trace context to use for tracing the output processing.
648730
allow_partial: If true, allow partial validation.
649731
wrap_validation_errors: If true, wrap the validation errors in a retry message.
650732
@@ -670,8 +752,18 @@ async def process(
670752
output = output[k]
671753

672754
if self._function_schema:
755+
# Wraps the output function call in an OpenTelemetry span.
756+
if trace_context.call:
757+
call = trace_context.call
758+
include_tool_call_id = True
759+
else:
760+
function_name = getattr(self._function_schema.function, '__name__', 'output_function')
761+
call = _messages.ToolCallPart(tool_name=function_name, args=data)
762+
include_tool_call_id = False
673763
try:
674-
output = await self._function_schema.call(output, run_context)
764+
output = await trace_context.execute_function_with_span(
765+
self._function_schema, run_context, output, call, include_tool_call_id
766+
)
675767
except ModelRetry as r:
676768
if wrap_validation_errors:
677769
m = _messages.RetryPromptPart(
@@ -784,11 +876,12 @@ async def process(
784876
self,
785877
data: str | dict[str, Any] | None,
786878
run_context: RunContext[AgentDepsT],
879+
trace_context: TraceContext,
787880
allow_partial: bool = False,
788881
wrap_validation_errors: bool = True,
789882
) -> OutputDataT:
790883
union_object = await self._union_processor.process(
791-
data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
884+
data, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
792885
)
793886

794887
result = union_object.result
@@ -804,7 +897,7 @@ async def process(
804897
raise
805898

806899
return await processor.process(
807-
data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
900+
data, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
808901
)
809902

810903

@@ -835,13 +928,20 @@ async def process(
835928
self,
836929
data: str,
837930
run_context: RunContext[AgentDepsT],
931+
trace_context: TraceContext,
838932
allow_partial: bool = False,
839933
wrap_validation_errors: bool = True,
840934
) -> OutputDataT:
841935
args = {self._str_argument_name: data}
842-
936+
# Wraps the output function call in an OpenTelemetry span.
937+
# Note: PlainTextOutputProcessor is used for text responses (not tool calls),
938+
# so we don't have tool call attributes like gen_ai.tool.name or gen_ai.tool.call.id
939+
function_name = getattr(self._function_schema.function, '__name__', 'text_output_function')
940+
call = _messages.ToolCallPart(tool_name=function_name, args=args)
843941
try:
844-
output = await self._function_schema.call(args, run_context)
942+
output = await trace_context.execute_function_with_span(
943+
self._function_schema, run_context, args, call, include_tool_call_id=False
944+
)
845945
except ModelRetry as r:
846946
if wrap_validation_errors:
847947
m = _messages.RetryPromptPart(
@@ -881,6 +981,7 @@ async def process(
881981
self,
882982
tool_call: _messages.ToolCallPart,
883983
run_context: RunContext[AgentDepsT],
984+
trace_context: TraceContext,
884985
allow_partial: bool = False,
885986
wrap_validation_errors: bool = True,
886987
) -> OutputDataT:
@@ -889,6 +990,7 @@ async def process(
889990
Args:
890991
tool_call: The tool call from the LLM to validate.
891992
run_context: The current run context.
993+
trace_context: The trace context to use for tracing the output processing.
892994
allow_partial: If true, allow partial validation.
893995
wrap_validation_errors: If true, wrap the validation errors in a retry message.
894996
@@ -897,7 +999,11 @@ async def process(
897999
"""
8981000
try:
8991001
output = await self.processor.process(
900-
tool_call.args, run_context, allow_partial=allow_partial, wrap_validation_errors=False
1002+
tool_call.args,
1003+
run_context,
1004+
trace_context.with_call(tool_call),
1005+
allow_partial=allow_partial,
1006+
wrap_validation_errors=False,
9011007
)
9021008
except ValidationError as e:
9031009
if wrap_validation_errors:

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,7 @@ async def on_complete() -> None:
10891089
streamed_response,
10901090
graph_ctx.deps.output_schema,
10911091
_agent_graph.build_run_context(graph_ctx),
1092+
_output.build_trace_context(graph_ctx),
10921093
graph_ctx.deps.output_validators,
10931094
final_result_details.tool_name,
10941095
on_complete,

0 commit comments

Comments
 (0)