Skip to content

Commit cf68ef6

Browse files
committed
workarena bench, reuse bgym task inside
1 parent 462038e commit cf68ef6

File tree

5 files changed

+130
-24
lines changed

5 files changed

+130
-24
lines changed

src/agentlab/backends/browser/env.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from dataclasses import dataclass
44
from pathlib import Path
55

6-
from browsergym.core.task import AbstractBrowserTask
7-
86
from agentlab.actions import ToolCall, ToolsActionSet, ToolSpec
97
from agentlab.backends.browser.base import BrowserBackend
108
from agentlab.benchmarks.abstract_env import AbstractEnv, AbstractEnvArgs
@@ -27,7 +25,7 @@ def final_step():
2725

2826
class BrowserEnv(AbstractEnv):
2927
def __init__(
30-
self, task_name: str, task: AbstractWebTask | AbstractBrowserTask, backend: BrowserBackend, seed: int = 0
28+
self, task_name: str, task: AbstractWebTask, backend: BrowserBackend, seed: int = 0
3129
):
3230
self.task_name = task_name
3331
self.task = task
@@ -36,20 +34,12 @@ def __init__(
3634
self.backend = backend
3735
self.backend.initialize()
3836
self.goal = ""
39-
if isinstance(self.task, AbstractBrowserTask) and not self.backend.has_pw_page:
40-
raise ValueError(
41-
"Legacy task requires a backend with direct playwright page access."
42-
)
4337

4438
def reset(self, seed: int):
4539
self.seed = seed
46-
if isinstance(self.task, AbstractBrowserTask):
47-
self.goal, task_info = self.task.setup(page=self.backend.page)
48-
obs = self._get_obs()
49-
else:
50-
self.goal, task_info = self.task.setup(backend=self.backend)
51-
obs = self._get_obs()
52-
obs = self.task.obs_postprocess(obs)
40+
self.goal, task_info = self.task.setup(backend=self.backend)
41+
obs = self._get_obs()
42+
obs = self.task.obs_postprocess(obs)
5343
return obs, task_info
5444

5545
def _get_obs(self) -> dict:
@@ -86,21 +76,15 @@ def step(self, action: ToolCall | str) -> tuple[dict, float, bool, bool, dict]:
8676

8777
observation = self.obs_postprocess(observation)
8878

89-
if isinstance(self.task, AbstractBrowserTask):
90-
reward, done, _, info = self.task.validate(page=self.backend.page, chat_messages=[])
91-
elif self.task.validate_per_step or done or truncated:
92-
reward, info = self.task.validate()
93-
if info.get("done", False):
94-
done = True
95-
else:
96-
reward = 0.0
97-
info = {}
79+
reward, info = self.task.validate()
80+
if info.get("done", False):
81+
done = True
9882

9983
env_info = {
10084
**info,
10185
"action_exec_start": action_exec_start,
10286
"action_exec_stop": action_exec_stop,
103-
"action_exec_timeout": 0.0
87+
"action_exec_timeout": 0.0,
10488
}
10589
logger.info(f"Action result in observation: {observation}")
10690
return observation, reward, done, truncated, env_info

src/agentlab/backends/browser/playwright.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import time
23
from io import BytesIO
34
from typing import Any, Callable
45

@@ -82,6 +83,10 @@ def browser_mouse_click_xy(self, x: int, y: int):
8283
"""Click at a given x, y coordinate using the mouse."""
8384
self._page.mouse.click(x, y, delay=100)
8485

86+
def browser_wait(self, seconds: int):
87+
"""Wait for a given number of seconds, up to 10 seconds."""
88+
time.sleep(min(seconds, 10))
89+
8590
def evaluate_js(self, js: str):
8691
js_result = self._page.evaluate(js)
8792
logger.info(f"JS result: {js_result}")
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .benchmark import WorkArenaBenchmark
2+
from .task import WorkarenaTask
3+
4+
__all__ = ["WorkArenaBenchmark", "WorkarenaTask"]
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import logging
2+
from typing import Any
3+
4+
from browsergym.workarena import get_all_tasks_agents
5+
from browsergym.workarena.instance import SNowInstance
6+
from pydantic import ConfigDict
7+
from ray.cloudpickle import instance
8+
9+
from agentlab.actions import ToolsActionSet
10+
from agentlab.backends.browser.base import BrowserBackend
11+
from agentlab.backends.browser.env import BrowserEnvArgs
12+
from agentlab.benchmarks.abstract_env import AbstractBenchmark
13+
14+
from .task import WorkarenaTask
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
class WorkArenaBenchmark(AbstractBenchmark):
20+
model_config = ConfigDict(arbitrary_types_allowed=True)
21+
22+
backend_cls: type[BrowserBackend]
23+
name: str = "workarena"
24+
level: str = "l1"
25+
env_args_list: list[BrowserEnvArgs] = None # type: ignore
26+
dataset: list[WorkarenaTask] = None # type: ignore
27+
is_multi_tab: bool = False
28+
high_level_action_set_args: ToolsActionSet = None # type: ignore
29+
_snow_instance: SNowInstance = None # type: ignore
30+
31+
def model_post_init(self, __context: Any) -> None:
32+
self.name = f"workarena_{self.level}_{self.backend_cls.__name__.lower()}"
33+
self._snow_instance = SNowInstance()
34+
self.env_args_list = []
35+
if self.dataset is None:
36+
task_seed_tuples = get_all_tasks_agents(filter=self.level)
37+
self.dataset = self.load_tasks(task_seed_tuples, self.level)
38+
for task in self.dataset:
39+
env_args = BrowserEnvArgs(task=task, backend_cls=self.backend_cls)
40+
self.env_args_list.append(env_args)
41+
logger.info(f"Loaded {len(self.env_args_list)} workarena tasks")
42+
43+
def load_tasks(self, task_seed_tuples: list[tuple[type, int]], level: str) -> list[WorkarenaTask]:
44+
tasks = []
45+
46+
for task_cls, seed in task_seed_tuples:
47+
task = WorkarenaTask(
48+
url="",
49+
task_id=task_cls.get_task_id(),
50+
instance=self._snow_instance,
51+
task_cls=task_cls,
52+
level=level,
53+
seed=seed,
54+
)
55+
tasks.append(task)
56+
return tasks
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import logging
2+
from typing import ClassVar
3+
4+
from browsergym.utils.obs import prune_html
5+
from browsergym.workarena.instance import SNowInstance
6+
from browsergym.workarena.tasks.base import AbstractServiceNowTask
7+
from pydantic import ConfigDict
8+
9+
from agentlab.backends.browser import BrowserBackend
10+
from agentlab.benchmarks.web_task import AbstractWebTask
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class WorkarenaTask(AbstractWebTask):
16+
model_config = ConfigDict(arbitrary_types_allowed=True)
17+
18+
dataset: str = "workarena"
19+
level: str
20+
task_cls: type[AbstractServiceNowTask]
21+
seed: int
22+
instance: SNowInstance
23+
_task_obj: AbstractServiceNowTask = None # type: ignore
24+
actions_whitelist: ClassVar[list[str]] = [
25+
"browser_press_key",
26+
"browser_type",
27+
"browser_click",
28+
"browser_drag",
29+
"browser_hover",
30+
"browser_select_option",
31+
"browser_mouse_click_xy",
32+
"browser_wait",
33+
]
34+
35+
def setup(self, backend: BrowserBackend) -> tuple[str, dict]:
36+
if not backend.has_pw_page:
37+
raise ValueError("Workarena task requires a backend with playwright page access.")
38+
self._backend = backend
39+
self._task_obj = self.task_cls(instance=self.instance, seed=self.seed) # type: ignore
40+
self.url = self._task_obj.start_url
41+
goal, info = self._task_obj.setup(backend.page)
42+
logger.info(f"Current backend page URL: {backend.page.url}")
43+
# backend.goto(self.url)
44+
return goal, info
45+
46+
def teardown(self) -> None:
47+
self._task_obj.teardown()
48+
49+
def validate(self) -> tuple[float, dict]:
50+
reward, done, _, info = self._task_obj.validate(page=self._backend.page, chat_messages=[])
51+
info["done"] = done
52+
return reward, info
53+
54+
def obs_postprocess(self, obs: dict) -> dict:
55+
html = obs.pop("html", "")
56+
obs["pruned_html"] = prune_html(html)
57+
return obs

0 commit comments

Comments
 (0)