|
| 1 | +import logging |
| 2 | +import time |
| 3 | +from typing import Any, Literal |
| 4 | + |
| 5 | +from tapeagents.core import Action, Observation, StopStep |
| 6 | + |
| 7 | +from agentlab.backends.browser.base import BrowserBackend |
| 8 | +from agentlab.benchmarks.abstract_env import AbstractEnv |
| 9 | +from agentlab.benchmarks.miniwob.task import AbstractWebTask |
| 10 | + |
| 11 | +logger = logging.getLogger(__name__) |
| 12 | + |
| 13 | + |
| 14 | +class PageObservation(Observation): |
| 15 | + kind: Literal["page_observation"] = "page_observation" |
| 16 | + content: str |
| 17 | + |
| 18 | +class BrowserAction(Action): |
| 19 | + kind: Literal["browser_action"] = "browser_action" |
| 20 | + name: str |
| 21 | + arguments: dict[str, Any] |
| 22 | + |
| 23 | + |
| 24 | +class BrowserEnv(AbstractEnv): |
| 25 | + def __init__(self, task_name: str, task: AbstractWebTask, backend: BrowserBackend, seed: int = 0): |
| 26 | + self.task_name = task_name |
| 27 | + self.task = task |
| 28 | + self.seed = seed |
| 29 | + self.backend = backend |
| 30 | + self._turns = 0 |
| 31 | + |
| 32 | + def reset(self, seed: int): |
| 33 | + self.seed = seed |
| 34 | + setup_js = self.task.get_setup_js() |
| 35 | + if setup_js: |
| 36 | + js_result_str = self.backend.run_js(setup_js) |
| 37 | + logger.info(f"Task reset result: {js_result_str}") |
| 38 | + |
| 39 | + def step(self, action: BrowserAction) -> tuple[Observation, float, bool, bool, dict]: |
| 40 | + logger.info(f"BrowserEnv.step() called with action {type(action)}") |
| 41 | + |
| 42 | + action_exec_start = time.time() |
| 43 | + finished = isinstance(action, StopStep) |
| 44 | + if finished: |
| 45 | + observation = Observation() # empty observation |
| 46 | + else: |
| 47 | + observation = self._step(action) |
| 48 | + action_exec_stop = time.time() |
| 49 | + self._turns += 1 |
| 50 | + |
| 51 | + truncated = self._turns >= self.max_turns |
| 52 | + |
| 53 | + if self.task.validate_per_step or finished or truncated: |
| 54 | + reward = self.calculate_reward(action, observation) |
| 55 | + else: |
| 56 | + reward = None |
| 57 | + |
| 58 | + env_info = { |
| 59 | + "step_metadata": observation.metadata, |
| 60 | + "action_exec_start": action_exec_start, |
| 61 | + "action_exec_stop": action_exec_stop, |
| 62 | + "action_exec_timeout": 0.0, |
| 63 | + } |
| 64 | + obs_view = observation.short_view() if isinstance(observation, Observation) else observation |
| 65 | + logger.info(f"Action result in observation: {obs_view}") |
| 66 | + return observation, reward, finished, truncated, env_info |
| 67 | + |
| 68 | + def _step(self, action: Action) -> PageObservation: |
| 69 | + tool_result = self.backend.call_tool(action.name, action.arguments) |
| 70 | + return PageObservation(content=tool_result) |
| 71 | + |
| 72 | + def calculate_reward(self, action: Action, observation: PageObservation) -> float: |
| 73 | + validate_js = self.task.get_step_validate_js() |
| 74 | + validate_result = self.backend.run_js(validate_js) |
| 75 | + reward, other = self.task.parse_validation_result(validate_result) |
| 76 | + return reward |
| 77 | + |
| 78 | + def close(self): |
| 79 | + teardown_js = self.task.get_teardown_js() |
| 80 | + if teardown_js: |
| 81 | + js_result_str = self.backend.run_js(teardown_js) |
| 82 | + logger.info(f"Task teardown result: {js_result_str}") |
0 commit comments