|
1 | 1 | import numpy as np |
2 | | - |
3 | | -from browsergym.experiments.benchmark.metadata.utils import task_list_from_metadata, task_metadata |
| 2 | +from browsergym.experiments.benchmark.metadata.utils import ( |
| 3 | + task_list_from_metadata, |
| 4 | + task_metadata, |
| 5 | +) |
4 | 6 | from browsergym.experiments.benchmark.utils import ( |
5 | 7 | make_env_args_list_from_fixed_seeds, |
6 | 8 | make_env_args_list_from_repeat_tasks, |
|
88 | 90 |
|
89 | 91 | # all benchmarks are callables designed for lazy loading, i.e. `bench = DEFAULT_BENCHMARKS["miniwob_all"]()` |
90 | 92 | DEFAULT_BENCHMARKS = { |
91 | | - "miniwob": lambda: Benchmark( |
| 93 | + "miniwob": lambda n_repeats=5: Benchmark( |
92 | 94 | name="miniwob", |
93 | 95 | high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["miniwob_all"], |
94 | 96 | is_multi_tab=False, |
|
97 | 99 | env_args_list=make_env_args_list_from_repeat_tasks( |
98 | 100 | task_list=task_list_from_metadata(metadata=task_metadata("miniwob")), |
99 | 101 | max_steps=10, |
100 | | - n_repeats=5, |
| 102 | + n_repeats=n_repeats, |
101 | 103 | seeds_rng=np.random.RandomState(42), |
102 | 104 | ), |
103 | 105 | task_metadata=task_metadata("miniwob"), |
104 | 106 | ), |
105 | | - "miniwob_tiny_test": lambda: Benchmark( |
| 107 | + "miniwob_tiny_test": lambda n_repeats=2: Benchmark( |
106 | 108 | name="miniwob_tiny_test", |
107 | 109 | high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["miniwob_all"], |
108 | 110 | is_multi_tab=False, |
|
111 | 113 | env_args_list=make_env_args_list_from_repeat_tasks( |
112 | 114 | task_list=["miniwob.click-dialog", "miniwob.click-checkboxes"], |
113 | 115 | max_steps=5, |
114 | | - n_repeats=2, |
| 116 | + n_repeats=n_repeats, |
115 | 117 | seeds_rng=np.random.RandomState(42), |
116 | 118 | ), |
117 | 119 | task_metadata=task_metadata("miniwob"), |
118 | 120 | ), |
119 | | - "webarena": lambda: Benchmark( |
| 121 | + "webarena": lambda n_repeats=1: Benchmark( |
120 | 122 | name="webarena", |
121 | 123 | high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["webarena"], |
122 | 124 | is_multi_tab=True, |
|
125 | 127 | env_args_list=make_env_args_list_from_repeat_tasks( |
126 | 128 | task_list=task_list_from_metadata(metadata=task_metadata("webarena")), |
127 | 129 | max_steps=30, |
128 | | - n_repeats=1, |
| 130 | + n_repeats=n_repeats, |
129 | 131 | seeds_rng=np.random.RandomState(42), |
130 | 132 | ), |
131 | 133 | task_metadata=task_metadata("webarena"), |
132 | 134 | ), |
133 | | - "webarena_tiny": lambda: Benchmark( |
| 135 | + "webarena_tiny": lambda n_repeats=1: Benchmark( |
134 | 136 | name="webarena_tiny", |
135 | 137 | high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["webarena"], |
136 | 138 | is_multi_tab=True, |
|
150 | 152 | ), |
151 | 153 | task_metadata=task_metadata("webarena"), |
152 | 154 | ), |
153 | | - "visualwebarena_tiny": lambda: Benchmark( |
| 155 | + "visualwebarena_tiny": lambda n_repeats=10: Benchmark( |
154 | 156 | name="visualwebarena_tiny", |
155 | 157 | high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["visualwebarena"], |
156 | 158 | is_multi_tab=True, |
|
168 | 170 | ), |
169 | 171 | task_metadata=task_metadata("visualwebarena"), |
170 | 172 | ), |
171 | | - "visualwebarena": lambda: Benchmark( |
| 173 | + "visualwebarena": lambda n_repeats=1: Benchmark( |
172 | 174 | name="visualwebarena", |
173 | 175 | high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["visualwebarena"], |
174 | 176 | is_multi_tab=True, |
|
177 | 179 | env_args_list=make_env_args_list_from_repeat_tasks( |
178 | 180 | task_list=task_list_from_metadata(metadata=task_metadata("visualwebarena")), |
179 | 181 | max_steps=30, |
180 | | - n_repeats=1, |
| 182 | + n_repeats=n_repeats, |
181 | 183 | seeds_rng=np.random.RandomState(42), |
182 | 184 | ), |
183 | 185 | task_metadata=task_metadata("visualwebarena"), |
184 | 186 | ), |
185 | | - "workarena_l1": lambda: Benchmark( |
| 187 | + "workarena_l1": lambda n_repeats=10: Benchmark( |
186 | 188 | name="workarena_l1", |
187 | 189 | high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["workarena"], |
188 | 190 | is_multi_tab=False, |
|
194 | 196 | meta_seed=42, # meta seed for evaluation curriculum |
195 | 197 | max_steps=15, |
196 | 198 | curriculum_type="agent", |
197 | | - seeds_l1=10, |
| 199 | + seeds_l1=n_repeats, |
198 | 200 | ), |
199 | 201 | task_metadata=task_metadata("workarena"), |
200 | 202 | ), |
201 | | - "workarena_l2_agent_curriculum_eval": lambda: Benchmark( |
| 203 | + "workarena_l2_agent_curriculum_eval": lambda n_repeats=1: Benchmark( |
202 | 204 | name="workarena_l2_agent_curriculum_eval", |
203 | 205 | high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["workarena++"], |
204 | 206 | is_multi_tab=True, |
|
213 | 215 | ), |
214 | 216 | task_metadata=task_metadata("workarena"), |
215 | 217 | ), |
216 | | - "workarena_l3_agent_curriculum_eval": lambda: Benchmark( |
| 218 | + "workarena_l3_agent_curriculum_eval": lambda n_repeats=1: Benchmark( |
217 | 219 | name="workarena_l3_agent_curriculum_eval", |
218 | 220 | high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["workarena++"], |
219 | 221 | is_multi_tab=True, |
|
228 | 230 | ), |
229 | 231 | task_metadata=task_metadata("workarena"), |
230 | 232 | ), |
231 | | - "assistantbench": lambda: Benchmark( |
| 233 | + "assistantbench": lambda n_repeats=1: Benchmark( |
232 | 234 | name="assistantbench", |
233 | 235 | high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["assistantbench"], |
234 | 236 | is_multi_tab=True, |
|
239 | 241 | metadata=task_metadata("assistantbench"), filter={"browsergym_split": "valid|test"} |
240 | 242 | ), |
241 | 243 | max_steps=30, |
242 | | - n_repeats=1, |
| 244 | + n_repeats=n_repeats, |
243 | 245 | seeds_rng=np.random.RandomState(42), |
244 | 246 | ), |
245 | 247 | task_metadata=task_metadata("assistantbench"), |
246 | 248 | ), |
247 | | - "weblinx": lambda: Benchmark( |
| 249 | + "weblinx": lambda n_repeats=1: Benchmark( |
248 | 250 | name="weblinx", |
249 | 251 | high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["weblinx"], |
250 | 252 | is_multi_tab=True, |
|
253 | 255 | env_args_list=make_env_args_list_from_repeat_tasks( |
254 | 256 | task_list=task_list_from_metadata(metadata=task_metadata("weblinx")), |
255 | 257 | max_steps=1, |
256 | | - n_repeats=1, |
| 258 | + n_repeats=n_repeats, |
257 | 259 | seeds_rng=np.random.RandomState(42), |
258 | 260 | ), |
259 | 261 | task_metadata=task_metadata("weblinx"), |
|
0 commit comments