|
6 | 6 | import uuid |
7 | 7 | from abc import ABC, abstractmethod |
8 | 8 | from concurrent.futures import ProcessPoolExecutor |
9 | | -from dataclasses import dataclass |
| 9 | +from dataclasses import asdict, dataclass |
10 | 10 | from datetime import datetime |
11 | 11 | from multiprocessing import Manager, Pool, Queue |
12 | 12 | from pathlib import Path |
|
19 | 19 | from agentlab.analyze import inspect_results |
20 | 20 | from agentlab.experiments import reproducibility_util as repro |
21 | 21 | from agentlab.experiments.exp_utils import RESULTS_DIR, add_dependencies |
22 | | -from agentlab.experiments.launch_exp import find_incomplete, non_dummy_count, run_experiments |
| 22 | +from agentlab.experiments.launch_exp import ( |
| 23 | + find_incomplete, |
| 24 | + non_dummy_count, |
| 25 | + run_experiments, |
| 26 | +) |
23 | 27 | from agentlab.experiments.loop import EnvArgs, ExpArgs |
24 | 28 | from agentlab.experiments.multi_server import BaseServer |
25 | 29 |
|
@@ -238,6 +242,9 @@ def __post_init__(self): |
238 | 242 | self.uuid = uuid.uuid4() |
239 | 243 | if isinstance(self.benchmark, str): |
240 | 244 | self.benchmark = DEFAULT_BENCHMARKS[self.benchmark.lower()]() |
| 245 | + |
| 246 | + self.benchmark.env_args_list = _convert_env_args(self.benchmark.env_args_list) |
| 247 | + |
241 | 248 | if isinstance(self.dir, str): |
242 | 249 | self.dir = Path(self.dir) |
243 | 250 | self.make_exp_args_list() |
@@ -324,28 +331,31 @@ def run( |
324 | 331 | self._run(n_jobs, parallel_backend, strict_reproducibility) |
325 | 332 |
|
326 | 333 | suffix = f"trial_{i + 1}_of_{n_relaunch}" |
327 | | - _, summary_df, _ = self.get_results(suffix=suffix) |
| 334 | + _, summary_df, error_report = self.get_results(suffix=suffix) |
328 | 335 | logger.info("\n" + str(summary_df)) |
329 | 336 |
|
330 | 337 | n_incomplete, n_error = self.find_incomplete(include_errors=relaunch_errors) |
331 | 338 |
|
332 | 339 | if n_error / n_exp > 0.3: |
333 | | - logger.warning("More than 30% of the experiments errored. Stopping the study.") |
334 | | - return |
| 340 | + logger.warning("More than 30% of the experiments errored. Stopping the retries.") |
| 341 | + break |
335 | 342 |
|
336 | 343 | if last_error_count is not None and n_error >= last_error_count: |
337 | 344 | logger.warning( |
338 | | - "Last trial did not reduce the number of errors. Stopping the study." |
| 345 | + "Last trial did not reduce the number of errors. Stopping the retries." |
339 | 346 | ) |
340 | | - return |
| 347 | + break |
341 | 348 |
|
342 | 349 | if n_incomplete == 0: |
343 | 350 | logger.info(f"Study {self.name} finished.") |
344 | | - return |
| 351 | + break |
345 | 352 |
|
346 | | - logger.warning( |
347 | | - f"Study {self.name} did not finish after {n_relaunch} trials. There are {n_incomplete} incomplete experiments." |
348 | | - ) |
| 353 | + logger.info("# Error Report:\n-------------\n\n" + error_report) |
| 354 | + |
| 355 | + if n_incomplete != 0: |
| 356 | + logger.warning( |
| 357 | + f"Study {self.name} did not finish after {n_relaunch} trials. There are {n_incomplete} incomplete experiments." |
| 358 | + ) |
349 | 359 |
|
350 | 360 | def _run(self, n_jobs=1, parallel_backend="joblib", strict_reproducibility=False): |
351 | 361 | """Run all experiments in the study in parallel when possible. |
@@ -715,6 +725,26 @@ def set_demo_mode(env_args_list: list[EnvArgs]): |
715 | 725 | env_args.slow_mo = 1000 |
716 | 726 |
|
717 | 727 |
|
| 728 | +def _convert_env_args(env_args_list): |
| 729 | + """Return a list where every element is the *new* EnvArgs. |
| 730 | +
|
| 731 | + For backward compatibility, we need to convert the old EnvArgs to the new one. |
| 732 | + """ |
| 733 | + from bgym import EnvArgs as BGymEnvArgs |
| 734 | + |
| 735 | + new_list = [] |
| 736 | + for ea in env_args_list: |
| 737 | + # already new → keep as‑is |
| 738 | + if isinstance(ea, EnvArgs): |
| 739 | + new_list.append(ea) |
| 740 | + # old → convert |
| 741 | + elif isinstance(ea, BGymEnvArgs): |
| 742 | + new_list.append(EnvArgs(**asdict(ea))) |
| 743 | + else: |
| 744 | + raise TypeError(f"Unexpected type: {type(ea)}") |
| 745 | + return new_list |
| 746 | + |
| 747 | + |
718 | 748 | # def _flag_sequential_exp(exp_args_list: list[ExpArgs], benchmark: Benchmark): |
719 | 749 | # if benchmark.name.startswith("visualwebarena"): |
720 | 750 | # sequential_subset = benchmark.subset_from_glob("requires_reset", "True") |
|
0 commit comments