Skip to content

Commit 4e27c3a

Browse files
committed
fixes
1 parent f278c0f commit 4e27c3a

File tree

7 files changed

+150
-38
lines changed

7 files changed

+150
-38
lines changed

experiments/run_miniwob.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import logging
22
import os
33

4+
from bgym import DEFAULT_BENCHMARKS
45
from dotenv import load_dotenv
56

7+
from agentlab.agents.generic_agent.agent_configs import FLAGS_GPT_4o
8+
from agentlab.agents.generic_agent.generic_agent import GenericAgentArgs
69
from agentlab.agents.tapeagent.agent import TapeAgentArgs, load_config
710
from agentlab.backends.browser.mcp_playwright import MCPPlaywright
811
from agentlab.benchmarks.miniwob import MiniWobBenchmark
912
from agentlab.experiments.study import make_study
13+
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT
1014

1115
fmt = "%(asctime)s - %(levelname)s - %(name)s:%(lineno)d - %(funcName)s() - %(message)s"
1216
logging.basicConfig(level=logging.INFO, force=True, format=fmt, handlers=[logging.StreamHandler()])
@@ -15,15 +19,26 @@
1519

1620
if __name__ == "__main__":
1721
config = load_config("miniwob")
22+
23+
# benchmark = DEFAULT_BENCHMARKS["miniwob"]()
24+
agent_args = GenericAgentArgs(
25+
chat_model_args=CHAT_MODEL_ARGS_DICT["anthropic/claude-sonnet-4-20250514"],
26+
flags=FLAGS_GPT_4o,
27+
)
28+
29+
benchmark = MiniWobBenchmark(backend=MCPPlaywright())
30+
# agent_args =TapeAgentArgs(agent_name=config.name, config=config)
31+
32+
1833
study = make_study(
19-
benchmark=MiniWobBenchmark(backend=MCPPlaywright()),
20-
agent_args=TapeAgentArgs(agent_name=config.name, config=config),
34+
benchmark=benchmark,
35+
agent_args=agent_args,
2136
comment=config.comment,
2237
logging_level=logging.INFO,
2338
logging_level_stdout=logging.INFO,
2439
)
2540
if os.environ.get("AGENTLAB_DEBUG"):
26-
study.exp_args_list = study.exp_args_list[:1]
41+
study.exp_args_list = study.exp_args_list[1:2]
2742
study.run(n_jobs=1, n_relaunch=1, parallel_backend="sequential")
2843
else:
2944
study.run(n_jobs=config.n_jobs, n_relaunch=1, parallel_backend=config.parallel_backend)

src/agentlab/actions.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,27 @@
1+
import json
2+
import logging
3+
14
from bgym import AbstractActionSet
25
from tapeagents.tool_calling import FunctionCall, ToolCallAction, ToolSpec
36

47
from agentlab.llm.llm_utils import parse_html_tags_raise
58

9+
logger = logging.getLogger(__name__)
10+
611

712
class ToolsActionSet(AbstractActionSet):
8-
def __init__(self, actions:list[ToolSpec]):
13+
multiaction: bool = False
14+
strict: bool = False
15+
16+
def __init__(self, actions: list[ToolSpec]):
917
self.actions = actions
1018

1119
def describe(self, with_long_description: bool = True, with_examples: bool = True) -> str:
12-
tools_description = "\n".join([action.description() for action in self.actions])
20+
descs = []
21+
for action in self.actions:
22+
desc = f"## {action.description()}.\n Schema: {action.model_dump_json(indent=2)}"
23+
descs.append(desc)
24+
tools_description = "\n".join(descs)
1325
return tools_description
1426

1527
def example_action(self, abstract: bool) -> str:
@@ -35,13 +47,26 @@ def example_action(self, abstract: bool) -> str:
3547
}
3648
</action>
3749
"""
50+
3851
@classmethod
3952
def parse_action(cls, llm_output: str) -> ToolCallAction:
40-
content_dict, valid, retry_message = parse_html_tags_raise(llm_output, keys=["action"])
41-
if not valid or "action" not in content_dict:
42-
raise ValueError(f"Invalid action: llm_output: {llm_output}, retry_message: {retry_message}")
43-
action_str = content_dict["action"]
44-
return ToolCallAction(function=FunctionCall(name=action_str["name"], arguments=action_str["arguments"]))
53+
logger.info(f"Parsing action: {llm_output}")
54+
if "<action>" in llm_output:
55+
content_dict, valid, retry_message = parse_html_tags_raise(llm_output, keys=["action"])
56+
if not valid or "action" not in content_dict:
57+
raise ValueError(
58+
f"Invalid action: llm_output: {llm_output}, retry_message: {retry_message}"
59+
)
60+
action_str = content_dict["action"]
61+
else:
62+
action_str = llm_output
63+
try:
64+
action_dict = json.loads(action_str)
65+
except json.JSONDecodeError:
66+
raise ValueError(f"Failed to parse action: {action_str}")
67+
return ToolCallAction(
68+
function=FunctionCall(name=action_dict["name"], arguments=action_dict["arguments"])
69+
)
4570

4671
def to_python_code(self, action) -> str:
47-
return action.model_dump_json(indent=2)
72+
return action

src/agentlab/agents/generic_agent/generic_agent.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
import bgym
1717
from bgym import Benchmark
1818
from browsergym.experiments.agent import Agent, AgentInfo
19+
from tapeagents.tool_calling import ToolSpec
1920

21+
from agentlab.actions import ToolsActionSet
2022
from agentlab.agents import dynamic_prompting as dp
2123
from agentlab.agents.agent_args import AgentArgs
2224
from agentlab.llm.chat_api import BaseModelArgs
@@ -65,9 +67,12 @@ def prepare(self):
6567
def close(self):
6668
return self.chat_model_args.close_server()
6769

68-
def make_agent(self):
70+
def make_agent(self, actions: list[ToolSpec] | None = None):
6971
return GenericAgent(
70-
chat_model_args=self.chat_model_args, flags=self.flags, max_retry=self.max_retry
72+
chat_model_args=self.chat_model_args,
73+
flags=self.flags,
74+
max_retry=self.max_retry,
75+
actions=actions,
7176
)
7277

7378

@@ -78,15 +83,21 @@ def __init__(
7883
chat_model_args: BaseModelArgs,
7984
flags: GenericPromptFlags,
8085
max_retry: int = 4,
86+
actions: list[ToolSpec] | None = None,
8187
):
8288

8389
self.chat_llm = chat_model_args.make_model()
8490
self.chat_model_args = chat_model_args
8591
self.max_retry = max_retry
8692

8793
self.flags = flags
88-
self.action_set = self.flags.action.action_set.make_action_set()
89-
self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs)
94+
if actions is not None:
95+
self.action_set = ToolsActionSet(actions=actions)
96+
self.flags.action.action_set = self.action_set
97+
self._obs_preprocessor = lambda obs: obs
98+
else:
99+
self.action_set = self.flags.action.action_set.make_action_set()
100+
self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs)
90101

91102
self._check_flag_constancy()
92103
self.reset(seed=None)

src/agentlab/backends/browser/base.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,23 @@ def run_js(self, js: str):
1313
def goto(self, url: str) -> str:
1414
raise NotImplementedError
1515

16+
def page_snapshot(self) -> str:
17+
raise NotImplementedError
18+
1619
def step(self, action: ToolCallAction) -> str:
1720
raise NotImplementedError
1821

1922
def actions(self) -> tuple[ToolSpec]:
2023
raise NotImplementedError
2124

25+
def close(self) -> None:
26+
raise NotImplementedError
2227

2328

2429
class MCPBrowserBackend(BrowserBackend):
2530
config_path: str
2631
_mcp = None
27-
32+
2833
def initialize(self) -> None:
2934
self._mcp = MCPEnvironment(config_path=self.config_path)
3035
self._mcp.initialize()
@@ -33,12 +38,17 @@ def step(self, action: ToolCallAction) -> str:
3338
return self._call_mcp(action)
3439

3540
def call_tool(self, tool_name: str, arguments: dict) -> str:
36-
return self._call_mcp(ToolCallAction(function=FunctionCall(name=tool_name, arguments=arguments)))
37-
41+
return self._call_mcp(
42+
ToolCallAction(function=FunctionCall(name=tool_name, arguments=arguments))
43+
)
44+
3845
def _call_mcp(self, action: ToolCallAction) -> str:
3946
tool_result = self._mcp.step(action)
4047
texts = [c.text for c in tool_result.content.content]
4148
return "\n\n".join(texts)
4249

4350
def actions(self) -> tuple[ToolSpec]:
44-
return self._mcp.actions()
51+
return self._mcp.actions()
52+
53+
def close(self) -> None:
54+
self._mcp.close()

src/agentlab/backends/browser/env.py

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import logging
23
import time
34
from dataclasses import dataclass
@@ -10,38 +11,60 @@
1011
from agentlab.actions import ToolsActionSet
1112
from agentlab.backends.browser.base import BrowserBackend
1213
from agentlab.benchmarks.abstract_env import AbstractEnv, AbstractEnvArgs
13-
from agentlab.benchmarks.miniwob.task import AbstractWebTask
14+
from agentlab.benchmarks.web_task import AbstractWebTask
1415

1516
logger = logging.getLogger(__name__)
1617

18+
1719
class GoalObservation(Observation):
1820
kind: Literal["goal_observation"] = "goal_observation"
1921
goal: str
2022

23+
2124
class PageObservation(Observation):
2225
kind: Literal["page_observation"] = "page_observation"
2326
content: str
2427

2528

2629
class BrowserEnv(AbstractEnv):
27-
def __init__(self, task_name: str, task: AbstractWebTask, backend: BrowserBackend, seed: int = 0):
30+
def __init__(
31+
self, task_name: str, task: AbstractWebTask, backend: BrowserBackend, seed: int = 0
32+
):
2833
self.task_name = task_name
2934
self.task = task
3035
self.seed = seed
3136
self._turns = 0
3237
self.max_turns = task.max_turns
3338
self.backend = backend
3439
self.backend.initialize()
40+
self.goal = ""
3541

3642
def reset(self, seed: int):
3743
self.seed = seed
3844
logger.info(f"Open task URL: {self.task.url}")
39-
page_content = self.backend.goto(self.task.url)
45+
self.backend.goto(self.task.url)
4046
setup_js = self.task.get_setup_js()
4147
if setup_js:
42-
js_result_str = self.backend.run_js(setup_js)
43-
logger.info(f"Task reset result: {js_result_str}")
44-
return [GoalObservation(goal=js_result_str), PageObservation(content=page_content)], {}
48+
js_out = self.backend.run_js(setup_js)
49+
out_dict = json.loads(js_out)
50+
logger.info(f"Task setup result: {out_dict}")
51+
goal = out_dict["goal"]
52+
done = out_dict["done"]
53+
task_start_time = out_dict["task_start_time"]
54+
logger.info(f"Task start time: {task_start_time}")
55+
if done:
56+
raise ValueError("Task is already done")
57+
self.goal = goal
58+
logger.info(f"Task goal: {self.goal}")
59+
page_content = self.backend.page_snapshot()
60+
logger.info(f"Initial obs: {page_content}")
61+
return {
62+
"goal_object": [{"type": "text", "text": self.goal}],
63+
"pruned_html": page_content,
64+
"axtree_txt": "",
65+
"last_action_error": "",
66+
"focused_element_bid": "none",
67+
}, {}
4568

4669
def step(self, action: ToolCallAction | str) -> tuple[Observation, float, bool, bool, dict]:
4770
if isinstance(action, str):
@@ -51,49 +74,67 @@ def step(self, action: ToolCallAction | str) -> tuple[Observation, float, bool,
5174
action_exec_start = time.time()
5275
finished = isinstance(action, StopStep)
5376
if finished:
54-
observation = Observation() # empty observation
77+
observation = {
78+
"goal_object": [{"type": "text", "text": self.goal}],
79+
"pruned_html": "Task finished",
80+
"axtree_txt": "",
81+
"last_action_error": "",
82+
"focused_element_bid": "none",
83+
}
5584
else:
5685
observation = self._step(action)
5786
action_exec_stop = time.time()
5887
self._turns += 1
88+
logger.info(f"Obs:\n{observation['pruned_html']}")
5989

6090
truncated = self._turns >= self.max_turns
6191

6292
if self.task.validate_per_step or finished or truncated:
63-
reward = self.calculate_reward(action, observation)
93+
reward, other = self.calculate_reward(action, observation)
94+
if other.get("done", False):
95+
finished = True
6496
else:
6597
reward = 0.0
98+
other = {}
6699

67100
env_info = {
68-
"step_metadata": observation.metadata,
69101
"action_exec_start": action_exec_start,
70102
"action_exec_stop": action_exec_stop,
71103
"action_exec_timeout": 0.0,
72-
}
104+
} | other
73105
obs_view = observation.short_view() if isinstance(observation, Observation) else observation
74106
logger.info(f"Action result in observation: {obs_view}")
75107
return observation, reward, finished, truncated, env_info
76108

77-
def _step(self, action: ToolCallAction) -> PageObservation:
109+
def _step(self, action: ToolCallAction) -> dict:
78110
tool_result = self.backend.step(action)
79-
return PageObservation(content=tool_result)
111+
return {
112+
"goal_object": [{"type": "text", "text": self.goal}],
113+
"pruned_html": tool_result,
114+
"axtree_txt": "",
115+
"last_action_error": "",
116+
"focused_element_bid": "none",
117+
}
80118

81-
def calculate_reward(self, action: Action, observation: PageObservation) -> float:
119+
def calculate_reward(self, action: Action, observation: PageObservation) -> tuple[float, dict]:
82120
validate_js = self.task.get_step_validate_js()
83121
validate_result = self.backend.run_js(validate_js)
84122
reward, other = self.task.parse_validation_result(validate_result)
85-
return reward
123+
return reward, other
86124

87125
def close(self):
88126
teardown_js = self.task.get_teardown_js()
89127
if teardown_js:
90128
js_result_str = self.backend.run_js(teardown_js)
91129
logger.info(f"Task teardown result: {js_result_str}")
130+
self.backend.close()
92131

93132
def actions(self) -> list[ToolSpec]:
94133
all_actions = self.backend.actions()
95134
filtered_actions = self.task.filter_actions(all_actions)
96-
logger.info(f"Filtered {len(filtered_actions)} actions out of {len(all_actions)} for task {self.task.dataset}")
135+
logger.info(
136+
f"Filtered {len(filtered_actions)} actions out of {len(all_actions)} for task {self.task.dataset}"
137+
)
97138
return filtered_actions
98139

99140

@@ -104,13 +145,16 @@ class BrowserEnvArgs(AbstractEnvArgs):
104145
task_name: str
105146
backend: BrowserBackend
106147

107-
def __init__(self, task_name: str, task: AbstractWebTask, backend: BrowserBackend, task_seed: int = 0):
148+
def __init__(
149+
self, task_name: str, task: AbstractWebTask, backend: BrowserBackend, task_seed: int = 0
150+
):
108151
self.task_name = task_name
109152
self.task = task
110153
self.task_seed = task_seed
111154
self.backend = backend
112155

113156
def make_env(self, exp_dir: Path) -> BrowserEnv:
114-
env = BrowserEnv(task_name=self.task_name, task=self.task, backend=self.backend, seed=self.task_seed)
157+
env = BrowserEnv(
158+
task_name=self.task_name, task=self.task, backend=self.backend, seed=self.task_seed
159+
)
115160
return env
116-
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .benchmark import MiniWobBenchmark
22
from .task import MiniWobTask
33

4-
__all__ = ["MiniWobBenchmark", "MiniWobTask"]
4+
__all__ = ["MiniWobBenchmark", "MiniWobTask"]

0 commit comments

Comments
 (0)