Skip to content

Commit 9ee2367

Browse files
committed
moving the browsergym.experiment.benchmark module to agentlab
1 parent cad0629 commit 9ee2367

25 files changed

+35541
-55
lines changed

src/agentlab/agents/agent_args.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import bgym
22
from bgym import AbstractAgentArgs
33

4+
from agentlab.experiments.benchmark import Benchmark
5+
46

57
class AgentArgs(AbstractAgentArgs):
68
"""Base class for agent arguments for instantiating an agent.
@@ -14,7 +16,7 @@ class MyAgentArgs(AgentArgs):
1416
Note: for working properly with AgentXRay, the arguments need to be serializable and hasable.
1517
"""
1618

17-
def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode: bool):
19+
def set_benchmark(self, benchmark: Benchmark, demo_mode: bool):
1820
"""Optional method to set benchmark specific flags.
1921
2022
This allows the agent to have minor adjustments based on the benchmark.

src/agentlab/agents/dynamic_prompting.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,9 @@
1010

1111
import bgym
1212
from browsergym.core.action.base import AbstractActionSet
13-
from browsergym.utils.obs import (
14-
flatten_axtree_to_str,
15-
flatten_dom_to_str,
16-
overlay_som,
17-
prune_html,
18-
)
13+
from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, overlay_som, prune_html
1914

15+
from agentlab.experiments.benchmark import HighLevelActionSetArgs
2016
from agentlab.llm.llm_utils import (
2117
BaseMessage,
2218
ParseError,
@@ -99,7 +95,7 @@ class ObsFlags(Flags):
9995

10096
@dataclass
10197
class ActionFlags(Flags):
102-
action_set: bgym.HighLevelActionSetArgs = None # should be set by the set_benchmark method
98+
action_set: HighLevelActionSetArgs = None # should be set by the set_benchmark method
10399
long_description: bool = True
104100
individual_examples: bool = False
105101

src/agentlab/agents/generic_agent/agent_configs.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from agentlab.agents import dynamic_prompting as dp
88
from agentlab.experiments import args
9+
from agentlab.experiments.benchmark import HighLevelActionSetArgs
910
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT
1011

1112
from .generic_agent import GenericAgentArgs
@@ -31,7 +32,7 @@
3132
filter_visible_elements_only=False,
3233
),
3334
action=dp.ActionFlags(
34-
action_set=bgym.HighLevelActionSetArgs(
35+
action_set=HighLevelActionSetArgs(
3536
subsets=["bid"],
3637
multiaction=False,
3738
),
@@ -79,7 +80,7 @@
7980
filter_visible_elements_only=False,
8081
),
8182
action=dp.ActionFlags(
82-
action_set=bgym.HighLevelActionSetArgs(
83+
action_set=HighLevelActionSetArgs(
8384
subsets=["bid"],
8485
multiaction=False,
8586
),
@@ -126,7 +127,7 @@
126127
filter_visible_elements_only=False,
127128
),
128129
action=dp.ActionFlags(
129-
action_set=bgym.HighLevelActionSetArgs(
130+
action_set=HighLevelActionSetArgs(
130131
subsets=["bid"],
131132
multiaction=False,
132133
),
@@ -176,7 +177,7 @@
176177
filter_visible_elements_only=False,
177178
),
178179
action=dp.ActionFlags(
179-
action_set=bgym.HighLevelActionSetArgs(
180+
action_set=HighLevelActionSetArgs(
180181
subsets=["bid"],
181182
multiaction=True,
182183
),
@@ -231,7 +232,7 @@
231232
filter_visible_elements_only=False,
232233
),
233234
action=dp.ActionFlags(
234-
action_set=bgym.HighLevelActionSetArgs(
235+
action_set=HighLevelActionSetArgs(
235236
subsets=["bid"],
236237
multiaction=False,
237238
),
@@ -319,7 +320,7 @@
319320
filter_visible_elements_only=args.Choice([True, False], p=[0.3, 0.7]),
320321
),
321322
action=dp.ActionFlags(
322-
action_set=bgym.HighLevelActionSetArgs(
323+
action_set=HighLevelActionSetArgs(
323324
subsets=args.Choice([["bid"], ["bid", "coord"]]),
324325
multiaction=args.Choice([True, False], p=[0.7, 0.3]),
325326
),

src/agentlab/agents/generic_agent/generic_agent.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,20 @@
1010

1111
from copy import deepcopy
1212
from dataclasses import asdict, dataclass
13+
from functools import partial
1314
from warnings import warn
1415

1516
import bgym
1617
from browsergym.experiments.agent import Agent, AgentInfo
1718

1819
from agentlab.agents import dynamic_prompting as dp
1920
from agentlab.agents.agent_args import AgentArgs
21+
from agentlab.experiments.benchmark import Benchmark
2022
from agentlab.llm.chat_api import BaseModelArgs
2123
from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry
2224
from agentlab.llm.tracking import cost_tracker_decorator
2325

2426
from .generic_agent_prompt import GenericPromptFlags, MainPrompt
25-
from functools import partial
2627

2728

2829
@dataclass
@@ -37,7 +38,7 @@ def __post_init__(self):
3738
except AttributeError:
3839
pass
3940

40-
def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode):
41+
def set_benchmark(self, benchmark: Benchmark, demo_mode):
4142
"""Override Some flags based on the benchmark."""
4243
if benchmark.name.startswith("miniwob"):
4344
self.flags.obs.use_html = True

src/agentlab/agents/generic_agent/reproducibility_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from bs4 import BeautifulSoup
2424

2525
from agentlab.agents.agent_args import AgentArgs
26+
from agentlab.experiments.benchmark import HighLevelActionSetArgs
2627
from agentlab.experiments.loop import ExpArgs, ExpResult, yield_all_exp_results
2728
from agentlab.experiments.study import Study
2829
from agentlab.llm.chat_api import make_assistant_message
@@ -144,7 +145,7 @@ def _make_backward_compatible(agent_args: GenericAgentArgs):
144145
if isinstance(action_set, str):
145146
action_set = action_set.split("+")
146147

147-
agent_args.flags.action.action_set = bgym.HighLevelActionSetArgs(
148+
agent_args.flags.action.action_set = HighLevelActionSetArgs(
148149
subsets=action_set,
149150
multiaction=agent_args.flags.action.multi_actions,
150151
)

src/agentlab/agents/visual_agent/agent_configs.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import bgym
2+
3+
import agentlab.agents.dynamic_prompting as dp
4+
from agentlab.experiments.benchmark import HighLevelActionSetArgs
15
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT
26

37
from .visual_agent import VisualAgentArgs
48
from .visual_agent_prompts import PromptFlags
5-
import agentlab.agents.dynamic_prompting as dp
6-
import bgym
79

810
# the other flags are ignored for this agent.
911
DEFAULT_OBS_FLAGS = dp.ObsFlags(
@@ -16,7 +18,7 @@
1618
)
1719

1820
DEFAULT_ACTION_FLAGS = dp.ActionFlags(
19-
action_set=bgym.HighLevelActionSetArgs(subsets=["coord"]),
21+
action_set=HighLevelActionSetArgs(subsets=["coord"]),
2022
long_description=True,
2123
individual_examples=False,
2224
)

src/agentlab/agents/visual_agent/visual_agent.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515

1616
from agentlab.agents import dynamic_prompting as dp
1717
from agentlab.agents.agent_args import AgentArgs
18+
from agentlab.experiments.benchmark import Benchmark
1819
from agentlab.llm.chat_api import BaseModelArgs
1920
from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry
2021
from agentlab.llm.tracking import cost_tracker_decorator
2122

22-
from .visual_agent_prompts import PromptFlags, MainPrompt
23+
from .visual_agent_prompts import MainPrompt, PromptFlags
2324

2425

2526
@dataclass
@@ -34,7 +35,7 @@ def __post_init__(self):
3435
except AttributeError:
3536
pass
3637

37-
def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode):
38+
def set_benchmark(self, benchmark: Benchmark, demo_mode):
3839
"""Override Some flags based on the benchmark."""
3940
self.flags.obs.use_tabs = benchmark.is_multi_tab
4041

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .base import Benchmark, HighLevelActionSetArgs
2+
from .configs import DEFAULT_BENCHMARKS

0 commit comments

Comments
 (0)