Skip to content

Commit 6de7a02

Browse files
committed
adding benchmark stuff back in
1 parent 31c3224 commit 6de7a02

File tree

12 files changed

+35480
-0
lines changed

12 files changed

+35480
-0
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .base import Benchmark, HighLevelActionSetArgs
2+
from .configs import DEFAULT_BENCHMARKS
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
import fnmatch
2+
import logging
3+
import random
4+
import typing
5+
from dataclasses import dataclass, field
6+
from typing import Literal, Optional
7+
8+
import pandas as pd
9+
from browsergym.core.action.highlevel import HighLevelActionSet
10+
from dataclasses_json import DataClassJsonMixin, config
11+
12+
from agentlab.experiments.loop import EnvArgs
13+
14+
from .metadata.utils import (
15+
build_env_args_dependency_graphs,
16+
build_full_task_dependency_graph_from_metadata,
17+
extract_sparse_task_dependency_graph_from_subset,
18+
task_list_from_metadata,
19+
)
20+
from .utils import prepare_backend
21+
22+
logger = logging.getLogger(__name__)
23+
24+
25+
@dataclass
26+
class HighLevelActionSetArgs(DataClassJsonMixin):
27+
subsets: tuple[HighLevelActionSet.ActionSubset] = field(
28+
metadata=config(
29+
encoder=lambda x: list(x),
30+
decoder=lambda x: tuple(x),
31+
),
32+
)
33+
# custom_actions: list[callable] | None # non-serializable argument, not supported
34+
multiaction: bool = False
35+
strict: bool = False
36+
retry_with_force: bool = False
37+
demo_mode: Optional[Literal["off", "default", "all_blue", "only_visible_elements"]] = None
38+
39+
def __post_init__(self):
40+
if isinstance(self.subsets, list):
41+
"""Needs to be hashable for AgentLab when uniquely identifying agents."""
42+
self.subsets = tuple(self.subsets)
43+
44+
def make_action_set(self):
45+
return HighLevelActionSet(
46+
subsets=self.subsets,
47+
custom_actions=None,
48+
multiaction=self.multiaction,
49+
strict=self.strict,
50+
retry_with_force=self.retry_with_force,
51+
demo_mode=self.demo_mode,
52+
)
53+
54+
55+
BenchmarkBackend = Literal[
56+
"miniwob", "webarena", "visualwebarena", "workarena", "assistantbench", "weblinx"
57+
]
58+
59+
60+
@dataclass
61+
class Benchmark(DataClassJsonMixin):
62+
name: str
63+
high_level_action_set_args: HighLevelActionSetArgs
64+
is_multi_tab: bool
65+
supports_parallel_seeds: bool
66+
env_args_list: list[EnvArgs]
67+
backends: list[BenchmarkBackend]
68+
task_metadata: Optional[pd.DataFrame] = field(
69+
default_factory=lambda: None,
70+
metadata=config(
71+
encoder=lambda df: df.to_dict(orient="records") if df is not None else None,
72+
decoder=lambda items: pd.DataFrame(items) if items is not None else None,
73+
),
74+
)
75+
76+
def __post_init__(self):
77+
# if no metadata is present, generate a dataframe with single "task_name" column
78+
if self.task_metadata is None:
79+
unique_task_names = list(set([env_args.task_name for env_args in self.env_args_list]))
80+
self.task_metadata = pd.DataFrame(
81+
[{"task_name": task_name} for task_name in unique_task_names]
82+
)
83+
# make sure all tasks in env_args are in the metadata
84+
metadata_tasks = list(self.task_metadata["task_name"])
85+
assert all([env_args.task_name in metadata_tasks for env_args in self.env_args_list])
86+
# check backend values
87+
for backend in self.backends:
88+
if backend not in typing.get_args(BenchmarkBackend):
89+
raise ValueError(
90+
f"Unknown Benchmark backend {repr(backend)}. Available backends: {typing.get_args(BenchmarkBackend)}"
91+
)
92+
93+
def prepare_backends(self):
94+
for backend in self.backends:
95+
logger.info(f"Preparing {backend} backend...")
96+
prepare_backend(backend)
97+
logger.info(f"{backend} backend ready")
98+
99+
def subset_from_split(self, split: Literal["train", "valid", "test"]):
100+
split_column = "browsergym_split"
101+
102+
# check for a split column in metadata
103+
if split_column not in self.task_metadata.columns:
104+
raise NotImplementedError(
105+
f"This benchmark does not provide default train/valid/test splits (missing a {repr(split_column)} column in task metadata)"
106+
)
107+
108+
# recover the target split
109+
sub_benchmark = self.subset_from_regexp(split_column, regexp=f"^{split}$")
110+
sub_benchmark.name = f"{self.name}_{split}"
111+
112+
# check that the split exists (non-empty task list)
113+
if not sub_benchmark.env_args_list:
114+
raise ValueError(f"The default {split} split for this benchmark is empty.")
115+
116+
return sub_benchmark
117+
118+
def subset_from_list(
119+
self,
120+
task_list: list[str],
121+
benchmark_name_suffix: Optional[str] = "custom",
122+
split: Optional[str] = None,
123+
):
124+
"""Create a sub-benchmark containing only the specified tasks.
125+
126+
Args:
127+
task_list: List of task names to include in the sub-benchmark.
128+
benchmark_name_suffix: Optional suffix to append to the benchmark name. Defaults to "custom".
129+
split: Optional split name to append to the benchmark name. Useful for organization.
130+
131+
Returns:
132+
Benchmark: A new benchmark instance containing only the specified tasks.
133+
134+
Raises:
135+
ValueError: If the resulting task list is empty or if any specified task doesn't exist.
136+
"""
137+
if not task_list:
138+
raise ValueError("Task list cannot be empty")
139+
140+
# Convert task_list to set for more efficient lookups
141+
task_set = set(task_list)
142+
143+
# Validate that all requested tasks exist in the original benchmark
144+
existing_tasks = {env_args.task_name for env_args in self.env_args_list}
145+
invalid_tasks = task_set - existing_tasks
146+
if invalid_tasks:
147+
raise ValueError(f"The following tasks do not exist in the benchmark: {invalid_tasks}")
148+
149+
name = f"{self.name}_{benchmark_name_suffix}"
150+
if split:
151+
name += f"_{split}"
152+
153+
sub_benchmark = Benchmark(
154+
name=name,
155+
high_level_action_set_args=self.high_level_action_set_args,
156+
is_multi_tab=self.is_multi_tab,
157+
supports_parallel_seeds=self.supports_parallel_seeds,
158+
backends=self.backends,
159+
env_args_list=[
160+
env_args for env_args in self.env_args_list if env_args.task_name in task_set
161+
],
162+
task_metadata=self.task_metadata,
163+
)
164+
165+
# This check is redundant now due to the validation above, but kept for safety
166+
if not sub_benchmark.env_args_list:
167+
raise ValueError(
168+
f"The custom {split if split else ''} split for this benchmark is empty."
169+
)
170+
171+
return sub_benchmark
172+
173+
def subset_from_glob(self, column, glob):
174+
subset = self.subset_from_regexp(column, regexp=fnmatch.translate(glob))
175+
subset.name = f"{self.name}[{column}={glob}]"
176+
return subset
177+
178+
def subset_from_regexp(self, column, regexp):
179+
# extract the filtered task_name subset
180+
task_name_subset = task_list_from_metadata(self.task_metadata, {column: regexp})
181+
182+
# return the sub benchmark
183+
return Benchmark(
184+
name=f"{self.name}[{column}=/{regexp}/]",
185+
high_level_action_set_args=self.high_level_action_set_args,
186+
is_multi_tab=self.is_multi_tab,
187+
supports_parallel_seeds=self.supports_parallel_seeds,
188+
backends=self.backends,
189+
env_args_list=[
190+
env_args
191+
for env_args in self.env_args_list
192+
if env_args.task_name in task_name_subset
193+
],
194+
task_metadata=self.task_metadata,
195+
)
196+
197+
def subset_from_task_ratio(self, ratio, seed):
198+
"""Get a random subset of the tasks given a ratio and seed."""
199+
rng = random.Random(seed)
200+
task_names = list(set([env_args.task_name for env_args in self.env_args_list]))
201+
rng.shuffle(task_names)
202+
num_tasks = int(len(task_names) * ratio)
203+
task_name_subset = task_names[:num_tasks]
204+
205+
return Benchmark(
206+
name=f"{self.name}[ratio={ratio}, seed={seed}]",
207+
high_level_action_set_args=self.high_level_action_set_args,
208+
is_multi_tab=self.is_multi_tab,
209+
supports_parallel_seeds=self.supports_parallel_seeds,
210+
backends=self.backends,
211+
env_args_list=[
212+
env_args
213+
for env_args in self.env_args_list
214+
if env_args.task_name in task_name_subset
215+
],
216+
task_metadata=self.task_metadata,
217+
)
218+
219+
def dependency_graph_over_tasks(self) -> dict[str, list[str]]:
220+
# recover all unique task_names present in the benchmark
221+
task_names = list(set([env_args.task_name for env_args in self.env_args_list]))
222+
223+
# if "depends_on" column is missing, raise a warning and deal with it
224+
# (we don't want the "depends_on" column to be mandatory)
225+
if "depends_on" not in self.task_metadata.columns:
226+
logger.warning(
227+
f'This benchmark does not provide a dependency graph (missing a "depends_on" column in task metadata). Assuming no task dependencies.'
228+
)
229+
zero_dependencies = {task_name: [] for task_name in task_names}
230+
return zero_dependencies
231+
232+
# recover the task dependency graph, for tasks in the benchmark only
233+
task_dependencies = extract_sparse_task_dependency_graph_from_subset(
234+
task_subset=task_names,
235+
parents=build_full_task_dependency_graph_from_metadata(
236+
task_metadata=self.task_metadata
237+
),
238+
)
239+
240+
return task_dependencies
241+
242+
def dependency_graphs_over_env_args(self) -> list[dict[str, list[str]]]:
243+
"""
244+
Returns a list of dependency graphs to be executed sequentially, typically with a full instance reset in-between.
245+
Ideally, a job scheduler should connect these graphs by injecting a reset task in-between each, which depends on all previous tasks being completed.
246+
"""
247+
task_dependencies = self.dependency_graph_over_tasks()
248+
env_args_dependencies = build_env_args_dependency_graphs(
249+
env_args_list=self.env_args_list,
250+
task_dependencies=task_dependencies,
251+
supports_parallel_seeds=self.supports_parallel_seeds,
252+
)
253+
return env_args_dependencies

0 commit comments

Comments
 (0)