@@ -417,42 +417,42 @@ def retrieve(
417417 def stream (
418418 self ,
419419 task_id : str ,
420- structured_output_json : None | NotGiven = NOT_GIVEN ,
420+ structured_output_json : type [ T ] ,
421421 * ,
422422 # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
423423 # The extra values given here take precedence over values defined on the client or passed to this method.
424424 extra_headers : Headers | None = None ,
425425 extra_query : Query | None = None ,
426426 extra_body : Body | None = None ,
427427 timeout : float | httpx .Timeout | None | NotGiven = NOT_GIVEN ,
428- ) -> Iterator [TaskView ]: ...
428+ ) -> Iterator [TaskViewWithOutput [ T ] ]: ...
429429
430430 @overload
431431 def stream (
432432 self ,
433433 task_id : str ,
434- structured_output_json : type [ T ] ,
434+ structured_output_json : None | NotGiven = NOT_GIVEN ,
435435 * ,
436436 # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
437437 # The extra values given here take precedence over values defined on the client or passed to this method.
438438 extra_headers : Headers | None = None ,
439439 extra_query : Query | None = None ,
440440 extra_body : Body | None = None ,
441441 timeout : float | httpx .Timeout | None | NotGiven = NOT_GIVEN ,
442- ) -> Iterator [TaskViewWithOutput [ T ] ]: ...
442+ ) -> Iterator [TaskView ]: ...
443443
444444 def stream (
445445 self ,
446446 task_id : str ,
447- structured_output_json : type [BaseModel ] | None | NotGiven = NOT_GIVEN ,
447+ structured_output_json : type [T ] | None | NotGiven = NOT_GIVEN ,
448448 * ,
449449 # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
450450 # The extra values given here take precedence over values defined on the client or passed to this method.
451451 extra_headers : Headers | None = None ,
452452 extra_query : Query | None = None ,
453453 extra_body : Body | None = None ,
454454 timeout : float | httpx .Timeout | None | NotGiven = NOT_GIVEN ,
455- ) -> Iterator [Union [ TaskView , TaskViewWithOutput [BaseModel ] ]]:
455+ ) -> Iterator [TaskView | TaskViewWithOutput [T ]]:
456456 """
457457 Stream the task view as it is updated until the task is finished.
458458 """
@@ -466,14 +466,15 @@ def stream(
466466 ):
467467 if structured_output_json is not None and isinstance (structured_output_json , type ):
468468 if res .done_output is None :
469- yield TaskViewWithOutput [BaseModel ](
469+ yield TaskViewWithOutput [T ](
470470 ** res .model_dump (),
471471 parsed_output = None ,
472472 )
473473 else :
474- parsed_output = structured_output_json .model_validate_json (res .done_output )
474+ schema : type [T ] = structured_output_json
475+ parsed_output : T = schema .model_validate_json (res .done_output )
475476
476- yield TaskViewWithOutput [BaseModel ](
477+ yield TaskViewWithOutput [T ](
477478 ** res .model_dump (),
478479 parsed_output = parsed_output ,
479480 )
@@ -1246,7 +1247,7 @@ def stream(
12461247 extra_query : Query | None = None ,
12471248 extra_body : Body | None = None ,
12481249 timeout : float | httpx .Timeout | None | NotGiven = NOT_GIVEN ,
1249- ) -> AsyncIterator [TaskView ] | AsyncIterator [ TaskViewWithOutput [T ]]:
1250+ ) -> AsyncIterator [TaskView | TaskViewWithOutput [T ]]:
12501251 """
12511252 Stream the task view as it is updated until the task is finished.
12521253 """
@@ -1259,7 +1260,6 @@ async def _gen() -> AsyncIterator[TaskView | TaskViewWithOutput[T]]:
12591260 extra_body = extra_body ,
12601261 timeout = timeout ,
12611262 ):
1262- # If a schema (type[T]) is passed, wrap with parsed_output[T]
12631263 if structured_output_json is not None and isinstance (structured_output_json , type ):
12641264 if res .done_output is None :
12651265 yield TaskViewWithOutput [T ](
0 commit comments