@@ -96,57 +96,74 @@ def prepare_backends(self):
9696 prepare_backend (backend )
9797 logger .info (f"{ backend } backend ready" )
9898
99- def subset_from_split (
99+ def subset_from_split (self , split : Literal ["train" , "valid" , "test" ]):
100+ split_column = "browsergym_split"
101+
102+ # check for a split column in metadata
103+ if split_column not in self .task_metadata .columns :
104+ raise NotImplementedError (
105+ f"This benchmark does not provide default train/valid/test splits (missing a { repr (split_column )} column in task metadata)"
106+ )
107+
108+ # recover the target split
109+ sub_benchmark = self .subset_from_regexp (split_column , regexp = f"^{ split } $" )
110+ sub_benchmark .name = f"{ self .name } _{ split } "
111+
112+ # check that the split exists (non-empty task list)
113+ if not sub_benchmark .env_args_list :
114+ raise ValueError (f"The default { split } split for this benchmark is empty." )
115+
116+ return sub_benchmark
117+
118+ def subset_from_list (
100119 self ,
101- split : Literal ["train" , "valid" , "test" ],
102- task_splits : Optional [dict [str , list [str ]]] = None ,
120+ task_list : list [str ],
103121 benchmark_name_suffix : Optional [str ] = "custom" ,
122+ split : Optional [str ] = None ,
104123 ):
105- """Create a subset of the benchmark containing only tasks from the specified split .
124+ """Create a sub- benchmark containing only the specified tasks .
106125
107126 Args:
108- split: The split to filter for ("train", "valid", or "test")
109- task_splits: Optional dictionary mapping splits to lists of task names.
110- Example: {"train": ["task1", "task2"], "valid": ["task3", "task4"], "test": ["task5", "task6"]}
111- benchmark_name_suffix: Optional suffix to append to the new benchmark name
127+ task_list: List of task names to include in the sub-benchmark.
128+ benchmark_name_suffix: Optional suffix to append to the benchmark name. Defaults to "custom".
129+ split: Optional split name to append to the benchmark name. Useful for organization.
112130
113131 Returns:
114- A new Benchmark instance containing only tasks from the specified split .
132+ Benchmark: A new benchmark instance containing only the specified tasks .
115133
116134 Raises:
117- NotImplementedError: If task_splits is None and the metadata has no 'browsergym_split' column
118- ValueError: If the resulting split would be empty
135+ ValueError: If the resulting task list is empty or if any specified task doesn't exist.
119136 """
120- if task_splits is not None :
121-
122- sub_benchmark = Benchmark (
123- name = f"{ self .name } _{ benchmark_name_suffix } _{ split } " ,
124- high_level_action_set_args = self .high_level_action_set_args ,
125- is_multi_tab = self .is_multi_tab ,
126- supports_parallel_seeds = self .supports_parallel_seeds ,
127- backends = self .backends ,
128- env_args_list = [
129- env_args
130- for env_args in self .env_args_list
131- if env_args .task_name in task_splits [split ]
132- ],
133- task_metadata = self .task_metadata ,
134- )
135- else :
136- split_column = "browsergym_split"
137- # check for a split column in metadata
138- if split_column not in self .task_metadata .columns :
139- raise NotImplementedError (
140- f"This benchmark does not provide default train/valid/test splits (missing a { repr (split_column )} column in task metadata)"
141- )
137+ if not task_list :
138+ raise ValueError ("Task list cannot be empty" )
142139
143- # recover the target split
144- sub_benchmark = self .subset_from_regexp (split_column , regexp = f"^{ split } $" )
145- sub_benchmark .name = f"{ self .name } _{ split } "
140+ # Validate that all requested tasks exist in the original benchmark
141+ existing_tasks = {env_args .task_name for env_args in self .env_args_list }
142+ invalid_tasks = set (task_list ) - existing_tasks
143+ if invalid_tasks :
144+ raise ValueError (f"The following tasks do not exist in the benchmark: { invalid_tasks } " )
146145
147- # check that the split exists (non-empty task list)
146+ name = f"{ self .name } _{ benchmark_name_suffix } "
147+ if split :
148+ name += f"_{ split } "
149+
150+ sub_benchmark = Benchmark (
151+ name = name ,
152+ high_level_action_set_args = self .high_level_action_set_args ,
153+ is_multi_tab = self .is_multi_tab ,
154+ supports_parallel_seeds = self .supports_parallel_seeds ,
155+ backends = self .backends ,
156+ env_args_list = [
157+ env_args for env_args in self .env_args_list if env_args .task_name in task_list
158+ ],
159+ task_metadata = self .task_metadata ,
160+ )
161+
162+ # This check is redundant now due to the validation above, but kept for safety
148163 if not sub_benchmark .env_args_list :
149- raise ValueError (f"The { split } split for this benchmark is empty." )
164+ raise ValueError (
165+ f"The custom { split if split else '' } split for this benchmark is empty."
166+ )
150167
151168 return sub_benchmark
152169
0 commit comments