11import logging
22import time
3- from typing import Any , Literal
3+ from dataclasses import dataclass
4+ from pathlib import Path
5+ from typing import Literal
46
57from tapeagents .core import Action , Observation , StopStep
8+ from tapeagents .tool_calling import ToolCallAction , ToolSpec
69
710from agentlab .backends .browser .base import BrowserBackend
8- from agentlab .benchmarks .abstract_env import AbstractEnv
11+ from agentlab .benchmarks .abstract_env import AbstractEnv , AbstractEnvArgs
912from agentlab .benchmarks .miniwob .task import AbstractWebTask
1013
1114logger = logging .getLogger (__name__ )
1215
16+ class GoalObservation (Observation ):
17+ kind : Literal ["goal_observation" ] = "goal_observation"
18+ goal : str
1319
1420class PageObservation (Observation ):
1521 kind : Literal ["page_observation" ] = "page_observation"
1622 content : str
1723
18- class BrowserAction (Action ):
19- kind : Literal ["browser_action" ] = "browser_action"
20- name : str
21- arguments : dict [str , Any ]
22-
2324
2425class BrowserEnv (AbstractEnv ):
2526 def __init__ (self , task_name : str , task : AbstractWebTask , backend : BrowserBackend , seed : int = 0 ):
2627 self .task_name = task_name
2728 self .task = task
2829 self .seed = seed
29- self .backend = backend
3030 self ._turns = 0
31+ self .backend = backend
32+ self .backend .initialize ()
3133
3234 def reset (self , seed : int ):
3335 self .seed = seed
36+ logger .info (f"Open task URL: { self .task .url } " )
37+ page_content = self .backend .goto (self .task .url )
3438 setup_js = self .task .get_setup_js ()
3539 if setup_js :
3640 js_result_str = self .backend .run_js (setup_js )
3741 logger .info (f"Task reset result: { js_result_str } " )
42+ return [GoalObservation (goal = js_result_str ), PageObservation (content = page_content )], {}
3843
39- def step (self , action : BrowserAction ) -> tuple [Observation , float , bool , bool , dict ]:
40- logger .info (f"BrowserEnv.step() called with action { type ( action ) } " )
44+ def step (self , action : ToolCallAction ) -> tuple [Observation , float , bool , bool , dict ]:
45+ logger .info (f"BrowserEnv.step() called with action { action . function . name } " )
4146
4247 action_exec_start = time .time ()
4348 finished = isinstance (action , StopStep )
@@ -65,8 +70,8 @@ def step(self, action: BrowserAction) -> tuple[Observation, float, bool, bool, d
6570 logger .info (f"Action result in observation: { obs_view } " )
6671 return observation , reward , finished , truncated , env_info
6772
68- def _step (self , action : Action ) -> PageObservation :
69- tool_result = self .backend .call_tool (action . name , action . arguments )
73+ def _step (self , action : ToolCallAction ) -> PageObservation :
74+ tool_result = self .backend .step (action )
7075 return PageObservation (content = tool_result )
7176
7277 def calculate_reward (self , action : Action , observation : PageObservation ) -> float :
@@ -80,3 +85,28 @@ def close(self):
8085 if teardown_js :
8186 js_result_str = self .backend .run_js (teardown_js )
8287 logger .info (f"Task teardown result: { js_result_str } " )
88+
89+ def actions (self ) -> list [ToolSpec ]:
90+ all_actions = self .backend .actions ()
91+ filtered_actions = self .task .filter_actions (all_actions )
92+ logger .info (f"Filtered { len (filtered_actions )} actions out of { len (all_actions )} for task { self .task .dataset } " )
93+ return filtered_actions
94+
95+
96+ @dataclass
97+ class BrowserEnvArgs (AbstractEnvArgs ):
98+ task : AbstractWebTask
99+ task_seed : int
100+ task_name : str
101+ backend : BrowserBackend
102+
103+ def __init__ (self , task_name : str , task : AbstractWebTask , backend : BrowserBackend , task_seed : int = 0 ):
104+ self .task_name = task_name
105+ self .task = task
106+ self .task_seed = task_seed
107+ self .backend = backend
108+
109+ def make_env (self , exp_dir : Path ) -> BrowserEnv :
110+ env = BrowserEnv (task_name = self .task_name , task = self .task , backend = self .backend , seed = self .task_seed )
111+ return env
112+
0 commit comments