Skip to content

Commit 450dacf

Browse files
committed
actions whitelist, fixes, support new order of the agent env creation in the loop
1 parent 5ceeb60 commit 450dacf

File tree

9 files changed

+139
-84
lines changed

9 files changed

+139
-84
lines changed

experiments/run_miniwob.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
fmt = "%(asctime)s - %(levelname)s - %(name)s:%(lineno)d - %(funcName)s() - %(message)s"
1010
logging.basicConfig(level=logging.INFO, force=True, format=fmt, handlers=[logging.StreamHandler()])
11+
logger = logging.getLogger(__name__)
1112

1213
if __name__ == "__main__":
1314
config = load_config("gaia_l1")

src/agentlab/agents/tapeagent/agent.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from tapeagents.agent import Agent
1010
from tapeagents.core import Action, Observation, StopStep, TapeMetadata, Thought
1111
from tapeagents.core import Tape as BaseTape
12+
from tapeagents.tool_calling import ToolSpec
1213

1314
from agentlab.agents.agent_args import AgentArgs
1415

@@ -40,8 +41,12 @@ def load_config(config_name: str) -> DictConfig:
4041
class TapeAgentArgs(AgentArgs):
4142
config: DictConfig = None # type: ignore
4243

43-
def make_agent(self) -> bgym.Agent:
44-
agent: Agent = hydra.utils.instantiate(self.config.agent)
44+
def make_agent(self, known_actions: tuple[ToolSpec, ...] | None) -> bgym.Agent:
45+
if known_actions is None:
46+
agent = hydra.utils.instantiate(self.config.agent)
47+
else:
48+
tools_description = "\n".join([action.description() for action in known_actions])
49+
agent = hydra.utils.instantiate(self.config.agent, known_actions=known_actions, tools_description=tools_description)
4550
return TapeAgent(agent=agent)
4651

4752

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,44 @@
1-
from tapeagents.environment import FunctionCall
2-
from tapeagents.mcp import MCPEnvironment, ToolCallAction
3-
from tapeagents.tool_calling import as_openai_tool
1+
from pydantic import BaseModel
2+
from tapeagents.mcp import MCPEnvironment
3+
from tapeagents.tool_calling import FunctionCall, ToolCallAction, ToolSpec
44

55

6-
class BrowserBackend():
7-
def run_js(self, js: str):
6+
class BrowserBackend(BaseModel):
7+
def initialize(self) -> None:
88
raise NotImplementedError
99

10-
def call_tool(self, tool_name: str, arguments: dict) -> str:
10+
def run_js(self, js: str):
1111
raise NotImplementedError
1212

13-
def tools_description(self) -> str:
13+
def goto(self, url: str) -> str:
1414
raise NotImplementedError
1515

16-
def tools(self) -> list[dict]:
16+
def step(self, action: ToolCallAction) -> str:
1717
raise NotImplementedError
1818

19+
def actions(self) -> tuple[ToolSpec]:
20+
raise NotImplementedError
1921

20-
class MCPBrowserBackend(BrowserBackend):
21-
def __init__(self, config_path: str):
22-
self.config_path = config_path
23-
self.mcp = MCPEnvironment(config_path=self.config_path)
24-
self.mcp.initialize()
2522

26-
def call_tool(self, tool_name: str, arguments: dict) -> str:
27-
action = ToolCallAction(
28-
function=FunctionCall(name=tool_name, arguments=arguments)
29-
)
30-
tool_result = self.mcp.step(action)
31-
return tool_result.content.content[0].text
3223

24+
class MCPBrowserBackend(BrowserBackend):
25+
config_path: str
26+
_mcp = None
27+
28+
def initialize(self) -> None:
29+
self._mcp = MCPEnvironment(config_path=self.config_path)
30+
self._mcp.initialize()
3331

34-
def tools_description(self) -> str:
35-
return self.mcp.tools_description()
32+
def step(self, action: ToolCallAction) -> str:
33+
return self._call_mcp(action)
3634

37-
def tools(self) -> list[dict]:
38-
actions = self.mcp.actions()
39-
tools = [as_openai_tool(a).model_dump() for a in actions]
40-
return tools
35+
def call_tool(self, tool_name: str, arguments: dict) -> str:
36+
return self._call_mcp(ToolCallAction(function=FunctionCall(name=tool_name, arguments=arguments)))
37+
38+
def _call_mcp(self, action: ToolCallAction) -> str:
39+
tool_result = self._mcp.step(action)
40+
texts = [c.text for c in tool_result.content.content]
41+
return "\n\n".join(texts)
42+
43+
def actions(self) -> tuple[ToolSpec]:
44+
return self._mcp.actions()

src/agentlab/backends/browser/env.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,48 @@
11
import logging
22
import time
3-
from typing import Any, Literal
3+
from dataclasses import dataclass
4+
from pathlib import Path
5+
from typing import Literal
46

57
from tapeagents.core import Action, Observation, StopStep
8+
from tapeagents.tool_calling import ToolCallAction, ToolSpec
69

710
from agentlab.backends.browser.base import BrowserBackend
8-
from agentlab.benchmarks.abstract_env import AbstractEnv
11+
from agentlab.benchmarks.abstract_env import AbstractEnv, AbstractEnvArgs
912
from agentlab.benchmarks.miniwob.task import AbstractWebTask
1013

1114
logger = logging.getLogger(__name__)
1215

16+
class GoalObservation(Observation):
17+
kind: Literal["goal_observation"] = "goal_observation"
18+
goal: str
1319

1420
class PageObservation(Observation):
1521
kind: Literal["page_observation"] = "page_observation"
1622
content: str
1723

18-
class BrowserAction(Action):
19-
kind: Literal["browser_action"] = "browser_action"
20-
name: str
21-
arguments: dict[str, Any]
22-
2324

2425
class BrowserEnv(AbstractEnv):
2526
def __init__(self, task_name: str, task: AbstractWebTask, backend: BrowserBackend, seed: int = 0):
2627
self.task_name = task_name
2728
self.task = task
2829
self.seed = seed
29-
self.backend = backend
3030
self._turns = 0
31+
self.backend = backend
32+
self.backend.initialize()
3133

3234
def reset(self, seed: int):
3335
self.seed = seed
36+
logger.info(f"Open task URL: {self.task.url}")
37+
page_content = self.backend.goto(self.task.url)
3438
setup_js = self.task.get_setup_js()
3539
if setup_js:
3640
js_result_str = self.backend.run_js(setup_js)
3741
logger.info(f"Task reset result: {js_result_str}")
42+
return [GoalObservation(goal=js_result_str), PageObservation(content=page_content)], {}
3843

39-
def step(self, action: BrowserAction) -> tuple[Observation, float, bool, bool, dict]:
40-
logger.info(f"BrowserEnv.step() called with action {type(action)}")
44+
def step(self, action: ToolCallAction) -> tuple[Observation, float, bool, bool, dict]:
45+
logger.info(f"BrowserEnv.step() called with action {action.function.name}")
4146

4247
action_exec_start = time.time()
4348
finished = isinstance(action, StopStep)
@@ -65,8 +70,8 @@ def step(self, action: BrowserAction) -> tuple[Observation, float, bool, bool, d
6570
logger.info(f"Action result in observation: {obs_view}")
6671
return observation, reward, finished, truncated, env_info
6772

68-
def _step(self, action: Action) -> PageObservation:
69-
tool_result = self.backend.call_tool(action.name, action.arguments)
73+
def _step(self, action: ToolCallAction) -> PageObservation:
74+
tool_result = self.backend.step(action)
7075
return PageObservation(content=tool_result)
7176

7277
def calculate_reward(self, action: Action, observation: PageObservation) -> float:
@@ -80,3 +85,28 @@ def close(self):
8085
if teardown_js:
8186
js_result_str = self.backend.run_js(teardown_js)
8287
logger.info(f"Task teardown result: {js_result_str}")
88+
89+
def actions(self) -> list[ToolSpec]:
90+
all_actions = self.backend.actions()
91+
filtered_actions = self.task.filter_actions(all_actions)
92+
logger.info(f"Filtered {len(filtered_actions)} actions out of {len(all_actions)} for task {self.task.dataset}")
93+
return filtered_actions
94+
95+
96+
@dataclass
97+
class BrowserEnvArgs(AbstractEnvArgs):
98+
task: AbstractWebTask
99+
task_seed: int
100+
task_name: str
101+
backend: BrowserBackend
102+
103+
def __init__(self, task_name: str, task: AbstractWebTask, backend: BrowserBackend, task_seed: int = 0):
104+
self.task_name = task_name
105+
self.task = task
106+
self.task_seed = task_seed
107+
self.backend = backend
108+
109+
def make_env(self, exp_dir: Path) -> BrowserEnv:
110+
env = BrowserEnv(task_name=self.task_name, task=self.task, backend=self.backend, seed=self.task_seed)
111+
return env
112+
Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,30 @@
1+
import logging
2+
3+
from tapeagents.tool_calling import ToolCallAction
4+
15
from agentlab.backends.browser.base import MCPBrowserBackend
26

7+
logger = logging.getLogger(__name__)
8+
39
DEFAULT_CONFIG_PATH = "src/agentlab/backends/browser/mcp_playwright.json"
410

11+
512
class MCPPlaywright(MCPBrowserBackend):
6-
def __init__(self, config_path: str | None = None):
7-
super().__init__(config_path or DEFAULT_CONFIG_PATH)
13+
config_path: str = DEFAULT_CONFIG_PATH
814

915
def run_js(self, js: str):
1016
raw_response = self.call_tool("browser_evaluate", {"function": js})
1117
_, half_response = raw_response.split("### Result", maxsplit=1)
1218
result_str, _ = half_response.split("\n### Ran", maxsplit=1)
1319
result_str = result_str.strip()
1420
return result_str
21+
22+
def step(self, action: ToolCallAction) -> str:
23+
tool_result = self._call_mcp(action)
24+
logger.info(f"Tool result: {tool_result}")
25+
snapshot = self.call_tool("browser_snapshot", {})
26+
return snapshot
27+
28+
def goto(self, url: str) -> str:
29+
tool_result = self.call_tool("browser_navigate", {"url": url})
30+
return tool_result
Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,18 @@
11
import logging
2-
from dataclasses import dataclass
3-
from pathlib import Path
42
from typing import Any
53

6-
from pydantic import ConfigDict
7-
84
from agentlab.backends.browser.base import BrowserBackend
9-
from agentlab.backends.browser.env import BrowserEnv
10-
from agentlab.benchmarks.abstract_env import AbstractBenchmark, AbstractEnvArgs
5+
from agentlab.backends.browser.env import BrowserEnvArgs
6+
from agentlab.benchmarks.abstract_env import AbstractBenchmark
117
from agentlab.benchmarks.miniwob.task import MiniWobTask, get_miniwob_tasks
128

139
logger = logging.getLogger(__name__)
1410

1511

16-
@dataclass
17-
class MiniwobArgs(AbstractEnvArgs):
18-
task: MiniWobTask
19-
task_seed: int
20-
task_name: str
21-
backend: BrowserBackend
22-
23-
def __init__(self, task_name: str, task: MiniWobTask, backend: BrowserBackend, task_seed: int = 0):
24-
self.task_name = task_name
25-
self.task = task
26-
self.task_seed = task_seed
27-
self.backend = backend
28-
29-
def make_env(self, exp_dir: Path, action_mapping=None) -> BrowserEnv:
30-
env = BrowserEnv(task_name=self.task_name, task=self.task, backend=self.backend, seed=self.task_seed)
31-
return env
32-
33-
3412
class MiniWobBenchmark(AbstractBenchmark):
35-
model_config = ConfigDict(arbitrary_types_allowed=True)
36-
3713
backend: BrowserBackend
3814
name: str = "miniwob"
39-
env_args_list: list[MiniwobArgs] = None # type: ignore
15+
env_args_list: list[BrowserEnvArgs] = None # type: ignore
4016
dataset: list[MiniWobTask] = None # type: ignore
4117

4218
def model_post_init(self, __context: Any) -> None:
@@ -45,6 +21,6 @@ def model_post_init(self, __context: Any) -> None:
4521
self.dataset = get_miniwob_tasks()
4622
for task in self.dataset:
4723
name = f"miniwob.{task.task_id}"
48-
env_args = MiniwobArgs(task_name=name, task=task, backend=self.backend)
24+
env_args = BrowserEnvArgs(task_name=name, task=task, backend=self.backend)
4925
self.env_args_list.append(env_args)
5026
logger.info(f"Loaded {len(self.env_args_list)} miniwob tasks")

src/agentlab/benchmarks/miniwob/task.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
import os
3-
from typing import Any
3+
from typing import Any, ClassVar
44

55
from browsergym.miniwob import ALL_MINIWOB_TASKS
66

@@ -16,6 +16,15 @@ class MiniWobTask(AbstractWebTask):
1616
url: str = None
1717
remove_human_display: bool = True
1818
episode_max_time: int = 1000000
19+
actions_whitelist: ClassVar[list[str]] = [
20+
"browser_press_key",
21+
"browser_type",
22+
"browser_navigate",
23+
"browser_click",
24+
"browser_drag",
25+
"browser_hover",
26+
"browser_select_option",
27+
]
1928

2029
def model_post_init(self, __context: Any):
2130
self.url = f"{self.base_url}/{self.subdomain}.html"

src/agentlab/benchmarks/web_task.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
1+
from typing import ClassVar
12

23
from pydantic import BaseModel
4+
from tapeagents.tool_calling import ToolSpec
35

46

57
class AbstractWebTask(BaseModel):
6-
name: str
8+
dataset: str
9+
url: str
710
validate_per_step: bool = False
8-
11+
actions_whitelist: ClassVar[list[str]] = []
12+
13+
@classmethod
14+
def filter_actions(cls, actions: list[ToolSpec]) -> list[str]:
15+
return [action for action in actions if action.function.name in cls.actions_whitelist]
16+
917
def get_setup_js(self) -> str:
1018
raise NotImplementedError
1119

src/agentlab/experiments/loop.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from PIL import Image
2626
from tqdm import tqdm
2727

28+
from agentlab.backends.browser.env import BrowserEnvArgs
29+
2830
try:
2931
from agentlab.agents.tapeagent import TapeAgent, save_tape
3032
except ImportError:
@@ -414,19 +416,23 @@ def run(self):
414416
env, step_info, err_msg, stack_trace = None, None, None, None
415417
try:
416418
logger.info(f"Running experiment {self.exp_name} in:\n {self.exp_dir}")
417-
agent = self.agent_args.make_agent()
418-
if hasattr(agent, "set_task_name"):
419-
agent.set_task_name(self.env_args.task_name)
420-
421-
logger.debug("Agent created.")
422-
423-
env = self.env_args.make_env(
424-
action_mapping=agent.action_set.to_python_code,
425-
exp_dir=self.exp_dir,
426-
use_raw_page_output=getattr(self.agent_args, "use_raw_page_output", False),
427-
)
419+
if isinstance(self.env_args, BrowserEnvArgs):
420+
env = self.env_args.make_env(exp_dir=self.exp_dir)
421+
logger.debug("Environment created.")
422+
agent = self.agent_args.make_agent(known_actions=env.actions())
423+
logger.debug(f"Agent created with actions: {env.actions()}")
424+
else:
425+
agent = self.agent_args.make_agent()
426+
if hasattr(agent, "set_task_name"):
427+
agent.set_task_name(self.env_args.task_name)
428+
logger.debug("Agent created.")
429+
env = self.env_args.make_env(
430+
action_mapping=agent.action_set.to_python_code,
431+
exp_dir=self.exp_dir,
432+
use_raw_page_output=getattr(self.agent_args, "use_raw_page_output", False),
433+
)
434+
logger.debug("Environment created.")
428435

429-
logger.debug("Environment created.")
430436
step_info = StepInfo(step=0)
431437
episode_info = [step_info]
432438
step_info.from_reset(

0 commit comments

Comments
 (0)