1- import json
21import logging
32import time
43from dataclasses import dataclass
54from pathlib import Path
6- from typing import Literal
7-
8- from tapeagents .core import Action , Observation , StopStep
9- from tapeagents .tool_calling import ToolCallAction , ToolSpec
105
116from agentlab .actions import ToolsActionSet
12- from agentlab .backends .browser .base import BrowserBackend
7+ from agentlab .backends .browser .base import BrowserBackend , ToolCallAction , ToolSpec
138from agentlab .benchmarks .abstract_env import AbstractEnv , AbstractEnvArgs
149from agentlab .benchmarks .web_task import AbstractWebTask
1510
1611logger = logging .getLogger (__name__ )
1712
18-
19- class GoalObservation (Observation ):
20- kind : Literal ["goal_observation" ] = "goal_observation"
21- goal : str
22-
23-
24- class PageObservation (Observation ):
25- kind : Literal ["page_observation" ] = "page_observation"
26- content : str
27-
28-
2913class BrowserEnv (AbstractEnv ):
3014 def __init__ (
3115 self , task_name : str , task : AbstractWebTask , backend : BrowserBackend , seed : int = 0
@@ -50,22 +34,23 @@ def reset(self, seed: int):
5034 page_content = self .backend .page_snapshot ()
5135 screenshot = self .backend .page_screenshot ()
5236 logger .info (f"Initial obs: { page_content } \n { screenshot } " )
53- return {
37+ obs = {
5438 "goal_object" : [{"type" : "text" , "text" : self .goal }],
5539 "pruned_html" : page_content ,
5640 "axtree_txt" : page_content ,
5741 "screenshot" : screenshot ,
5842 "last_action_error" : "" ,
5943 "focused_element_bid" : "none" ,
60- }, {}
44+ }
45+ return self .task .obs_postprocess (obs ), {}
6146
62- def step (self , action : ToolCallAction | str ) -> tuple [Observation , float , bool , bool , dict ]:
47+ def step (self , action : ToolCallAction | str ) -> tuple [dict , float , bool , bool , dict ]:
6348 if isinstance (action , str ):
6449 action = ToolsActionSet .parse_action (action )
6550 logger .info (f"BrowserEnv.step() called with action { action } " )
6651
6752 action_exec_start = time .time ()
68- finished = isinstance ( action , StopStep )
53+ finished = action . function . name == "final_step"
6954 if finished :
7055 observation = {
7156 "goal_object" : [{"type" : "text" , "text" : self .goal }],
@@ -76,6 +61,7 @@ def step(self, action: ToolCallAction | str) -> tuple[Observation, float, bool,
7661 }
7762 else :
7863 observation = self ._step (action )
64+ observation = self .task .obs_postprocess (observation )
7965 action_exec_stop = time .time ()
8066 self ._turns += 1
8167 logger .info (f"Obs:\n { observation ['pruned_html' ]} " )
@@ -95,8 +81,7 @@ def step(self, action: ToolCallAction | str) -> tuple[Observation, float, bool,
9581 "action_exec_stop" : action_exec_stop ,
9682 "action_exec_timeout" : 0.0 ,
9783 } | other
98- obs_view = observation .short_view () if isinstance (observation , Observation ) else observation
99- logger .info (f"Action result in observation: { obs_view } " )
84+ logger .info (f"Action result in observation: { observation } " )
10085 return observation , reward , finished , truncated , env_info
10186
10287 def _step (self , action : ToolCallAction ) -> dict :
@@ -108,7 +93,7 @@ def _step(self, action: ToolCallAction) -> dict:
10893 "focused_element_bid" : "none" ,
10994 }
11095
111- def validate_task (self , action : Action , observation : PageObservation ) -> tuple [float , dict ]:
96+ def validate_task (self , action : ToolCallAction , observation : dict ) -> tuple [float , dict ]:
11297 validate_js = self .task .get_step_validate_js ()
11398 validate_result = self .backend .run_js (validate_js )
11499 reward , other = self .task .parse_validation_result (validate_result )
0 commit comments