Skip to content

Commit 50d4571

Browse files
committed
sequential studies
1 parent 6a50756 commit 50d4571

File tree

1 file changed

+149
-50
lines changed

1 file changed

+149
-50
lines changed

src/agentlab/experiments/study.py

Lines changed: 149 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
from abc import ABC, abstractmethod
12
import gzip
23
import logging
34
import pickle
4-
import re
55
import uuid
66
from dataclasses import dataclass
77
from datetime import datetime
@@ -13,7 +13,6 @@
1313

1414
from agentlab.agents.agent_args import AgentArgs
1515
from agentlab.analyze import inspect_results
16-
from agentlab.experiments import args
1716
from agentlab.experiments import reproducibility_util as repro
1817
from agentlab.experiments.exp_utils import RESULTS_DIR, add_dependencies
1918
from agentlab.experiments.launch_exp import (
@@ -22,11 +21,96 @@
2221
run_experiments,
2322
)
2423

24+
2525
logger = logging.getLogger(__name__)
2626

2727

28+
def make_study(
29+
agent_args: list[AgentArgs],
30+
benchmark: bgym.Benchmark,
31+
logging_level_stdout=logging.WARNING,
32+
suffix="",
33+
comment=None,
34+
ignore_dependencies=False,
35+
):
36+
37+
if isinstance(benchmark, str):
38+
benchmark = bgym.DEFAULT_BENCHMARKS[benchmark]()
39+
40+
"""Make a study from a list of agents and a benchmark."""
41+
if "webarena" in benchmark.name and len(agent_args) > 1:
42+
logger.warning(
43+
"*WebArena* requires manual reset after each evaluation. Running through SequentialStudies."
44+
)
45+
studies = []
46+
for agent in agent_args:
47+
studies.append(
48+
Study(
49+
[agent],
50+
benchmark,
51+
logging_level=logging_level_stdout,
52+
suffix=suffix,
53+
comment=comment,
54+
ignore_dependencies=ignore_dependencies,
55+
)
56+
)
57+
58+
return SequentialStudies(studies)
59+
else:
60+
return Study(
61+
agent_args,
62+
benchmark,
63+
logging_level=logging_level_stdout,
64+
suffix=suffix,
65+
comment=comment,
66+
ignore_dependencies=ignore_dependencies,
67+
)
68+
69+
70+
class AbstractStudy(ABC):
71+
dir: Path = None
72+
suffix: str = ""
73+
74+
@abstractmethod
75+
def find_incomplete(self, include_errors=True):
76+
"""Search for missing"""
77+
78+
@abstractmethod
79+
def run(self, n_jobs=1, parallel_backend="ray", strict_reproducibility=False, n_relaunch=3):
80+
"""Run the study"""
81+
82+
def make_dir(self, exp_root=RESULTS_DIR):
83+
if self.dir is None:
84+
dir_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{self.name}"
85+
86+
self.dir = Path(exp_root) / dir_name
87+
self.dir.mkdir(parents=True, exist_ok=True)
88+
89+
def save(self, exp_root=RESULTS_DIR):
90+
"""Pickle the study to the directory"""
91+
# TODO perhaps remove exp_args_list before pickling and when loading bring them from the individual directories
92+
93+
self.make_dir(exp_root=exp_root)
94+
with gzip.open(self.dir / "study.pkl.gz", "wb") as f:
95+
pickle.dump(self, f)
96+
97+
def get_results(self, suffix="", also_save=True):
98+
"""Recursively load all results from the study directory and summarize them."""
99+
result_df = inspect_results.load_result_df(self.dir)
100+
error_report = inspect_results.error_report(result_df, max_stack_trace=3, use_log=True)
101+
summary_df = inspect_results.summarize_study(result_df)
102+
103+
if also_save:
104+
suffix = f"_{suffix}" if suffix else ""
105+
result_df.to_csv(self.dir / f"result_df{suffix}.csv")
106+
summary_df.to_csv(self.dir / f"summary_df{suffix}.csv")
107+
(self.dir / f"error_report{suffix}.md").write_text(error_report)
108+
109+
return result_df, summary_df, error_report
110+
111+
28112
@dataclass
29-
class Study:
113+
class Study(AbstractStudy):
30114
"""A study coresponds to one or multiple agents evaluated on a benchmark.
31115
32116
This is part of the high level API to help keep experiments organized and reproducible.
@@ -142,7 +226,7 @@ def run(
142226
self._run(n_jobs, parallel_backend, strict_reproducibility)
143227

144228
suffix = f"trial_{i + 1}_of_{n_relaunch}"
145-
_, summary_df, error_report = self.get_results(suffix=suffix)
229+
_, summary_df, _ = self.get_results(suffix=suffix)
146230
logger.info("\n" + str(summary_df))
147231

148232
n_incomplete, n_error = self.find_incomplete(include_errors=relaunch_errors)
@@ -200,60 +284,17 @@ def append_to_journal(self, strict_reproducibility=True):
200284
ValueError: If the reproducibility information is not compatible
201285
with the report.
202286
"""
287+
_, summary_df, _ = self.get_results()
203288
repro.append_to_journal(
204289
self.reproducibility_info,
205-
self.get_report(),
290+
summary_df,
206291
strict_reproducibility=strict_reproducibility,
207292
)
208293

209-
def get_results(self, suffix="", also_save=True):
210-
result_df = inspect_results.load_result_df(self.dir)
211-
error_report = inspect_results.error_report(result_df, max_stack_trace=3, use_log=True)
212-
summary_df = inspect_results.summarize_study(result_df)
213-
214-
if also_save:
215-
suffix = f"_{suffix}" if suffix else ""
216-
result_df.to_csv(self.dir / f"result_df{suffix}.csv")
217-
summary_df.to_csv(self.dir / f"summary_df{suffix}.csv")
218-
(self.dir / f"error_report{suffix}.md").write_text(error_report)
219-
220-
return result_df, summary_df, error_report
221-
222294
@property
223295
def name(self):
224296
agent_names = [a.agent_name for a in self.agent_args]
225-
if len(agent_names) == 1:
226-
study_name = f"{agent_names[0]}_on_{self.benchmark.name}"
227-
else:
228-
study_name = f"{len(agent_names)}_agents_on_{self.benchmark.name}"
229-
230-
study_name = slugify(study_name, max_length=100, allow_unicode=True)
231-
232-
if self.suffix:
233-
study_name += f"_{self.suffix}"
234-
return study_name
235-
236-
def make_dir(self, exp_root=RESULTS_DIR):
237-
if self.dir is None:
238-
dir_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{self.name}"
239-
240-
self.dir = Path(exp_root) / dir_name
241-
self.dir.mkdir(parents=True, exist_ok=True)
242-
243-
def save(self):
244-
"""Pickle the study to the directory"""
245-
246-
# TODO perhaps remove exp_args_list before pickling and when loading bring them from the individual directories
247-
248-
self.make_dir()
249-
250-
with gzip.open(self.dir / "study.pkl.gz", "wb") as f:
251-
pickle.dump(self, f)
252-
253-
def get_report(self, ignore_cache=False, ignore_stale=False):
254-
return inspect_results.get_study_summary(
255-
self.dir, ignore_cache=ignore_cache, ignore_stale=ignore_stale
256-
)
297+
return _make_study_name(agent_names, [self.benchmark.name], self.suffix)
257298

258299
def override_max_steps(self, max_steps):
259300
for exp_args in self.exp_args_list:
@@ -288,6 +329,64 @@ def load_most_recent(root_dir: Path = None, contains=None) -> "Study":
288329
return Study.load(get_most_recent_study(root_dir, contains=contains))
289330

290331

332+
def _make_study_name(agent_names, benchmark_names, suffix=None):
333+
"""Make a study name from the agent and benchmark names."""
334+
if len(agent_names) == 1:
335+
agent_name = agent_names[0]
336+
else:
337+
agent_name = f"{len(agent_names)}_agents"
338+
339+
if len(benchmark_names) == 1:
340+
benchmark_name = benchmark_names[0]
341+
else:
342+
benchmark_name = f"{len(benchmark_names)}_benchmarks"
343+
344+
study_name = f"{agent_name}_on_{benchmark_name}_{suffix if suffix else ''}"
345+
346+
return slugify(study_name, max_length=200, allow_unicode=True)
347+
348+
349+
@dataclass
350+
class SequentialStudies(AbstractStudy):
351+
"""
352+
Sequential execution of multiple studies.
353+
354+
This is required for e.g. WebArena, where a server reset is required between evaluations of each agent.
355+
"""
356+
357+
studies: list[Study]
358+
359+
@property
360+
def name(self):
361+
"""The name of the study."""
362+
agent_names = [a.agent_name for study in self.studies for a in study.agent_args]
363+
benchmark_names = [study.benchmark.name for study in self.studies]
364+
return _make_study_name(agent_names, benchmark_names, self.suffix)
365+
366+
def find_incomplete(self, include_errors=True):
367+
for study in self.studies:
368+
study.find_incomplete(include_errors=include_errors)
369+
370+
def run(self, n_jobs=1, parallel_backend="ray", strict_reproducibility=False, n_relaunch=3):
371+
372+
self.save()
373+
374+
for study in self.studies:
375+
study.make_dir(exp_root=self.dir)
376+
study.run(n_jobs, parallel_backend, strict_reproducibility, n_relaunch)
377+
_, summary_df, _ = self.get_results()
378+
logger.info("\n" + str(summary_df))
379+
logger.info(f"SequentialStudies {self.name} finished.")
380+
381+
def override_max_steps(self, max_steps):
382+
for study in self.studies:
383+
study.override_max_steps(max_steps)
384+
385+
def append_to_journal(self, strict_reproducibility=True):
386+
for study in self.studies:
387+
study.append_to_journal(strict_reproducibility=strict_reproducibility)
388+
389+
291390
def get_most_recent_study(
292391
root_dir: Path = None, date_format: str = "%Y-%m-%d_%H-%M-%S", contains=None
293392
):

0 commit comments

Comments
 (0)