Skip to content

Commit 61a537f

Browse files
committed
remove tapeagents dep, add task-level obs postprocess
1 parent 955e0d3 commit 61a537f

File tree

6 files changed

+27
-34
lines changed

6 files changed

+27
-34
lines changed

src/agentlab/actions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import json
22
import logging
3+
from typing import Literal
34

45
from bgym import AbstractActionSet
5-
from tapeagents.tool_calling import FunctionCall, ToolCallAction, ToolSpec
6+
from pydantic import BaseModel, Field
67

8+
from agentlab.backends.browser.base import FunctionCall, ToolCallAction, ToolSpec
79
from agentlab.llm.llm_utils import parse_html_tags_raise
810

911
logger = logging.getLogger(__name__)

src/agentlab/agents/generic_agent/generic_agent.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,10 @@
1010

1111
from copy import deepcopy
1212
from dataclasses import asdict, dataclass
13-
from functools import partial
1413
from warnings import warn
1514

16-
import bgym
1715
from bgym import Benchmark
1816
from browsergym.experiments.agent import Agent, AgentInfo
19-
from tapeagents.tool_calling import ToolSpec
2017

2118
from agentlab.actions import ToolsActionSet
2219
from agentlab.agents import dynamic_prompting as dp
@@ -67,7 +64,7 @@ def prepare(self):
6764
def close(self):
6865
return self.chat_model_args.close_server()
6966

70-
def make_agent(self, actions: list[ToolSpec] | None = None):
67+
def make_agent(self, actions: list | None = None):
7168
return GenericAgent(
7269
chat_model_args=self.chat_model_args,
7370
flags=self.flags,
@@ -83,7 +80,7 @@ def __init__(
8380
chat_model_args: BaseModelArgs,
8481
flags: GenericPromptFlags,
8582
max_retry: int = 4,
86-
actions: list[ToolSpec] | None = None,
83+
actions: list | None = None,
8784
):
8885

8986
self.chat_llm = chat_model_args.make_model()

src/agentlab/backends/browser/env.py

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,15 @@
1-
import json
21
import logging
32
import time
43
from dataclasses import dataclass
54
from pathlib import Path
6-
from typing import Literal
7-
8-
from tapeagents.core import Action, Observation, StopStep
9-
from tapeagents.tool_calling import ToolCallAction, ToolSpec
105

116
from agentlab.actions import ToolsActionSet
12-
from agentlab.backends.browser.base import BrowserBackend
7+
from agentlab.backends.browser.base import BrowserBackend, ToolCallAction, ToolSpec
138
from agentlab.benchmarks.abstract_env import AbstractEnv, AbstractEnvArgs
149
from agentlab.benchmarks.web_task import AbstractWebTask
1510

1611
logger = logging.getLogger(__name__)
1712

18-
19-
class GoalObservation(Observation):
20-
kind: Literal["goal_observation"] = "goal_observation"
21-
goal: str
22-
23-
24-
class PageObservation(Observation):
25-
kind: Literal["page_observation"] = "page_observation"
26-
content: str
27-
28-
2913
class BrowserEnv(AbstractEnv):
3014
def __init__(
3115
self, task_name: str, task: AbstractWebTask, backend: BrowserBackend, seed: int = 0
@@ -50,22 +34,23 @@ def reset(self, seed: int):
5034
page_content = self.backend.page_snapshot()
5135
screenshot = self.backend.page_screenshot()
5236
logger.info(f"Initial obs: {page_content}\n{screenshot}")
53-
return {
37+
obs = {
5438
"goal_object": [{"type": "text", "text": self.goal}],
5539
"pruned_html": page_content,
5640
"axtree_txt": page_content,
5741
"screenshot": screenshot,
5842
"last_action_error": "",
5943
"focused_element_bid": "none",
60-
}, {}
44+
}
45+
return self.task.obs_postprocess(obs), {}
6146

62-
def step(self, action: ToolCallAction | str) -> tuple[Observation, float, bool, bool, dict]:
47+
def step(self, action: ToolCallAction | str) -> tuple[dict, float, bool, bool, dict]:
6348
if isinstance(action, str):
6449
action = ToolsActionSet.parse_action(action)
6550
logger.info(f"BrowserEnv.step() called with action {action}")
6651

6752
action_exec_start = time.time()
68-
finished = isinstance(action, StopStep)
53+
finished = action.function.name == "final_step"
6954
if finished:
7055
observation = {
7156
"goal_object": [{"type": "text", "text": self.goal}],
@@ -76,6 +61,7 @@ def step(self, action: ToolCallAction | str) -> tuple[Observation, float, bool,
7661
}
7762
else:
7863
observation = self._step(action)
64+
observation = self.task.obs_postprocess(observation)
7965
action_exec_stop = time.time()
8066
self._turns += 1
8167
logger.info(f"Obs:\n{observation['pruned_html']}")
@@ -95,8 +81,7 @@ def step(self, action: ToolCallAction | str) -> tuple[Observation, float, bool,
9581
"action_exec_stop": action_exec_stop,
9682
"action_exec_timeout": 0.0,
9783
} | other
98-
obs_view = observation.short_view() if isinstance(observation, Observation) else observation
99-
logger.info(f"Action result in observation: {obs_view}")
84+
logger.info(f"Action result in observation: {observation}")
10085
return observation, reward, finished, truncated, env_info
10186

10287
def _step(self, action: ToolCallAction) -> dict:
@@ -108,7 +93,7 @@ def _step(self, action: ToolCallAction) -> dict:
10893
"focused_element_bid": "none",
10994
}
11095

111-
def validate_task(self, action: Action, observation: PageObservation) -> tuple[float, dict]:
96+
def validate_task(self, action: ToolCallAction, observation: dict) -> tuple[float, dict]:
11297
validate_js = self.task.get_step_validate_js()
11398
validate_result = self.backend.run_js(validate_js)
11499
reward, other = self.task.parse_validation_result(validate_result)

src/agentlab/backends/browser/mcp_playwright.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22
import logging
33
from io import BytesIO
44

5-
from mcp.types import ImageContent, TextContent
65
from PIL import Image
7-
from tapeagents.tool_calling import ToolCallAction
86

9-
from agentlab.backends.browser.mcp import MCPBrowserBackend
7+
from agentlab.backends.browser.mcp import MCPBrowserBackend, ToolCallAction
108

119
logger = logging.getLogger(__name__)
1210

src/agentlab/benchmarks/miniwob/task.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, ClassVar
44

55
from browsergym.miniwob import ALL_MINIWOB_TASKS
6+
from PIL import Image
67

78
from agentlab.benchmarks.web_task import AbstractWebTask
89

@@ -130,6 +131,12 @@ def parse_validation_result(self, validation_result: str) -> tuple[float, dict]:
130131
"done": done,
131132
}
132133

134+
def obs_postprocess(self, obs: dict) -> dict:
135+
screenshot: Image.Image | None = obs.get("screenshot", None)
136+
if screenshot is not None:
137+
obs["screenshot"] = screenshot.crop((0, 0, 332, 214)) # crop to 332x214 because this is the viewport size for MiniWob
138+
return obs
139+
133140

134141
def get_miniwob_tasks(
135142
base_url: str | None = None, remove_human_display: bool = True, episode_max_time: int = 1000000

src/agentlab/benchmarks/web_task.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from typing import ClassVar
22

33
from pydantic import BaseModel
4-
from tapeagents.tool_calling import ToolSpec
4+
5+
from agentlab.backends.browser.base import ToolSpec
56

67

78
class AbstractWebTask(BaseModel):
@@ -29,3 +30,6 @@ def get_step_validate_js(self) -> str:
2930

3031
def parse_validation_result(self, validate_result: str) -> tuple[float, dict]:
3132
raise NotImplementedError
33+
34+
def obs_postprocess(self, obs: dict) -> dict:
35+
return obs

0 commit comments

Comments
 (0)