Skip to content

Commit 33e2bf8

Browse files
authored
Include final_result agent span attribute after streaming (#3170)
1 parent 842b7a8 commit 33e2bf8

File tree

4 files changed

+192
-0
lines changed

4 files changed

+192
-0
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,18 @@ def _handle_final_result(
710710
__repr__ = dataclasses_no_defaults_repr
711711

712712

713+
@dataclasses.dataclass
714+
class SetFinalResult(AgentNode[DepsT, NodeRunEndT]):
715+
"""A node that immediately ends the graph run after a streaming response produced a final result."""
716+
717+
final_result: result.FinalResult[NodeRunEndT]
718+
719+
async def run(
720+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
721+
) -> End[result.FinalResult[NodeRunEndT]]:
722+
return End(self.final_result)
723+
724+
713725
def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
714726
"""Build a `RunContext` object from the current agent graph run context."""
715727
return RunContext[DepsT](
@@ -1123,6 +1135,7 @@ def build_agent_graph(
11231135
UserPromptNode[DepsT],
11241136
ModelRequestNode[DepsT],
11251137
CallToolsNode[DepsT],
1138+
SetFinalResult[DepsT],
11261139
)
11271140
graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[OutputT]](
11281141
nodes=nodes,

pydantic_ai_slim/pydantic_ai/agent/abstract.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,14 @@ async def on_complete() -> None:
524524
await stream.get_output(), final_result.tool_name, final_result.tool_call_id
525525
)
526526

527+
# When we get here, the `ModelRequestNode` has completed streaming after the final result was found.
528+
# When running an agent with `agent.run`, we'd then move to `CallToolsNode` to execute the tool calls and
529+
# find the final result.
530+
# We also want to execute tool calls (in case `agent.end_strategy == 'exhaustive'`) here, but
531+
# we don't want to use run the `CallToolsNode` logic to determine the final output, as it would be
532+
# wasteful and could produce a different result (e.g. when text output is followed by tool calls).
533+
# So we call `process_tool_calls` directly and then end the run with the found final result.
534+
527535
parts: list[_messages.ModelRequestPart] = []
528536
async for _event in _agent_graph.process_tool_calls(
529537
tool_manager=graph_ctx.deps.tool_manager,
@@ -534,9 +542,13 @@ async def on_complete() -> None:
534542
output_parts=parts,
535543
):
536544
pass
545+
546+
# For backwards compatibility, append a new ModelRequest using the tool returns and retries
537547
if parts:
538548
messages.append(_messages.ModelRequest(parts))
539549

550+
await agent_run.next(_agent_graph.SetFinalResult(final_result))
551+
540552
yield StreamedRunResult(
541553
messages,
542554
graph_ctx.deps.new_message_index,

tests/models/test_fallback.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ async def test_first_failed_instrumented_stream(capfire: CaptureLogfire) -> None
274274
'gen_ai.agent.name': 'agent',
275275
'logfire.msg': 'agent run',
276276
'logfire.span_type': 'span',
277+
'final_result': 'hello world',
277278
'gen_ai.usage.input_tokens': 50,
278279
'gen_ai.usage.output_tokens': 2,
279280
'pydantic_ai.all_messages': [

tests/test_logfire.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2890,3 +2890,169 @@ def instructions(ctx: RunContext[None]):
28902890
]
28912891
)
28922892
)
2893+
2894+
2895+
@pytest.mark.skipif(not logfire_installed, reason='logfire not installed')
2896+
@pytest.mark.parametrize(
2897+
'instrument',
2898+
[InstrumentationSettings(version=1), InstrumentationSettings(version=2), InstrumentationSettings(version=3)],
2899+
)
2900+
async def test_run_stream(
2901+
get_logfire_summary: Callable[[], LogfireSummary], instrument: InstrumentationSettings
2902+
) -> None:
2903+
my_agent = Agent(model=TestModel(), instrument=instrument)
2904+
2905+
@my_agent.instructions
2906+
def instructions(ctx: RunContext[None]):
2907+
return 'Instructions for the current agent run'
2908+
2909+
async with my_agent.run_stream('Hello') as stream:
2910+
async for _ in stream.stream_output():
2911+
pass
2912+
2913+
summary = get_logfire_summary()
2914+
chat_span_attributes = summary.attributes[1]
2915+
if instrument.version == 1:
2916+
assert summary.attributes[0] == snapshot(
2917+
{
2918+
'model_name': 'test',
2919+
'agent_name': 'my_agent',
2920+
'gen_ai.agent.name': 'my_agent',
2921+
'logfire.msg': 'my_agent run',
2922+
'logfire.span_type': 'span',
2923+
'final_result': 'success (no tool calls)',
2924+
'gen_ai.usage.input_tokens': 51,
2925+
'gen_ai.usage.output_tokens': 4,
2926+
'all_messages_events': IsJson(
2927+
snapshot(
2928+
[
2929+
{
2930+
'content': 'Instructions for the current agent run',
2931+
'role': 'system',
2932+
'event.name': 'gen_ai.system.message',
2933+
},
2934+
{
2935+
'content': 'Hello',
2936+
'role': 'user',
2937+
'gen_ai.message.index': 0,
2938+
'event.name': 'gen_ai.user.message',
2939+
},
2940+
{
2941+
'role': 'assistant',
2942+
'content': 'success (no tool calls)',
2943+
'gen_ai.message.index': 1,
2944+
'event.name': 'gen_ai.assistant.message',
2945+
},
2946+
]
2947+
)
2948+
),
2949+
'logfire.json_schema': IsJson(
2950+
snapshot(
2951+
{
2952+
'type': 'object',
2953+
'properties': {
2954+
'all_messages_events': {'type': 'array'},
2955+
'final_result': {'type': 'object'},
2956+
},
2957+
}
2958+
)
2959+
),
2960+
}
2961+
)
2962+
2963+
assert chat_span_attributes['events'] == IsJson(
2964+
snapshot(
2965+
[
2966+
{
2967+
'content': 'Instructions for the current agent run',
2968+
'role': 'system',
2969+
'gen_ai.system': 'test',
2970+
'event.name': 'gen_ai.system.message',
2971+
},
2972+
{
2973+
'content': 'Hello',
2974+
'role': 'user',
2975+
'gen_ai.system': 'test',
2976+
'gen_ai.message.index': 0,
2977+
'event.name': 'gen_ai.user.message',
2978+
},
2979+
{
2980+
'index': 0,
2981+
'message': {'role': 'assistant', 'content': 'success (no tool calls)'},
2982+
'gen_ai.system': 'test',
2983+
'event.name': 'gen_ai.choice',
2984+
},
2985+
]
2986+
)
2987+
)
2988+
else:
2989+
if instrument.version == 2:
2990+
assert summary.traces == snapshot(
2991+
[
2992+
{
2993+
'id': 0,
2994+
'name': 'agent run',
2995+
'message': 'my_agent run',
2996+
'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}],
2997+
}
2998+
]
2999+
)
3000+
else:
3001+
assert summary.traces == snapshot(
3002+
[
3003+
{
3004+
'id': 0,
3005+
'name': 'invoke_agent my_agent',
3006+
'message': 'my_agent run',
3007+
'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}],
3008+
}
3009+
]
3010+
)
3011+
3012+
assert summary.attributes[0] == snapshot(
3013+
{
3014+
'model_name': 'test',
3015+
'agent_name': 'my_agent',
3016+
'gen_ai.agent.name': 'my_agent',
3017+
'logfire.msg': 'my_agent run',
3018+
'logfire.span_type': 'span',
3019+
'final_result': 'success (no tool calls)',
3020+
'gen_ai.usage.input_tokens': 51,
3021+
'gen_ai.usage.output_tokens': 4,
3022+
'pydantic_ai.all_messages': IsJson(
3023+
snapshot(
3024+
[
3025+
{'role': 'user', 'parts': [{'type': 'text', 'content': 'Hello'}]},
3026+
{'role': 'assistant', 'parts': [{'type': 'text', 'content': 'success (no tool calls)'}]},
3027+
]
3028+
)
3029+
),
3030+
'gen_ai.system_instructions': '[{"type": "text", "content": "Instructions for the current agent run"}]',
3031+
'logfire.json_schema': IsJson(
3032+
snapshot(
3033+
{
3034+
'type': 'object',
3035+
'properties': {
3036+
'pydantic_ai.all_messages': {'type': 'array'},
3037+
'gen_ai.system_instructions': {'type': 'array'},
3038+
'final_result': {'type': 'object'},
3039+
},
3040+
}
3041+
)
3042+
),
3043+
}
3044+
)
3045+
3046+
assert chat_span_attributes['gen_ai.input.messages'] == IsJson(
3047+
snapshot([{'role': 'user', 'parts': [{'type': 'text', 'content': 'Hello'}]}])
3048+
)
3049+
assert chat_span_attributes['gen_ai.output.messages'] == IsJson(
3050+
snapshot(
3051+
[
3052+
{
3053+
'role': 'assistant',
3054+
'parts': [{'type': 'text', 'content': 'success (no tool calls)'}],
3055+
}
3056+
]
3057+
)
3058+
)

0 commit comments

Comments
 (0)