Skip to content

Commit 2dea4ac

Browse files
committed
add implementation
1 parent bef677e commit 2dea4ac

File tree

2 files changed

+97
-56
lines changed

2 files changed

+97
-56
lines changed

src/browser_use/wrapper/parse.py

Lines changed: 94 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import asyncio
12
import hashlib
23
import json
4+
import time
35
import typing
46
from datetime import datetime
57
from typing import Any, AsyncIterator, Generic, Iterator, Type, TypeVar, Union
@@ -33,47 +35,52 @@ def default(self, o: Any) -> Any: # type: ignore[override]
3335
return super().default(o)
3436

3537

36-
def hash_task_view(task_view: TaskView) -> str:
38+
def _hash_task_view(task_view: TaskView) -> str:
3739
"""Hashes the task view to detect changes."""
3840
return hashlib.sha256(
3941
json.dumps(task_view.model_dump(), sort_keys=True, cls=CustomJSONEncoder).encode()
4042
).hexdigest()
4143

4244

43-
# def _watch(
44-
# self,
45-
# task_id: str,
46-
# interval: float = 1, request_options:typing.Optional[RequestOptions] = None,
47-
# *,
48-
# # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
49-
# # The extra values given here take precedence over values defined on the client or passed to this method.
50-
# request_options: typing.Optional[RequestOptions] = None,
51-
# ) -> Iterator[TaskView]:
52-
# """Converts a polling loop into a generator loop."""
53-
# hash: str | None = None
45+
def _parse_task_view_with_output(task_view: TaskView, schema: Type[T]) -> TaskViewWithOutput[T]:
46+
"""Parses the task view with output."""
47+
if task_view.output is None:
48+
return TaskViewWithOutput[T](**task_view.model_dump(), parsed_output=None)
5449

55-
# while True:
56-
# res = self.retrieve(
57-
# task_id=task_id,
58-
# extra_headers=extra_headers,
59-
# extra_query=extra_query,
60-
# extra_body=extra_body,
61-
# timeout=timeout,
62-
# )
50+
return TaskViewWithOutput[T](**task_view.model_dump(), parsed_output=schema.model_validate_json(task_view.output))
6351

64-
# res_hash = hash_task_view(res)
6552

66-
# if hash is None or res_hash != hash:
67-
# hash = res_hash
68-
# yield res
53+
# Sync -----------------------------------------------------------------------
6954

70-
# if res.status == "finished":
71-
# break
7255

73-
# time.sleep(interval)
56+
def _watch(
57+
client: TasksClient, task_id: str, interval: float = 1, request_options: typing.Optional[RequestOptions] = None
58+
) -> Iterator[TaskViewWithOutput[T]]:
59+
"""Yields the latest task state on every change."""
60+
hash: str | None = None
61+
while True:
62+
res = client.get_task(task_id, request_options=request_options)
63+
res_hash = _hash_task_view(res)
7464

65+
if hash is None or res_hash != hash:
66+
hash = res_hash
67+
yield res
68+
69+
if res.status == "finished" or res.status == "stopped" or res.status == "paused":
70+
break
71+
72+
time.sleep(interval)
7573

76-
# Sync -----------------------------------------------------------------------
74+
75+
def _stream(
76+
client: TasksClient, task_id: str, interval: float = 1, request_options: typing.Optional[RequestOptions] = None
77+
) -> Iterator[TaskStepView]:
78+
"""Streams the steps of the task and closes when the task is finished."""
79+
total_steps = 0
80+
for state in _watch(client, task_id, interval, request_options):
81+
for i in range(total_steps, len(state.steps)):
82+
total_steps = i + 1
83+
yield state.steps[i]
7784

7885

7986
class WrappedTaskCreatedResponse(TaskCreatedResponse):
@@ -87,21 +94,23 @@ def complete(
8794
self, interval: float = 1, request_options: typing.Optional[RequestOptions] = None
8895
) -> TaskViewWithOutput[T]:
8996
"""Waits for the task to finish and return the result."""
90-
pass
97+
for state in _watch(self._client, self.id, interval, request_options):
98+
if state.status == "finished" or state.status == "stopped" or state.status == "paused":
99+
return state
100+
101+
raise Exception("Iterator ended without finding a finished state!")
91102

92103
def stream(
93104
self, interval: float = 1, request_options: typing.Optional[RequestOptions] = None
94105
) -> Iterator[TaskStepView]:
95106
"""Streams the steps of the task and closes when the task is finished."""
96-
for i in range(10):
97-
yield TaskStepView(number=i, status="finished")
107+
return _stream(self._client, self.id, interval, request_options)
98108

99109
def watch(
100110
self, interval: float = 1, request_options: typing.Optional[RequestOptions] = None
101111
) -> Iterator[TaskViewWithOutput[T]]:
102112
"""Yields the latest task state on every change."""
103-
for i in range(10):
104-
yield TaskViewWithOutput[T](status="finished")
113+
return _watch(self._client, self.id, interval, request_options)
105114

106115

107116
# Structured
@@ -120,26 +129,58 @@ def complete(
120129
self, interval: float = 1, request_options: typing.Optional[RequestOptions] = None
121130
) -> TaskViewWithOutput[T]:
122131
"""Waits for the task to finish and return the result."""
123-
pass
132+
for state in _watch(self._client, self.id, interval, request_options):
133+
if state.status == "finished" or state.status == "stopped" or state.status == "paused":
134+
return _parse_task_view_with_output(state, self._schema)
135+
136+
raise Exception("Iterator ended without finding a finished state!")
124137

125138
def stream(
126139
self, interval: float = 1, request_options: typing.Optional[RequestOptions] = None
127140
) -> Iterator[TaskStepView]:
128141
"""Streams the steps of the task and closes when the task is finished."""
129-
for i in range(10):
130-
yield TaskStepView(number=i, status="finished")
142+
return _stream(self._client, self.id, interval, request_options)
131143

132144
def watch(
133145
self, interval: float = 1, request_options: typing.Optional[RequestOptions] = None
134146
) -> Iterator[TaskViewWithOutput[T]]:
135147
"""Yields the latest task state on every change."""
136-
for i in range(10):
137-
yield TaskViewWithOutput[T](status="finished")
148+
for state in _watch(self._client, self.id, interval, request_options):
149+
yield _parse_task_view_with_output(state, self._schema)
138150

139151

140152
# Async ----------------------------------------------------------------------
141153

142154

155+
async def _async_watch(
156+
client: AsyncTasksClient, task_id: str, interval: float = 1, request_options: typing.Optional[RequestOptions] = None
157+
) -> AsyncIterator[TaskViewWithOutput[T]]:
158+
"""Yields the latest task state on every change."""
159+
hash: str | None = None
160+
while True:
161+
res = await client.get_task(task_id, request_options=request_options)
162+
res_hash = _hash_task_view(res)
163+
if hash is None or res_hash != hash:
164+
hash = res_hash
165+
yield res
166+
167+
if res.status == "finished" or res.status == "stopped" or res.status == "paused":
168+
break
169+
170+
await asyncio.sleep(interval)
171+
172+
173+
async def _async_stream(
174+
client: AsyncTasksClient, task_id: str, interval: float = 1, request_options: typing.Optional[RequestOptions] = None
175+
) -> AsyncIterator[TaskStepView]:
176+
"""Streams the steps of the task and closes when the task is finished."""
177+
total_steps = 0
178+
for state in _async_watch(client, task_id, interval, request_options):
179+
for i in range(total_steps, len(state.steps)):
180+
total_steps = i + 1
181+
yield state.steps[i]
182+
183+
143184
class AsyncWrappedTaskCreatedResponse(TaskCreatedResponse):
144185
"""TaskCreatedResponse with utility methods for easier interfacing with Browser Use Cloud."""
145186

@@ -149,21 +190,23 @@ def __init__(self, id: str, client: AsyncTasksClient):
149190

150191
async def complete(self, interval: float = 1, request_options: typing.Optional[RequestOptions] = None) -> TaskView:
151192
"""Waits for the task to finish and return the result."""
152-
pass
193+
for state in _async_watch(self._client, self.id, interval, request_options):
194+
if state.status == "finished" or state.status == "stopped" or state.status == "paused":
195+
return state
196+
197+
raise Exception("Iterator ended without finding a finished state!")
153198

154199
async def stream(
155200
self, interval: float = 1, request_options: typing.Optional[RequestOptions] = None
156201
) -> AsyncIterator[TaskStepView]:
157202
"""Streams the steps of the task and closes when the task is finished."""
158-
for i in range(10):
159-
yield TaskStepView(number=i, status="finished")
203+
return _async_stream(self._client, self.id, interval, request_options)
160204

161205
async def watch(
162206
self, interval: float = 1, request_options: typing.Optional[RequestOptions] = None
163207
) -> AsyncIterator[TaskView]:
164208
"""Yields the latest task state on every change."""
165-
for i in range(10):
166-
yield TaskView(status="finished")
209+
return _async_watch(self._client, self.id, interval, request_options)
167210

168211

169212
# Structured
@@ -182,18 +225,21 @@ async def complete(
182225
self, interval: float = 1, request_options: typing.Optional[RequestOptions] = None
183226
) -> TaskViewWithOutput[T]:
184227
"""Waits for the task to finish and return the result."""
185-
pass
228+
for state in _async_watch(self._client, self.id, interval, request_options):
229+
if state.status == "finished" or state.status == "stopped" or state.status == "paused":
230+
return _parse_task_view_with_output(state, self._schema)
231+
232+
raise Exception("Iterator ended without finding a finished state!")
186233

187234
async def stream(
188235
self, interval: float = 1, request_options: typing.Optional[RequestOptions] = None
189236
) -> AsyncIterator[TaskStepView]:
190237
"""Streams the steps of the task and closes when the task is finished."""
191-
for i in range(10):
192-
yield TaskStepView(number=i, status="finished")
238+
return _async_stream(self._client, self.id, interval, request_options)
193239

194240
async def watch(
195241
self, interval: float = 1, request_options: typing.Optional[RequestOptions] = None
196242
) -> AsyncIterator[TaskViewWithOutput[T]]:
197243
"""Yields the latest task state on every change."""
198-
for i in range(10):
199-
yield TaskViewWithOutput[T](status="finished")
244+
for state in _async_watch(self._client, self.id, interval, request_options):
245+
yield _parse_task_view_with_output(state, self._schema)

src/browser_use/wrapper/tasks/client.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
TaskViewWithOutput,
1313
WrappedStructuredTaskCreatedResponse,
1414
WrappedTaskCreatedResponse,
15+
_parse_task_view_with_output,
1516
)
1617

1718

@@ -145,10 +146,7 @@ def get_task(
145146
res = super().get_task(task_id, request_options=request_options)
146147

147148
if schema is not None:
148-
if res.output is None:
149-
return TaskViewWithOutput[T](**res.model_dump(), parsed_output=None)
150-
151-
return TaskViewWithOutput[T](**res.model_dump(), parsed_output=schema.model_validate_json(res.output))
149+
return _parse_task_view_with_output(res, schema)
152150
else:
153151
return res
154152

@@ -281,9 +279,6 @@ async def get_task(
281279
res = await super().get_task(task_id, request_options=request_options)
282280

283281
if schema is not None:
284-
if res.output is None:
285-
return TaskViewWithOutput[T](**res.model_dump(), parsed_output=None)
286-
287-
return TaskViewWithOutput[T](**res.model_dump(), parsed_output=schema.model_validate_json(res.output))
282+
return _parse_task_view_with_output(res, schema)
288283
else:
289284
return res

0 commit comments

Comments
 (0)