55import json
66import time
77import asyncio
8- from typing import Dict , List , Union , TypeVar , Optional , Generator , AsyncGenerator , overload
8+ from typing import Dict , List , Union , TypeVar , Iterator , Optional , AsyncIterator , overload
99from datetime import datetime
1010from typing_extensions import Literal
1111
@@ -417,14 +417,15 @@ def retrieve(
417417 def stream (
418418 self ,
419419 task_id : str ,
420+ structured_output_json : None | NotGiven = NOT_GIVEN ,
420421 * ,
421422 # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
422423 # The extra values given here take precedence over values defined on the client or passed to this method.
423424 extra_headers : Headers | None = None ,
424425 extra_query : Query | None = None ,
425426 extra_body : Body | None = None ,
426427 timeout : float | httpx .Timeout | None | NotGiven = NOT_GIVEN ,
427- ) -> Generator [TaskView , None ]: ...
428+ ) -> Iterator [TaskView ]: ...
428429
429430 @overload
430431 def stream (
@@ -438,20 +439,20 @@ def stream(
438439 extra_query : Query | None = None ,
439440 extra_body : Body | None = None ,
440441 timeout : float | httpx .Timeout | None | NotGiven = NOT_GIVEN ,
441- ) -> Generator [TaskViewWithOutput [T ], None ]: ...
442+ ) -> Iterator [TaskViewWithOutput [T ]]: ...
442443
443444 def stream (
444445 self ,
445446 task_id : str ,
446- structured_output_json : Optional [ type [BaseModel ]] | NotGiven = NOT_GIVEN ,
447+ structured_output_json : type [BaseModel ] | None | NotGiven = NOT_GIVEN ,
447448 * ,
448449 # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
449450 # The extra values given here take precedence over values defined on the client or passed to this method.
450451 extra_headers : Headers | None = None ,
451452 extra_query : Query | None = None ,
452453 extra_body : Body | None = None ,
453454 timeout : float | httpx .Timeout | None | NotGiven = NOT_GIVEN ,
454- ) -> Generator [Union [TaskView , TaskViewWithOutput [BaseModel ]], None ]:
455+ ) -> Iterator [Union [TaskView , TaskViewWithOutput [BaseModel ]]]:
455456 """
456457 Stream the task view as it is updated until the task is finished.
457458 """
@@ -491,7 +492,7 @@ def _watch(
491492 extra_query : Query | None = None ,
492493 extra_body : Body | None = None ,
493494 timeout : float | httpx .Timeout | None | NotGiven = NOT_GIVEN ,
494- ) -> Generator [ Union [ TaskView , TaskViewWithOutput [ BaseModel ]], None ]:
495+ ) -> Iterator [ TaskView ]:
495496 """Converts a polling loop into a generator loop."""
496497 hash : str | None = None
497498
@@ -1207,86 +1208,89 @@ async def retrieve(
12071208 )
12081209
12091210 @overload
1210- async def stream (
1211+ def stream (
12111212 self ,
12121213 task_id : str ,
1214+ structured_output_json : type [T ],
12131215 * ,
12141216 # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
12151217 # The extra values given here take precedence over values defined on the client or passed to this method.
12161218 extra_headers : Headers | None = None ,
12171219 extra_query : Query | None = None ,
12181220 extra_body : Body | None = None ,
12191221 timeout : float | httpx .Timeout | None | NotGiven = NOT_GIVEN ,
1220- ) -> AsyncGenerator [ TaskView , None ]: ...
1222+ ) -> AsyncIterator [ TaskViewWithOutput [ T ] ]: ...
12211223
12221224 @overload
1223- async def stream (
1225+ def stream (
12241226 self ,
12251227 task_id : str ,
1226- structured_output_json : type [ T ] ,
1228+ structured_output_json : None | NotGiven = NOT_GIVEN ,
12271229 * ,
12281230 # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
12291231 # The extra values given here take precedence over values defined on the client or passed to this method.
12301232 extra_headers : Headers | None = None ,
12311233 extra_query : Query | None = None ,
12321234 extra_body : Body | None = None ,
12331235 timeout : float | httpx .Timeout | None | NotGiven = NOT_GIVEN ,
1234- ) -> AsyncGenerator [ TaskViewWithOutput [ T ], None ]: ...
1236+ ) -> AsyncIterator [ TaskView ]: ...
12351237
1236- async def stream (
1238+ def stream (
12371239 self ,
12381240 task_id : str ,
1239- structured_output_json : Optional [ type [BaseModel ]] | NotGiven = NOT_GIVEN ,
1241+ structured_output_json : type [T ] | None | NotGiven = NOT_GIVEN ,
12401242 * ,
12411243 # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
12421244 # The extra values given here take precedence over values defined on the client or passed to this method.
12431245 extra_headers : Headers | None = None ,
12441246 extra_query : Query | None = None ,
12451247 extra_body : Body | None = None ,
12461248 timeout : float | httpx .Timeout | None | NotGiven = NOT_GIVEN ,
1247- ) -> AsyncGenerator [ Union [ TaskView , TaskViewWithOutput [BaseModel ]], None ]:
1249+ ) -> AsyncIterator [ TaskView ] | AsyncIterator [ TaskViewWithOutput [T ] ]:
12481250 """
12491251 Stream the task view as it is updated until the task is finished.
12501252 """
12511253
1252- async for res in self ._watch (
1253- task_id = task_id ,
1254- extra_headers = extra_headers ,
1255- extra_query = extra_query ,
1256- extra_body = extra_body ,
1257- timeout = timeout ,
1258- ):
1259- if structured_output_json is not None and isinstance (structured_output_json , type ):
1260- if res .done_output is None :
1261- yield TaskViewWithOutput [BaseModel ](
1262- ** res .model_dump (),
1263- parsed_output = None ,
1264- )
1254+ async def _gen () -> AsyncIterator [TaskView | TaskViewWithOutput [T ]]:
1255+ async for res in self ._watch (
1256+ task_id = task_id ,
1257+ extra_headers = extra_headers ,
1258+ extra_query = extra_query ,
1259+ extra_body = extra_body ,
1260+ timeout = timeout ,
1261+ ):
1262+ # If a schema (type[T]) is passed, wrap with parsed_output[T]
1263+ if structured_output_json is not None and isinstance (structured_output_json , type ):
1264+ if res .done_output is None :
1265+ yield TaskViewWithOutput [T ](
1266+ ** res .model_dump (),
1267+ parsed_output = None ,
1268+ )
1269+ else :
1270+ schema : type [T ] = structured_output_json
1271+ # pydantic returns the model instance, but the type checker can’t infer it.
1272+ parsed_output : T = schema .model_validate_json (res .done_output )
1273+ yield TaskViewWithOutput [T ](
1274+ ** res .model_dump (),
1275+ parsed_output = parsed_output ,
1276+ )
12651277 else :
1266- parsed_output = structured_output_json . model_validate_json ( res . done_output )
1278+ yield res
12671279
1268- yield TaskViewWithOutput [BaseModel ](
1269- ** res .model_dump (),
1270- parsed_output = parsed_output ,
1271- )
1272-
1273- else :
1274- yield res
1280+ return _gen ()
12751281
12761282 async def _watch (
12771283 self ,
12781284 task_id : str ,
12791285 interval : float = 1 ,
12801286 * ,
1281- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
1282- # The extra values given here take precedence over values defined on the client or passed to this method.
12831287 extra_headers : Headers | None = None ,
12841288 extra_query : Query | None = None ,
12851289 extra_body : Body | None = None ,
12861290 timeout : float | httpx .Timeout | None | NotGiven = NOT_GIVEN ,
1287- ) -> AsyncGenerator [ Union [ TaskView , TaskViewWithOutput [ BaseModel ]], None ]:
1291+ ) -> AsyncIterator [ TaskView ]:
12881292 """Converts a polling loop into a generator loop."""
1289- hash : str | None = None
1293+ prev_hash : str | None = None
12901294
12911295 while True :
12921296 res = await self .retrieve (
@@ -1298,9 +1302,8 @@ async def _watch(
12981302 )
12991303
13001304 res_hash = hash_task_view (res )
1301-
1302- if hash is None or res_hash != hash :
1303- hash = res_hash
1305+ if prev_hash is None or res_hash != prev_hash :
1306+ prev_hash = res_hash
13041307 yield res
13051308
13061309 if res .status == "finished" :
0 commit comments