Skip to content

Commit 64bc583

Browse files
committed
Update tasks.py
1 parent 917c3a4 commit 64bc583

File tree

1 file changed

+45
-42
lines changed

1 file changed

+45
-42
lines changed

src/browser_use_sdk/resources/tasks.py

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import json
66
import time
77
import asyncio
8-
from typing import Dict, List, Union, TypeVar, Optional, Generator, AsyncGenerator, overload
8+
from typing import Dict, List, Union, TypeVar, Iterator, Optional, AsyncIterator, overload
99
from datetime import datetime
1010
from typing_extensions import Literal
1111

@@ -417,14 +417,15 @@ def retrieve(
417417
def stream(
418418
self,
419419
task_id: str,
420+
structured_output_json: None | NotGiven = NOT_GIVEN,
420421
*,
421422
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
422423
# The extra values given here take precedence over values defined on the client or passed to this method.
423424
extra_headers: Headers | None = None,
424425
extra_query: Query | None = None,
425426
extra_body: Body | None = None,
426427
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
427-
) -> Generator[TaskView, None]: ...
428+
) -> Iterator[TaskView]: ...
428429

429430
@overload
430431
def stream(
@@ -438,20 +439,20 @@ def stream(
438439
extra_query: Query | None = None,
439440
extra_body: Body | None = None,
440441
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
441-
) -> Generator[TaskViewWithOutput[T], None]: ...
442+
) -> Iterator[TaskViewWithOutput[T]]: ...
442443

443444
def stream(
444445
self,
445446
task_id: str,
446-
structured_output_json: Optional[type[BaseModel]] | NotGiven = NOT_GIVEN,
447+
structured_output_json: type[BaseModel] | None | NotGiven = NOT_GIVEN,
447448
*,
448449
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
449450
# The extra values given here take precedence over values defined on the client or passed to this method.
450451
extra_headers: Headers | None = None,
451452
extra_query: Query | None = None,
452453
extra_body: Body | None = None,
453454
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
454-
) -> Generator[Union[TaskView, TaskViewWithOutput[BaseModel]], None]:
455+
) -> Iterator[Union[TaskView, TaskViewWithOutput[BaseModel]]]:
455456
"""
456457
Stream the task view as it is updated until the task is finished.
457458
"""
@@ -491,7 +492,7 @@ def _watch(
491492
extra_query: Query | None = None,
492493
extra_body: Body | None = None,
493494
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
494-
) -> Generator[Union[TaskView, TaskViewWithOutput[BaseModel]], None]:
495+
) -> Iterator[TaskView]:
495496
"""Converts a polling loop into a generator loop."""
496497
hash: str | None = None
497498

@@ -1207,86 +1208,89 @@ async def retrieve(
12071208
)
12081209

12091210
@overload
1210-
async def stream(
1211+
def stream(
12111212
self,
12121213
task_id: str,
1214+
structured_output_json: type[T],
12131215
*,
12141216
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
12151217
# The extra values given here take precedence over values defined on the client or passed to this method.
12161218
extra_headers: Headers | None = None,
12171219
extra_query: Query | None = None,
12181220
extra_body: Body | None = None,
12191221
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
1220-
) -> AsyncGenerator[TaskView, None]: ...
1222+
) -> AsyncIterator[TaskViewWithOutput[T]]: ...
12211223

12221224
@overload
1223-
async def stream(
1225+
def stream(
12241226
self,
12251227
task_id: str,
1226-
structured_output_json: type[T],
1228+
structured_output_json: None | NotGiven = NOT_GIVEN,
12271229
*,
12281230
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
12291231
# The extra values given here take precedence over values defined on the client or passed to this method.
12301232
extra_headers: Headers | None = None,
12311233
extra_query: Query | None = None,
12321234
extra_body: Body | None = None,
12331235
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
1234-
) -> AsyncGenerator[TaskViewWithOutput[T], None]: ...
1236+
) -> AsyncIterator[TaskView]: ...
12351237

1236-
async def stream(
1238+
def stream(
12371239
self,
12381240
task_id: str,
1239-
structured_output_json: Optional[type[BaseModel]] | NotGiven = NOT_GIVEN,
1241+
structured_output_json: type[T] | None | NotGiven = NOT_GIVEN,
12401242
*,
12411243
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
12421244
# The extra values given here take precedence over values defined on the client or passed to this method.
12431245
extra_headers: Headers | None = None,
12441246
extra_query: Query | None = None,
12451247
extra_body: Body | None = None,
12461248
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
1247-
) -> AsyncGenerator[Union[TaskView, TaskViewWithOutput[BaseModel]], None]:
1249+
) -> AsyncIterator[TaskView] | AsyncIterator[TaskViewWithOutput[T]]:
12481250
"""
12491251
Stream the task view as it is updated until the task is finished.
12501252
"""
12511253

1252-
async for res in self._watch(
1253-
task_id=task_id,
1254-
extra_headers=extra_headers,
1255-
extra_query=extra_query,
1256-
extra_body=extra_body,
1257-
timeout=timeout,
1258-
):
1259-
if structured_output_json is not None and isinstance(structured_output_json, type):
1260-
if res.done_output is None:
1261-
yield TaskViewWithOutput[BaseModel](
1262-
**res.model_dump(),
1263-
parsed_output=None,
1264-
)
1254+
async def _gen() -> AsyncIterator[TaskView | TaskViewWithOutput[T]]:
1255+
async for res in self._watch(
1256+
task_id=task_id,
1257+
extra_headers=extra_headers,
1258+
extra_query=extra_query,
1259+
extra_body=extra_body,
1260+
timeout=timeout,
1261+
):
1262+
# If a schema (type[T]) is passed, wrap with parsed_output[T]
1263+
if structured_output_json is not None and isinstance(structured_output_json, type):
1264+
if res.done_output is None:
1265+
yield TaskViewWithOutput[T](
1266+
**res.model_dump(),
1267+
parsed_output=None,
1268+
)
1269+
else:
1270+
schema: type[T] = structured_output_json
1271+
# pydantic returns the model instance, but the type checker can’t infer it.
1272+
parsed_output: T = schema.model_validate_json(res.done_output)
1273+
yield TaskViewWithOutput[T](
1274+
**res.model_dump(),
1275+
parsed_output=parsed_output,
1276+
)
12651277
else:
1266-
parsed_output = structured_output_json.model_validate_json(res.done_output)
1278+
yield res
12671279

1268-
yield TaskViewWithOutput[BaseModel](
1269-
**res.model_dump(),
1270-
parsed_output=parsed_output,
1271-
)
1272-
1273-
else:
1274-
yield res
1280+
return _gen()
12751281

12761282
async def _watch(
12771283
self,
12781284
task_id: str,
12791285
interval: float = 1,
12801286
*,
1281-
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
1282-
# The extra values given here take precedence over values defined on the client or passed to this method.
12831287
extra_headers: Headers | None = None,
12841288
extra_query: Query | None = None,
12851289
extra_body: Body | None = None,
12861290
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
1287-
) -> AsyncGenerator[Union[TaskView, TaskViewWithOutput[BaseModel]], None]:
1291+
) -> AsyncIterator[TaskView]:
12881292
"""Converts a polling loop into a generator loop."""
1289-
hash: str | None = None
1293+
prev_hash: str | None = None
12901294

12911295
while True:
12921296
res = await self.retrieve(
@@ -1298,9 +1302,8 @@ async def _watch(
12981302
)
12991303

13001304
res_hash = hash_task_view(res)
1301-
1302-
if hash is None or res_hash != hash:
1303-
hash = res_hash
1305+
if prev_hash is None or res_hash != prev_hash:
1306+
prev_hash = res_hash
13041307
yield res
13051308

13061309
if res.status == "finished":

0 commit comments

Comments
 (0)