|
| 1 | +from abc import ABC, abstractmethod |
1 | 2 | import gzip |
2 | 3 | import logging |
3 | 4 | import pickle |
4 | | -import re |
5 | 5 | import uuid |
6 | 6 | from abc import ABC, abstractmethod |
7 | 7 | from dataclasses import dataclass |
|
14 | 14 |
|
15 | 15 | from agentlab.agents.agent_args import AgentArgs |
16 | 16 | from agentlab.analyze import inspect_results |
17 | | -from agentlab.experiments import args |
18 | 17 | from agentlab.experiments import reproducibility_util as repro |
19 | 18 | from agentlab.experiments.exp_utils import RESULTS_DIR, add_dependencies |
20 | 19 | from agentlab.experiments.launch_exp import find_incomplete, non_dummy_count, run_experiments |
21 | 20 |
|
| 21 | + |
22 | 22 | logger = logging.getLogger(__name__) |
23 | 23 |
|
24 | 24 |
|
| 25 | +def make_study( |
| 26 | + agent_args: list[AgentArgs], |
| 27 | + benchmark: bgym.Benchmark, |
| 28 | + logging_level_stdout=logging.WARNING, |
| 29 | + suffix="", |
| 30 | + comment=None, |
| 31 | + ignore_dependencies=False, |
| 32 | +): |
| 33 | + |
| 34 | + if isinstance(benchmark, str): |
| 35 | + benchmark = bgym.DEFAULT_BENCHMARKS[benchmark]() |
| 36 | + |
| 37 | + """Make a study from a list of agents and a benchmark.""" |
| 38 | + if "webarena" in benchmark.name and len(agent_args) > 1: |
| 39 | + logger.warning( |
| 40 | + "*WebArena* requires manual reset after each evaluation. Running through SequentialStudies." |
| 41 | + ) |
| 42 | + studies = [] |
| 43 | + for agent in agent_args: |
| 44 | + studies.append( |
| 45 | + Study( |
| 46 | + [agent], |
| 47 | + benchmark, |
| 48 | + logging_level=logging_level_stdout, |
| 49 | + suffix=suffix, |
| 50 | + comment=comment, |
| 51 | + ignore_dependencies=ignore_dependencies, |
| 52 | + ) |
| 53 | + ) |
| 54 | + |
| 55 | + return SequentialStudies(studies) |
| 56 | + else: |
| 57 | + return Study( |
| 58 | + agent_args, |
| 59 | + benchmark, |
| 60 | + logging_level=logging_level_stdout, |
| 61 | + suffix=suffix, |
| 62 | + comment=comment, |
| 63 | + ignore_dependencies=ignore_dependencies, |
| 64 | + ) |
| 65 | + |
| 66 | + |
| 67 | +class AbstractStudy(ABC): |
| 68 | + dir: Path = None |
| 69 | + suffix: str = "" |
| 70 | + |
| 71 | + @abstractmethod |
| 72 | + def find_incomplete(self, include_errors=True): |
| 73 | + """Search for missing""" |
| 74 | + |
| 75 | + @abstractmethod |
| 76 | + def run(self, n_jobs=1, parallel_backend="ray", strict_reproducibility=False, n_relaunch=3): |
| 77 | + """Run the study""" |
| 78 | + |
| 79 | + def make_dir(self, exp_root=RESULTS_DIR): |
| 80 | + if self.dir is None: |
| 81 | + dir_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{self.name}" |
| 82 | + |
| 83 | + self.dir = Path(exp_root) / dir_name |
| 84 | + self.dir.mkdir(parents=True, exist_ok=True) |
| 85 | + |
| 86 | + def save(self, exp_root=RESULTS_DIR): |
| 87 | + """Pickle the study to the directory""" |
| 88 | + # TODO perhaps remove exp_args_list before pickling and when loading bring them from the individual directories |
| 89 | + |
| 90 | + self.make_dir(exp_root=exp_root) |
| 91 | + with gzip.open(self.dir / "study.pkl.gz", "wb") as f: |
| 92 | + pickle.dump(self, f) |
| 93 | + |
| 94 | + def get_results(self, suffix="", also_save=True): |
| 95 | + """Recursively load all results from the study directory and summarize them.""" |
| 96 | + result_df = inspect_results.load_result_df(self.dir) |
| 97 | + error_report = inspect_results.error_report(result_df, max_stack_trace=3, use_log=True) |
| 98 | + summary_df = inspect_results.summarize_study(result_df) |
| 99 | + |
| 100 | + if also_save: |
| 101 | + suffix = f"_{suffix}" if suffix else "" |
| 102 | + result_df.to_csv(self.dir / f"result_df{suffix}.csv") |
| 103 | + summary_df.to_csv(self.dir / f"summary_df{suffix}.csv") |
| 104 | + (self.dir / f"error_report{suffix}.md").write_text(error_report) |
| 105 | + |
| 106 | + return result_df, summary_df, error_report |
| 107 | + |
| 108 | + |
25 | 109 | @dataclass |
26 | | -class Study: |
| 110 | +class Study(AbstractStudy): |
27 | 111 | """A study coresponds to one or multiple agents evaluated on a benchmark. |
28 | 112 |
|
29 | 113 | This is part of the high level API to help keep experiments organized and reproducible. |
@@ -139,7 +223,7 @@ def run( |
139 | 223 | self._run(n_jobs, parallel_backend, strict_reproducibility) |
140 | 224 |
|
141 | 225 | suffix = f"trial_{i + 1}_of_{n_relaunch}" |
142 | | - _, summary_df, error_report = self.get_results(suffix=suffix) |
| 226 | + _, summary_df, _ = self.get_results(suffix=suffix) |
143 | 227 | logger.info("\n" + str(summary_df)) |
144 | 228 |
|
145 | 229 | n_incomplete, n_error = self.find_incomplete(include_errors=relaunch_errors) |
@@ -197,60 +281,17 @@ def append_to_journal(self, strict_reproducibility=True): |
197 | 281 | ValueError: If the reproducibility information is not compatible |
198 | 282 | with the report. |
199 | 283 | """ |
| 284 | + _, summary_df, _ = self.get_results() |
200 | 285 | repro.append_to_journal( |
201 | 286 | self.reproducibility_info, |
202 | | - self.get_report(), |
| 287 | + summary_df, |
203 | 288 | strict_reproducibility=strict_reproducibility, |
204 | 289 | ) |
205 | 290 |
|
206 | | - def get_results(self, suffix="", also_save=True): |
207 | | - result_df = inspect_results.load_result_df(self.dir) |
208 | | - error_report = inspect_results.error_report(result_df, max_stack_trace=3, use_log=True) |
209 | | - summary_df = inspect_results.summarize_study(result_df) |
210 | | - |
211 | | - if also_save: |
212 | | - suffix = f"_{suffix}" if suffix else "" |
213 | | - result_df.to_csv(self.dir / f"result_df{suffix}.csv") |
214 | | - summary_df.to_csv(self.dir / f"summary_df{suffix}.csv") |
215 | | - (self.dir / f"error_report{suffix}.md").write_text(error_report) |
216 | | - |
217 | | - return result_df, summary_df, error_report |
218 | | - |
219 | 291 | @property |
220 | 292 | def name(self): |
221 | 293 | agent_names = [a.agent_name for a in self.agent_args] |
222 | | - if len(agent_names) == 1: |
223 | | - study_name = f"{agent_names[0]}_on_{self.benchmark.name}" |
224 | | - else: |
225 | | - study_name = f"{len(agent_names)}_agents_on_{self.benchmark.name}" |
226 | | - |
227 | | - study_name = slugify(study_name, max_length=100, allow_unicode=True) |
228 | | - |
229 | | - if self.suffix: |
230 | | - study_name += f"_{self.suffix}" |
231 | | - return study_name |
232 | | - |
233 | | - def make_dir(self, exp_root=RESULTS_DIR): |
234 | | - if self.dir is None: |
235 | | - dir_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{self.name}" |
236 | | - |
237 | | - self.dir = Path(exp_root) / dir_name |
238 | | - self.dir.mkdir(parents=True, exist_ok=True) |
239 | | - |
240 | | - def save(self): |
241 | | - """Pickle the study to the directory""" |
242 | | - |
243 | | - # TODO perhaps remove exp_args_list before pickling and when loading bring them from the individual directories |
244 | | - |
245 | | - self.make_dir() |
246 | | - |
247 | | - with gzip.open(self.dir / "study.pkl.gz", "wb") as f: |
248 | | - pickle.dump(self, f) |
249 | | - |
250 | | - def get_report(self, ignore_cache=False, ignore_stale=False): |
251 | | - return inspect_results.get_study_summary( |
252 | | - self.dir, ignore_cache=ignore_cache, ignore_stale=ignore_stale |
253 | | - ) |
| 294 | + return _make_study_name(agent_names, [self.benchmark.name], self.suffix) |
254 | 295 |
|
255 | 296 | def override_max_steps(self, max_steps): |
256 | 297 | for exp_args in self.exp_args_list: |
@@ -285,6 +326,64 @@ def load_most_recent(root_dir: Path = None, contains=None) -> "Study": |
285 | 326 | return Study.load(get_most_recent_study(root_dir, contains=contains)) |
286 | 327 |
|
287 | 328 |
|
| 329 | +def _make_study_name(agent_names, benchmark_names, suffix=None): |
| 330 | + """Make a study name from the agent and benchmark names.""" |
| 331 | + if len(agent_names) == 1: |
| 332 | + agent_name = agent_names[0] |
| 333 | + else: |
| 334 | + agent_name = f"{len(agent_names)}_agents" |
| 335 | + |
| 336 | + if len(benchmark_names) == 1: |
| 337 | + benchmark_name = benchmark_names[0] |
| 338 | + else: |
| 339 | + benchmark_name = f"{len(benchmark_names)}_benchmarks" |
| 340 | + |
| 341 | + study_name = f"{agent_name}_on_{benchmark_name}_{suffix if suffix else ''}" |
| 342 | + |
| 343 | + return slugify(study_name, max_length=200, allow_unicode=True) |
| 344 | + |
| 345 | + |
| 346 | +@dataclass |
| 347 | +class SequentialStudies(AbstractStudy): |
| 348 | + """ |
| 349 | + Sequential execution of multiple studies. |
| 350 | +
|
| 351 | + This is required for e.g. WebArena, where a server reset is required between evaluations of each agent. |
| 352 | + """ |
| 353 | + |
| 354 | + studies: list[Study] |
| 355 | + |
| 356 | + @property |
| 357 | + def name(self): |
| 358 | + """The name of the study.""" |
| 359 | + agent_names = [a.agent_name for study in self.studies for a in study.agent_args] |
| 360 | + benchmark_names = [study.benchmark.name for study in self.studies] |
| 361 | + return _make_study_name(agent_names, benchmark_names, self.suffix) |
| 362 | + |
| 363 | + def find_incomplete(self, include_errors=True): |
| 364 | + for study in self.studies: |
| 365 | + study.find_incomplete(include_errors=include_errors) |
| 366 | + |
| 367 | + def run(self, n_jobs=1, parallel_backend="ray", strict_reproducibility=False, n_relaunch=3): |
| 368 | + |
| 369 | + self.save() |
| 370 | + |
| 371 | + for study in self.studies: |
| 372 | + study.make_dir(exp_root=self.dir) |
| 373 | + study.run(n_jobs, parallel_backend, strict_reproducibility, n_relaunch) |
| 374 | + _, summary_df, _ = self.get_results() |
| 375 | + logger.info("\n" + str(summary_df)) |
| 376 | + logger.info(f"SequentialStudies {self.name} finished.") |
| 377 | + |
| 378 | + def override_max_steps(self, max_steps): |
| 379 | + for study in self.studies: |
| 380 | + study.override_max_steps(max_steps) |
| 381 | + |
| 382 | + def append_to_journal(self, strict_reproducibility=True): |
| 383 | + for study in self.studies: |
| 384 | + study.append_to_journal(strict_reproducibility=strict_reproducibility) |
| 385 | + |
| 386 | + |
288 | 387 | def get_most_recent_study( |
289 | 388 | root_dir: Path = None, date_format: str = "%Y-%m-%d_%H-%M-%S", contains=None |
290 | 389 | ): |
|
0 commit comments