Skip to content

Commit ea2bbc5

Browse files
authored
Record agent run attributes in case of streaming and exception (#1610)
1 parent 0d74e2c commit ea2bbc5

File tree

4 files changed

+133
-44
lines changed

4 files changed

+133
-44
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22

33
import asyncio
44
import dataclasses
5-
import json
65
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
76
from contextlib import asynccontextmanager, contextmanager
87
from contextvars import ContextVar
98
from dataclasses import field
109
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
1110

12-
from opentelemetry.trace import Span, Tracer
11+
from opentelemetry.trace import Tracer
1312
from typing_extensions import TypeGuard, TypeVar, assert_never
1413

1514
from pydantic_graph import BaseNode, Graph, GraphRunContext
@@ -24,7 +23,6 @@
2423
result,
2524
usage as _usage,
2625
)
27-
from .models.instrumented import InstrumentedModel
2826
from .result import OutputDataT, ToolOutput
2927
from .settings import ModelSettings, merge_model_settings
3028
from .tools import RunContext, Tool, ToolDefinition
@@ -95,7 +93,6 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
9593
function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
9694
mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
9795

98-
run_span: Span
9996
tracer: Tracer
10097

10198

@@ -498,39 +495,12 @@ def _handle_final_result(
498495
final_result: result.FinalResult[NodeRunEndT],
499496
tool_responses: list[_messages.ModelRequestPart],
500497
) -> End[result.FinalResult[NodeRunEndT]]:
501-
run_span = ctx.deps.run_span
502-
usage = ctx.state.usage
503498
messages = ctx.state.message_history
504499

505500
# For backwards compatibility, append a new ModelRequest using the tool returns and retries
506501
if tool_responses:
507502
messages.append(_messages.ModelRequest(parts=tool_responses))
508503

509-
run_span.set_attributes(
510-
{
511-
**usage.opentelemetry_attributes(),
512-
'all_messages_events': json.dumps(
513-
[InstrumentedModel.event_to_dict(e) for e in InstrumentedModel.messages_to_otel_events(messages)]
514-
),
515-
'final_result': final_result.output
516-
if isinstance(final_result.output, str)
517-
else json.dumps(InstrumentedModel.serialize_any(final_result.output)),
518-
}
519-
)
520-
run_span.set_attributes(
521-
{
522-
'logfire.json_schema': json.dumps(
523-
{
524-
'type': 'object',
525-
'properties': {
526-
'all_messages_events': {'type': 'array'},
527-
'final_result': {'type': 'object'},
528-
},
529-
}
530-
),
531-
}
532-
)
533-
534504
return End(final_result)
535505

536506
async def _handle_text_response(

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import dataclasses
44
import inspect
5+
import json
56
import warnings
67
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
78
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
@@ -600,9 +601,10 @@ async def main():
600601
)
601602

602603
# Build the initial state
604+
usage = usage or _usage.Usage()
603605
state = _agent_graph.GraphAgentState(
604606
message_history=message_history[:] if message_history else [],
605-
usage=usage or _usage.Usage(),
607+
usage=usage,
606608
retries=0,
607609
run_step=0,
608610
)
@@ -656,7 +658,6 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
656658
output_validators=output_validators,
657659
function_tools=self._function_tools,
658660
mcp_servers=self._mcp_servers,
659-
run_span=run_span,
660661
tracer=tracer,
661662
get_instructions=get_instructions,
662663
)
@@ -669,14 +670,51 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
669670
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
670671
)
671672

672-
async with graph.iter(
673-
start_node,
674-
state=state,
675-
deps=graph_deps,
676-
span=use_span(run_span, end_on_exit=True) if run_span.is_recording() else None,
677-
infer_name=False,
678-
) as graph_run:
679-
yield AgentRun(graph_run)
673+
try:
674+
async with graph.iter(
675+
start_node,
676+
state=state,
677+
deps=graph_deps,
678+
span=use_span(run_span) if run_span.is_recording() else None,
679+
infer_name=False,
680+
) as graph_run:
681+
agent_run = AgentRun(graph_run)
682+
yield agent_run
683+
if (final_result := agent_run.result) is not None and run_span.is_recording():
684+
run_span.set_attribute(
685+
'final_result',
686+
(
687+
final_result.output
688+
if isinstance(final_result.output, str)
689+
else json.dumps(InstrumentedModel.serialize_any(final_result.output))
690+
),
691+
)
692+
finally:
693+
try:
694+
if run_span.is_recording():
695+
run_span.set_attributes(self._run_span_end_attributes(state, usage))
696+
finally:
697+
run_span.end()
698+
699+
def _run_span_end_attributes(self, state: _agent_graph.GraphAgentState, usage: _usage.Usage):
700+
return {
701+
**usage.opentelemetry_attributes(),
702+
'all_messages_events': json.dumps(
703+
[
704+
InstrumentedModel.event_to_dict(e)
705+
for e in InstrumentedModel.messages_to_otel_events(state.message_history)
706+
]
707+
),
708+
'logfire.json_schema': json.dumps(
709+
{
710+
'type': 'object',
711+
'properties': {
712+
'all_messages_events': {'type': 'array'},
713+
'final_result': {'type': 'object'},
714+
},
715+
}
716+
),
717+
}
680718

681719
@overload
682720
def run_sync(

tests/models/test_fallback.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,10 @@ async def test_first_failed_instrumented_stream(capfire: CaptureLogfire) -> None
216216
'agent_name': 'agent',
217217
'logfire.msg': 'agent run',
218218
'logfire.span_type': 'span',
219+
'gen_ai.usage.input_tokens': 50,
220+
'gen_ai.usage.output_tokens': 2,
221+
'all_messages_events': '[{"content": "input", "role": "user", "gen_ai.message.index": 0, "event.name": "gen_ai.user.message"}, {"role": "assistant", "content": "hello world", "gen_ai.message.index": 1, "event.name": "gen_ai.assistant.message"}]',
222+
'logfire.json_schema': '{"type": "object", "properties": {"all_messages_events": {"type": "array"}, "final_result": {"type": "object"}}}',
219223
},
220224
},
221225
]
@@ -236,6 +240,82 @@ def test_all_failed() -> None:
236240
assert exceptions[0].body == {'error': 'test error'}
237241

238242

243+
@pytest.mark.skipif(not logfire_imports_successful(), reason='logfire not installed')
244+
def test_all_failed_instrumented(capfire: CaptureLogfire) -> None:
245+
fallback_model = FallbackModel(failure_model, failure_model)
246+
agent = Agent(model=fallback_model, instrument=True)
247+
with pytest.raises(ExceptionGroup) as exc_info:
248+
agent.run_sync('hello')
249+
assert 'All models from FallbackModel failed' in exc_info.value.args[0]
250+
exceptions = exc_info.value.exceptions
251+
assert len(exceptions) == 2
252+
assert isinstance(exceptions[0], ModelHTTPError)
253+
assert exceptions[0].status_code == 500
254+
assert exceptions[0].model_name == 'test-function-model'
255+
assert exceptions[0].body == {'error': 'test error'}
256+
assert capfire.exporter.exported_spans_as_dict() == snapshot(
257+
[
258+
{
259+
'name': 'chat fallback:function:failure_response:,function:failure_response:',
260+
'context': {'trace_id': 1, 'span_id': 3, 'is_remote': False},
261+
'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
262+
'start_time': 2000000000,
263+
'end_time': 4000000000,
264+
'attributes': {
265+
'gen_ai.operation.name': 'chat',
266+
'gen_ai.system': 'fallback:function,function',
267+
'gen_ai.request.model': 'fallback:function:failure_response:,function:failure_response:',
268+
'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}',
269+
'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}',
270+
'logfire.span_type': 'span',
271+
'logfire.msg': 'chat fallback:function:failure_response:,function:failure_response:',
272+
'logfire.level_num': 17,
273+
},
274+
'events': [
275+
{
276+
'name': 'exception',
277+
'timestamp': 3000000000,
278+
'attributes': {
279+
'exception.type': 'pydantic_ai.exceptions.FallbackExceptionGroup',
280+
'exception.message': 'All models from FallbackModel failed (2 sub-exceptions)',
281+
'exception.stacktrace': '+------------------------------------',
282+
'exception.escaped': 'False',
283+
},
284+
}
285+
],
286+
},
287+
{
288+
'name': 'agent run',
289+
'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
290+
'parent': None,
291+
'start_time': 1000000000,
292+
'end_time': 6000000000,
293+
'attributes': {
294+
'model_name': 'fallback:function:failure_response:,function:failure_response:',
295+
'agent_name': 'agent',
296+
'logfire.msg': 'agent run',
297+
'logfire.span_type': 'span',
298+
'all_messages_events': '[{"content": "hello", "role": "user", "gen_ai.message.index": 0, "event.name": "gen_ai.user.message"}]',
299+
'logfire.json_schema': '{"type": "object", "properties": {"all_messages_events": {"type": "array"}, "final_result": {"type": "object"}}}',
300+
'logfire.level_num': 17,
301+
},
302+
'events': [
303+
{
304+
'name': 'exception',
305+
'timestamp': 5000000000,
306+
'attributes': {
307+
'exception.type': 'pydantic_ai.exceptions.FallbackExceptionGroup',
308+
'exception.message': 'All models from FallbackModel failed (2 sub-exceptions)',
309+
'exception.stacktrace': '+------------------------------------',
310+
'exception.escaped': 'False',
311+
},
312+
}
313+
],
314+
},
315+
]
316+
)
317+
318+
239319
async def success_response_stream(_model_messages: list[ModelMessage], _agent_info: AgentInfo) -> AsyncIterator[str]:
240320
yield 'hello '
241321
yield 'world'

uv.lock

Lines changed: 4 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)