@@ -417,42 +417,42 @@ def retrieve(
417
417
def stream (
418
418
self ,
419
419
task_id : str ,
420
- structured_output_json : None | NotGiven = NOT_GIVEN ,
420
+ structured_output_json : type [ T ] ,
421
421
* ,
422
422
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
423
423
# The extra values given here take precedence over values defined on the client or passed to this method.
424
424
extra_headers : Headers | None = None ,
425
425
extra_query : Query | None = None ,
426
426
extra_body : Body | None = None ,
427
427
timeout : float | httpx .Timeout | None | NotGiven = NOT_GIVEN ,
428
- ) -> Iterator [TaskView ]: ...
428
+ ) -> Iterator [TaskViewWithOutput [ T ] ]: ...
429
429
430
430
@overload
431
431
def stream (
432
432
self ,
433
433
task_id : str ,
434
- structured_output_json : type [ T ] ,
434
+ structured_output_json : None | NotGiven = NOT_GIVEN ,
435
435
* ,
436
436
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
437
437
# The extra values given here take precedence over values defined on the client or passed to this method.
438
438
extra_headers : Headers | None = None ,
439
439
extra_query : Query | None = None ,
440
440
extra_body : Body | None = None ,
441
441
timeout : float | httpx .Timeout | None | NotGiven = NOT_GIVEN ,
442
- ) -> Iterator [TaskViewWithOutput [ T ] ]: ...
442
+ ) -> Iterator [TaskView ]: ...
443
443
444
444
def stream (
445
445
self ,
446
446
task_id : str ,
447
- structured_output_json : type [BaseModel ] | None | NotGiven = NOT_GIVEN ,
447
+ structured_output_json : type [T ] | None | NotGiven = NOT_GIVEN ,
448
448
* ,
449
449
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
450
450
# The extra values given here take precedence over values defined on the client or passed to this method.
451
451
extra_headers : Headers | None = None ,
452
452
extra_query : Query | None = None ,
453
453
extra_body : Body | None = None ,
454
454
timeout : float | httpx .Timeout | None | NotGiven = NOT_GIVEN ,
455
- ) -> Iterator [Union [ TaskView , TaskViewWithOutput [BaseModel ] ]]:
455
+ ) -> Iterator [TaskView | TaskViewWithOutput [T ]]:
456
456
"""
457
457
Stream the task view as it is updated until the task is finished.
458
458
"""
@@ -466,14 +466,15 @@ def stream(
466
466
):
467
467
if structured_output_json is not None and isinstance (structured_output_json , type ):
468
468
if res .done_output is None :
469
- yield TaskViewWithOutput [BaseModel ](
469
+ yield TaskViewWithOutput [T ](
470
470
** res .model_dump (),
471
471
parsed_output = None ,
472
472
)
473
473
else :
474
- parsed_output = structured_output_json .model_validate_json (res .done_output )
474
+ schema : type [T ] = structured_output_json
475
+ parsed_output : T = schema .model_validate_json (res .done_output )
475
476
476
- yield TaskViewWithOutput [BaseModel ](
477
+ yield TaskViewWithOutput [T ](
477
478
** res .model_dump (),
478
479
parsed_output = parsed_output ,
479
480
)
@@ -1246,7 +1247,7 @@ def stream(
1246
1247
extra_query : Query | None = None ,
1247
1248
extra_body : Body | None = None ,
1248
1249
timeout : float | httpx .Timeout | None | NotGiven = NOT_GIVEN ,
1249
- ) -> AsyncIterator [TaskView ] | AsyncIterator [ TaskViewWithOutput [T ]]:
1250
+ ) -> AsyncIterator [TaskView | TaskViewWithOutput [T ]]:
1250
1251
"""
1251
1252
Stream the task view as it is updated until the task is finished.
1252
1253
"""
@@ -1259,7 +1260,6 @@ async def _gen() -> AsyncIterator[TaskView | TaskViewWithOutput[T]]:
1259
1260
extra_body = extra_body ,
1260
1261
timeout = timeout ,
1261
1262
):
1262
- # If a schema (type[T]) is passed, wrap with parsed_output[T]
1263
1263
if structured_output_json is not None and isinstance (structured_output_json , type ):
1264
1264
if res .done_output is None :
1265
1265
yield TaskViewWithOutput [T ](
0 commit comments