1+ import json
12import logging
23import time
34from dataclasses import dataclass
1011from agentlab .actions import ToolsActionSet
1112from agentlab .backends .browser .base import BrowserBackend
1213from agentlab .benchmarks .abstract_env import AbstractEnv , AbstractEnvArgs
13- from agentlab .benchmarks .miniwob . task import AbstractWebTask
14+ from agentlab .benchmarks .web_task import AbstractWebTask
1415
1516logger = logging .getLogger (__name__ )
1617
18+
1719class GoalObservation (Observation ):
1820 kind : Literal ["goal_observation" ] = "goal_observation"
1921 goal : str
2022
23+
2124class PageObservation (Observation ):
2225 kind : Literal ["page_observation" ] = "page_observation"
2326 content : str
2427
2528
2629class BrowserEnv (AbstractEnv ):
27- def __init__ (self , task_name : str , task : AbstractWebTask , backend : BrowserBackend , seed : int = 0 ):
30+ def __init__ (
31+ self , task_name : str , task : AbstractWebTask , backend : BrowserBackend , seed : int = 0
32+ ):
2833 self .task_name = task_name
2934 self .task = task
3035 self .seed = seed
3136 self ._turns = 0
3237 self .max_turns = task .max_turns
3338 self .backend = backend
3439 self .backend .initialize ()
40+ self .goal = ""
3541
3642 def reset (self , seed : int ):
3743 self .seed = seed
3844 logger .info (f"Open task URL: { self .task .url } " )
39- page_content = self .backend .goto (self .task .url )
45+ self .backend .goto (self .task .url )
4046 setup_js = self .task .get_setup_js ()
4147 if setup_js :
42- js_result_str = self .backend .run_js (setup_js )
43- logger .info (f"Task reset result: { js_result_str } " )
44- return [GoalObservation (goal = js_result_str ), PageObservation (content = page_content )], {}
48+ js_out = self .backend .run_js (setup_js )
49+ out_dict = json .loads (js_out )
50+ logger .info (f"Task setup result: { out_dict } " )
51+ goal = out_dict ["goal" ]
52+ done = out_dict ["done" ]
53+ task_start_time = out_dict ["task_start_time" ]
54+ logger .info (f"Task start time: { task_start_time } " )
55+ if done :
56+ raise ValueError ("Task is already done" )
57+ self .goal = goal
58+ logger .info (f"Task goal: { self .goal } " )
59+ page_content = self .backend .page_snapshot ()
60+ logger .info (f"Initial obs: { page_content } " )
61+ return {
62+ "goal_object" : [{"type" : "text" , "text" : self .goal }],
63+ "pruned_html" : page_content ,
64+ "axtree_txt" : "" ,
65+ "last_action_error" : "" ,
66+ "focused_element_bid" : "none" ,
67+ }, {}
4568
4669 def step (self , action : ToolCallAction | str ) -> tuple [Observation , float , bool , bool , dict ]:
4770 if isinstance (action , str ):
@@ -51,49 +74,67 @@ def step(self, action: ToolCallAction | str) -> tuple[Observation, float, bool,
5174 action_exec_start = time .time ()
5275 finished = isinstance (action , StopStep )
5376 if finished :
54- observation = Observation () # empty observation
77+ observation = {
78+ "goal_object" : [{"type" : "text" , "text" : self .goal }],
79+ "pruned_html" : "Task finished" ,
80+ "axtree_txt" : "" ,
81+ "last_action_error" : "" ,
82+ "focused_element_bid" : "none" ,
83+ }
5584 else :
5685 observation = self ._step (action )
5786 action_exec_stop = time .time ()
5887 self ._turns += 1
88+ logger .info (f"Obs:\n { observation ['pruned_html' ]} " )
5989
6090 truncated = self ._turns >= self .max_turns
6191
6292 if self .task .validate_per_step or finished or truncated :
63- reward = self .calculate_reward (action , observation )
93+ reward , other = self .calculate_reward (action , observation )
94+ if other .get ("done" , False ):
95+ finished = True
6496 else :
6597 reward = 0.0
98+ other = {}
6699
67100 env_info = {
68- "step_metadata" : observation .metadata ,
69101 "action_exec_start" : action_exec_start ,
70102 "action_exec_stop" : action_exec_stop ,
71103 "action_exec_timeout" : 0.0 ,
72- }
104+ } | other
73105 obs_view = observation .short_view () if isinstance (observation , Observation ) else observation
74106 logger .info (f"Action result in observation: { obs_view } " )
75107 return observation , reward , finished , truncated , env_info
76108
77- def _step (self , action : ToolCallAction ) -> PageObservation :
109+ def _step (self , action : ToolCallAction ) -> dict :
78110 tool_result = self .backend .step (action )
79- return PageObservation (content = tool_result )
111+ return {
112+ "goal_object" : [{"type" : "text" , "text" : self .goal }],
113+ "pruned_html" : tool_result ,
114+ "axtree_txt" : "" ,
115+ "last_action_error" : "" ,
116+ "focused_element_bid" : "none" ,
117+ }
80118
81- def calculate_reward (self , action : Action , observation : PageObservation ) -> float :
119+ def calculate_reward (self , action : Action , observation : PageObservation ) -> tuple [ float , dict ] :
82120 validate_js = self .task .get_step_validate_js ()
83121 validate_result = self .backend .run_js (validate_js )
84122 reward , other = self .task .parse_validation_result (validate_result )
85- return reward
123+ return reward , other
86124
87125 def close (self ):
88126 teardown_js = self .task .get_teardown_js ()
89127 if teardown_js :
90128 js_result_str = self .backend .run_js (teardown_js )
91129 logger .info (f"Task teardown result: { js_result_str } " )
130+ self .backend .close ()
92131
93132 def actions (self ) -> list [ToolSpec ]:
94133 all_actions = self .backend .actions ()
95134 filtered_actions = self .task .filter_actions (all_actions )
96- logger .info (f"Filtered { len (filtered_actions )} actions out of { len (all_actions )} for task { self .task .dataset } " )
135+ logger .info (
136+ f"Filtered { len (filtered_actions )} actions out of { len (all_actions )} for task { self .task .dataset } "
137+ )
97138 return filtered_actions
98139
99140
@@ -104,13 +145,16 @@ class BrowserEnvArgs(AbstractEnvArgs):
104145 task_name : str
105146 backend : BrowserBackend
106147
107- def __init__ (self , task_name : str , task : AbstractWebTask , backend : BrowserBackend , task_seed : int = 0 ):
148+ def __init__ (
149+ self , task_name : str , task : AbstractWebTask , backend : BrowserBackend , task_seed : int = 0
150+ ):
108151 self .task_name = task_name
109152 self .task = task
110153 self .task_seed = task_seed
111154 self .backend = backend
112155
113156 def make_env (self , exp_dir : Path ) -> BrowserEnv :
114- env = BrowserEnv (task_name = self .task_name , task = self .task , backend = self .backend , seed = self .task_seed )
157+ env = BrowserEnv (
158+ task_name = self .task_name , task = self .task , backend = self .backend , seed = self .task_seed
159+ )
115160 return env
116-
0 commit comments