|
36 | 36 | from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
|
37 | 37 | from .output import OutputDataT, OutputSpec
|
38 | 38 | from .profiles import ModelProfile
|
39 |
| -from .result import FinalResult, StreamedRunResult |
| 39 | +from .result import AgentStream, FinalResult, StreamedRunResult |
40 | 40 | from .settings import ModelSettings, merge_model_settings
|
41 | 41 | from .tools import (
|
42 | 42 | AgentDepsT,
|
@@ -1127,29 +1127,15 @@ async def main():
|
1127 | 1127 | while True:
|
1128 | 1128 | if self.is_model_request_node(node):
|
1129 | 1129 | graph_ctx = agent_run.ctx
|
1130 |
| - async with node._stream(graph_ctx) as streamed_response: # pyright: ignore[reportPrivateUsage] |
1131 |
| - |
1132 |
| - async def stream_to_final( |
1133 |
| - s: models.StreamedResponse, |
1134 |
| - ) -> FinalResult[models.StreamedResponse] | None: |
1135 |
| - output_schema = graph_ctx.deps.output_schema |
1136 |
| - async for maybe_part_event in streamed_response: |
1137 |
| - if isinstance(maybe_part_event, _messages.PartStartEvent): |
1138 |
| - new_part = maybe_part_event.part |
1139 |
| - if isinstance(new_part, _messages.TextPart) and isinstance( |
1140 |
| - output_schema, _output.TextOutputSchema |
1141 |
| - ): |
1142 |
| - return FinalResult(s, None, None) |
1143 |
| - elif isinstance(new_part, _messages.ToolCallPart) and ( |
1144 |
| - tool_def := graph_ctx.deps.tool_manager.get_tool_def(new_part.tool_name) |
1145 |
| - ): |
1146 |
| - if tool_def.kind == 'output': |
1147 |
| - return FinalResult(s, new_part.tool_name, new_part.tool_call_id) |
1148 |
| - elif tool_def.kind == 'deferred': |
1149 |
| - return FinalResult(s, None, None) |
| 1130 | + async with node.stream(graph_ctx) as stream: |
| 1131 | + |
| 1132 | + async def stream_to_final(s: AgentStream) -> FinalResult[AgentStream] | None: |
| 1133 | + async for event in stream: |
| 1134 | + if isinstance(event, _messages.FinalResultEvent): |
| 1135 | + return FinalResult(s, event.tool_name, event.tool_call_id) |
1150 | 1136 | return None
|
1151 | 1137 |
|
1152 |
| - final_result = await stream_to_final(streamed_response) |
| 1138 | + final_result = await stream_to_final(stream) |
1153 | 1139 | if final_result is not None:
|
1154 | 1140 | if yielded:
|
1155 | 1141 | raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover
|
@@ -1184,14 +1170,8 @@ async def on_complete() -> None:
|
1184 | 1170 | yield StreamedRunResult(
|
1185 | 1171 | messages,
|
1186 | 1172 | graph_ctx.deps.new_message_index,
|
1187 |
| - graph_ctx.deps.usage_limits, |
1188 |
| - streamed_response, |
1189 |
| - graph_ctx.deps.output_schema, |
1190 |
| - _agent_graph.build_run_context(graph_ctx), |
1191 |
| - graph_ctx.deps.output_validators, |
1192 |
| - final_result.tool_name, |
| 1173 | + stream, |
1193 | 1174 | on_complete,
|
1194 |
| - graph_ctx.deps.tool_manager, |
1195 | 1175 | )
|
1196 | 1176 | break
|
1197 | 1177 | next_node = await agent_run.next(node)
|
|
0 commit comments