Skip to content

Commit 945a50e

Browse files
committed
improved
1 parent 75c3092 commit 945a50e

File tree

1 file changed

+55
-38
lines changed
  • browsergym/experiments/src/browsergym/experiments/benchmark

1 file changed

+55
-38
lines changed

browsergym/experiments/src/browsergym/experiments/benchmark/base.py

Lines changed: 55 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -96,57 +96,74 @@ def prepare_backends(self):
9696
prepare_backend(backend)
9797
logger.info(f"{backend} backend ready")
9898

99-
def subset_from_split(
99+
def subset_from_split(self, split: Literal["train", "valid", "test"]):
100+
split_column = "browsergym_split"
101+
102+
# check for a split column in metadata
103+
if split_column not in self.task_metadata.columns:
104+
raise NotImplementedError(
105+
f"This benchmark does not provide default train/valid/test splits (missing a {repr(split_column)} column in task metadata)"
106+
)
107+
108+
# recover the target split
109+
sub_benchmark = self.subset_from_regexp(split_column, regexp=f"^{split}$")
110+
sub_benchmark.name = f"{self.name}_{split}"
111+
112+
# check that the split exists (non-empty task list)
113+
if not sub_benchmark.env_args_list:
114+
raise ValueError(f"The default {split} split for this benchmark is empty.")
115+
116+
return sub_benchmark
117+
118+
def subset_from_list(
100119
self,
101-
split: Literal["train", "valid", "test"],
102-
task_splits: Optional[dict[str, list[str]]] = None,
120+
task_list: list[str],
103121
benchmark_name_suffix: Optional[str] = "custom",
122+
split: Optional[str] = None,
104123
):
105-
"""Create a subset of the benchmark containing only tasks from the specified split.
124+
"""Create a sub-benchmark containing only the specified tasks.
106125
107126
Args:
108-
split: The split to filter for ("train", "valid", or "test")
109-
task_splits: Optional dictionary mapping splits to lists of task names.
110-
Example: {"train": ["task1", "task2"], "valid": ["task3", "task4"], "test": ["task5", "task6"]}
111-
benchmark_name_suffix: Optional suffix to append to the new benchmark name
127+
task_list: List of task names to include in the sub-benchmark.
128+
benchmark_name_suffix: Optional suffix to append to the benchmark name. Defaults to "custom".
129+
split: Optional split name to append to the benchmark name. Useful for organization.
112130
113131
Returns:
114-
A new Benchmark instance containing only tasks from the specified split.
132+
Benchmark: A new benchmark instance containing only the specified tasks.
115133
116134
Raises:
117-
NotImplementedError: If task_splits is None and the metadata has no 'browsergym_split' column
118-
ValueError: If the resulting split would be empty
135+
ValueError: If the resulting task list is empty or if any specified task doesn't exist.
119136
"""
120-
if task_splits is not None:
121-
122-
sub_benchmark = Benchmark(
123-
name=f"{self.name}_{benchmark_name_suffix}_{split}",
124-
high_level_action_set_args=self.high_level_action_set_args,
125-
is_multi_tab=self.is_multi_tab,
126-
supports_parallel_seeds=self.supports_parallel_seeds,
127-
backends=self.backends,
128-
env_args_list=[
129-
env_args
130-
for env_args in self.env_args_list
131-
if env_args.task_name in task_splits[split]
132-
],
133-
task_metadata=self.task_metadata,
134-
)
135-
else:
136-
split_column = "browsergym_split"
137-
# check for a split column in metadata
138-
if split_column not in self.task_metadata.columns:
139-
raise NotImplementedError(
140-
f"This benchmark does not provide default train/valid/test splits (missing a {repr(split_column)} column in task metadata)"
141-
)
137+
if not task_list:
138+
raise ValueError("Task list cannot be empty")
142139

143-
# recover the target split
144-
sub_benchmark = self.subset_from_regexp(split_column, regexp=f"^{split}$")
145-
sub_benchmark.name = f"{self.name}_{split}"
140+
# Validate that all requested tasks exist in the original benchmark
141+
existing_tasks = {env_args.task_name for env_args in self.env_args_list}
142+
invalid_tasks = set(task_list) - existing_tasks
143+
if invalid_tasks:
144+
raise ValueError(f"The following tasks do not exist in the benchmark: {invalid_tasks}")
146145

147-
# check that the split exists (non-empty task list)
146+
name = f"{self.name}_{benchmark_name_suffix}"
147+
if split:
148+
name += f"_{split}"
149+
150+
sub_benchmark = Benchmark(
151+
name=name,
152+
high_level_action_set_args=self.high_level_action_set_args,
153+
is_multi_tab=self.is_multi_tab,
154+
supports_parallel_seeds=self.supports_parallel_seeds,
155+
backends=self.backends,
156+
env_args_list=[
157+
env_args for env_args in self.env_args_list if env_args.task_name in task_list
158+
],
159+
task_metadata=self.task_metadata,
160+
)
161+
162+
# This check is redundant now due to the validation above, but kept for safety
148163
if not sub_benchmark.env_args_list:
149-
raise ValueError(f"The {split} split for this benchmark is empty.")
164+
raise ValueError(
165+
f"The custom {split if split else ''} split for this benchmark is empty."
166+
)
150167

151168
return sub_benchmark
152169

0 commit comments

Comments
 (0)