Skip to content

Commit c676eab

Browse files
committed
typo
1 parent 417893c commit c676eab

File tree

1 file changed

+41
-11
lines changed

1 file changed

+41
-11
lines changed

src/agentlab/experiments/study.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import uuid
77
from abc import ABC, abstractmethod
88
from concurrent.futures import ProcessPoolExecutor
9-
from dataclasses import dataclass
9+
from dataclasses import asdict, dataclass
1010
from datetime import datetime
1111
from multiprocessing import Manager, Pool, Queue
1212
from pathlib import Path
@@ -19,7 +19,11 @@
1919
from agentlab.analyze import inspect_results
2020
from agentlab.experiments import reproducibility_util as repro
2121
from agentlab.experiments.exp_utils import RESULTS_DIR, add_dependencies
22-
from agentlab.experiments.launch_exp import find_incomplete, non_dummy_count, run_experiments
22+
from agentlab.experiments.launch_exp import (
23+
find_incomplete,
24+
non_dummy_count,
25+
run_experiments,
26+
)
2327
from agentlab.experiments.loop import EnvArgs, ExpArgs
2428
from agentlab.experiments.multi_server import BaseServer
2529

@@ -238,6 +242,9 @@ def __post_init__(self):
238242
self.uuid = uuid.uuid4()
239243
if isinstance(self.benchmark, str):
240244
self.benchmark = DEFAULT_BENCHMARKS[self.benchmark.lower()]()
245+
246+
self.benchmark.env_args_list = _convert_env_args(self.benchmark.env_args_list)
247+
241248
if isinstance(self.dir, str):
242249
self.dir = Path(self.dir)
243250
self.make_exp_args_list()
@@ -324,28 +331,31 @@ def run(
324331
self._run(n_jobs, parallel_backend, strict_reproducibility)
325332

326333
suffix = f"trial_{i + 1}_of_{n_relaunch}"
327-
_, summary_df, _ = self.get_results(suffix=suffix)
334+
_, summary_df, error_report = self.get_results(suffix=suffix)
328335
logger.info("\n" + str(summary_df))
329336

330337
n_incomplete, n_error = self.find_incomplete(include_errors=relaunch_errors)
331338

332339
if n_error / n_exp > 0.3:
333-
logger.warning("More than 30% of the experiments errored. Stopping the study.")
334-
return
340+
logger.warning("More than 30% of the experiments errored. Stopping the retries.")
341+
break
335342

336343
if last_error_count is not None and n_error >= last_error_count:
337344
logger.warning(
338-
"Last trial did not reduce the number of errors. Stopping the study."
345+
"Last trial did not reduce the number of errors. Stopping the retries."
339346
)
340-
return
347+
break
341348

342349
if n_incomplete == 0:
343350
logger.info(f"Study {self.name} finished.")
344-
return
351+
break
345352

346-
logger.warning(
347-
f"Study {self.name} did not finish after {n_relaunch} trials. There are {n_incomplete} incomplete experiments."
348-
)
353+
logger.info("# Error Report:\n-------------\n\n" + error_report)
354+
355+
if n_incomplete != 0:
356+
logger.warning(
357+
f"Study {self.name} did not finish after {n_relaunch} trials. There are {n_incomplete} incomplete experiments."
358+
)
349359

350360
def _run(self, n_jobs=1, parallel_backend="joblib", strict_reproducibility=False):
351361
"""Run all experiments in the study in parallel when possible.
@@ -715,6 +725,26 @@ def set_demo_mode(env_args_list: list[EnvArgs]):
715725
env_args.slow_mo = 1000
716726

717727

728+
def _convert_env_args(env_args_list):
729+
"""Return a list where every element is the *new* EnvArgs.
730+
731+
For backward compatibility, we need to convert the old EnvArgs to the new one.
732+
"""
733+
from bgym import EnvArgs as BGymEnvArgs
734+
735+
new_list = []
736+
for ea in env_args_list:
737+
# already new → keep as‑is
738+
if isinstance(ea, EnvArgs):
739+
new_list.append(ea)
740+
# old → convert
741+
elif isinstance(ea, BGymEnvArgs):
742+
new_list.append(EnvArgs(**asdict(ea)))
743+
else:
744+
raise TypeError(f"Unexpected type: {type(ea)}")
745+
return new_list
746+
747+
718748
# def _flag_sequential_exp(exp_args_list: list[ExpArgs], benchmark: Benchmark):
719749
# if benchmark.name.startswith("visualwebarena"):
720750
# sequential_subset = benchmark.subset_from_glob("requires_reset", "True")

0 commit comments

Comments
 (0)