Skip to content

Commit 79cfbfd

Browse files
DouweMKRRT7
authored andcommitted
Reduce duplication between StreamedRunResult and AgentStream (pydantic#2275)
1 parent 2d7c850 commit 79cfbfd

File tree

2 files changed

+124
-180
lines changed

2 files changed

+124
-180
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
3737
from .output import OutputDataT, OutputSpec
3838
from .profiles import ModelProfile
39-
from .result import FinalResult, StreamedRunResult
39+
from .result import AgentStream, FinalResult, StreamedRunResult
4040
from .settings import ModelSettings, merge_model_settings
4141
from .tools import (
4242
AgentDepsT,
@@ -1127,29 +1127,15 @@ async def main():
11271127
while True:
11281128
if self.is_model_request_node(node):
11291129
graph_ctx = agent_run.ctx
1130-
async with node._stream(graph_ctx) as streamed_response: # pyright: ignore[reportPrivateUsage]
1131-
1132-
async def stream_to_final(
1133-
s: models.StreamedResponse,
1134-
) -> FinalResult[models.StreamedResponse] | None:
1135-
output_schema = graph_ctx.deps.output_schema
1136-
async for maybe_part_event in streamed_response:
1137-
if isinstance(maybe_part_event, _messages.PartStartEvent):
1138-
new_part = maybe_part_event.part
1139-
if isinstance(new_part, _messages.TextPart) and isinstance(
1140-
output_schema, _output.TextOutputSchema
1141-
):
1142-
return FinalResult(s, None, None)
1143-
elif isinstance(new_part, _messages.ToolCallPart) and (
1144-
tool_def := graph_ctx.deps.tool_manager.get_tool_def(new_part.tool_name)
1145-
):
1146-
if tool_def.kind == 'output':
1147-
return FinalResult(s, new_part.tool_name, new_part.tool_call_id)
1148-
elif tool_def.kind == 'deferred':
1149-
return FinalResult(s, None, None)
1130+
async with node.stream(graph_ctx) as stream:
1131+
1132+
async def stream_to_final(s: AgentStream) -> FinalResult[AgentStream] | None:
1133+
async for event in stream:
1134+
if isinstance(event, _messages.FinalResultEvent):
1135+
return FinalResult(s, event.tool_name, event.tool_call_id)
11501136
return None
11511137

1152-
final_result = await stream_to_final(streamed_response)
1138+
final_result = await stream_to_final(stream)
11531139
if final_result is not None:
11541140
if yielded:
11551141
raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover
@@ -1184,14 +1170,8 @@ async def on_complete() -> None:
11841170
yield StreamedRunResult(
11851171
messages,
11861172
graph_ctx.deps.new_message_index,
1187-
graph_ctx.deps.usage_limits,
1188-
streamed_response,
1189-
graph_ctx.deps.output_schema,
1190-
_agent_graph.build_run_context(graph_ctx),
1191-
graph_ctx.deps.output_validators,
1192-
final_result.tool_name,
1173+
stream,
11931174
on_complete,
1194-
graph_ctx.deps.tool_manager,
11951175
)
11961176
break
11971177
next_node = await agent_run.next(node)

0 commit comments

Comments
 (0)