1
1
from __future__ import annotations as _annotations
2
2
3
3
from abc import ABC , abstractmethod
4
- from collections .abc import AsyncIterator , Callable
4
+ from collections .abc import AsyncIterator , Awaitable , Callable
5
5
from dataclasses import dataclass , field
6
6
from datetime import datetime
7
7
from typing import Generic , TypeVar , cast
@@ -122,7 +122,8 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
122
122
_result_schema : _result .ResultSchema [ResultData ] | None
123
123
_deps : AgentDeps
124
124
_result_validators : list [_result .ResultValidator [AgentDeps , ResultData ]]
125
- _on_complete : Callable [[list [_messages .ModelMessage ]], None ]
125
+ _result_tool_name : str | None
126
+ _on_complete : Callable [[], Awaitable [None ]]
126
127
is_complete : bool = field (default = False , init = False )
127
128
"""Whether the stream has all been received.
128
129
@@ -205,7 +206,7 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None =
205
206
combined = await self ._validate_text_result ('' .join (chunks ))
206
207
yield combined
207
208
lf_span .set_attribute ('combined_text' , combined )
208
- self ._marked_completed (_messages .ModelResponse .from_text (combined ))
209
+ await self ._marked_completed (_messages .ModelResponse .from_text (combined ))
209
210
210
211
async def stream_structured (
211
212
self , * , debounce_by : float | None = 0.1
@@ -244,7 +245,7 @@ async def stream_structured(
244
245
msg = self ._stream_response .get (final = True )
245
246
yield msg , True
246
247
lf_span .set_attribute ('structured_response' , msg )
247
- self ._marked_completed (msg )
248
+ await self ._marked_completed (msg )
248
249
249
250
async def get_data (self ) -> ResultData :
250
251
"""Stream the whole response, validate and return it."""
@@ -253,11 +254,11 @@ async def get_data(self) -> ResultData:
253
254
if isinstance (self ._stream_response , models .StreamTextResponse ):
254
255
text = '' .join (self ._stream_response .get (final = True ))
255
256
text = await self ._validate_text_result (text )
256
- self ._marked_completed (_messages .ModelResponse .from_text (text ))
257
+ await self ._marked_completed (_messages .ModelResponse .from_text (text ))
257
258
return cast (ResultData , text )
258
259
else :
259
260
message = self ._stream_response .get (final = True )
260
- self ._marked_completed (message )
261
+ await self ._marked_completed (message )
261
262
return await self .validate_structured_result (message )
262
263
263
264
@property
@@ -282,7 +283,8 @@ async def validate_structured_result(
282
283
) -> ResultData :
283
284
"""Validate a structured result message."""
284
285
assert self ._result_schema is not None , 'Expected _result_schema to not be None'
285
- match = self ._result_schema .find_tool (message )
286
+ assert self ._result_tool_name is not None , 'Expected _result_tool_name to not be None'
287
+ match = self ._result_schema .find_named_tool (message .parts , self ._result_tool_name )
286
288
if match is None :
287
289
raise exceptions .UnexpectedModelBehavior (
288
290
f'Invalid message, unable to find tool: { self ._result_schema .tool_names ()} '
@@ -306,7 +308,7 @@ async def _validate_text_result(self, text: str) -> str:
306
308
)
307
309
return text
308
310
309
- def _marked_completed (self , message : _messages .ModelResponse ) -> None :
311
+ async def _marked_completed (self , message : _messages .ModelResponse ) -> None :
310
312
self .is_complete = True
311
313
self ._all_messages .append (message )
312
- self ._on_complete (self . _all_messages )
314
+ await self ._on_complete ()
0 commit comments