Skip to content

Commit 95f423d

Browse files
recursixgasse
andcommitted
Study to multi eval (#126)
* yet another way to kill timedout jobs * Improve timeout handling in task polling logic * Add method to override max_steps in Study class * add support for tab visibility in observation flags and update related components * fix tests * black * Improve timeout handling in task polling logic * yet another way to kill timedout jobs (#108) * Add method to override max_steps in Study class * add support for tab visibility in observation flags and update related components * fix tests * black * black * Fix sorting bug. improve directory content retrieval with summary statistics * fix test * black * tmp * add error report, add cum cost to summary and ray backend by default * sequential studies --------- Co-authored-by: Maxime Gasse <[email protected]>
1 parent f951e65 commit 95f423d

File tree

1 file changed

+149
-50
lines changed

1 file changed

+149
-50
lines changed

src/agentlab/experiments/study.py

Lines changed: 149 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
from abc import ABC, abstractmethod
12
import gzip
23
import logging
34
import pickle
4-
import re
55
import uuid
66
from abc import ABC, abstractmethod
77
from dataclasses import dataclass
@@ -14,16 +14,100 @@
1414

1515
from agentlab.agents.agent_args import AgentArgs
1616
from agentlab.analyze import inspect_results
17-
from agentlab.experiments import args
1817
from agentlab.experiments import reproducibility_util as repro
1918
from agentlab.experiments.exp_utils import RESULTS_DIR, add_dependencies
2019
from agentlab.experiments.launch_exp import find_incomplete, non_dummy_count, run_experiments
2120

21+
2222
logger = logging.getLogger(__name__)
2323

2424

25+
def make_study(
26+
agent_args: list[AgentArgs],
27+
benchmark: bgym.Benchmark,
28+
logging_level_stdout=logging.WARNING,
29+
suffix="",
30+
comment=None,
31+
ignore_dependencies=False,
32+
):
33+
34+
if isinstance(benchmark, str):
35+
benchmark = bgym.DEFAULT_BENCHMARKS[benchmark]()
36+
37+
"""Make a study from a list of agents and a benchmark."""
38+
if "webarena" in benchmark.name and len(agent_args) > 1:
39+
logger.warning(
40+
"*WebArena* requires manual reset after each evaluation. Running through SequentialStudies."
41+
)
42+
studies = []
43+
for agent in agent_args:
44+
studies.append(
45+
Study(
46+
[agent],
47+
benchmark,
48+
logging_level=logging_level_stdout,
49+
suffix=suffix,
50+
comment=comment,
51+
ignore_dependencies=ignore_dependencies,
52+
)
53+
)
54+
55+
return SequentialStudies(studies)
56+
else:
57+
return Study(
58+
agent_args,
59+
benchmark,
60+
logging_level=logging_level_stdout,
61+
suffix=suffix,
62+
comment=comment,
63+
ignore_dependencies=ignore_dependencies,
64+
)
65+
66+
67+
class AbstractStudy(ABC):
68+
dir: Path = None
69+
suffix: str = ""
70+
71+
@abstractmethod
72+
def find_incomplete(self, include_errors=True):
73+
"""Search for missing"""
74+
75+
@abstractmethod
76+
def run(self, n_jobs=1, parallel_backend="ray", strict_reproducibility=False, n_relaunch=3):
77+
"""Run the study"""
78+
79+
def make_dir(self, exp_root=RESULTS_DIR):
80+
if self.dir is None:
81+
dir_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{self.name}"
82+
83+
self.dir = Path(exp_root) / dir_name
84+
self.dir.mkdir(parents=True, exist_ok=True)
85+
86+
def save(self, exp_root=RESULTS_DIR):
87+
"""Pickle the study to the directory"""
88+
# TODO perhaps remove exp_args_list before pickling and when loading bring them from the individual directories
89+
90+
self.make_dir(exp_root=exp_root)
91+
with gzip.open(self.dir / "study.pkl.gz", "wb") as f:
92+
pickle.dump(self, f)
93+
94+
def get_results(self, suffix="", also_save=True):
95+
"""Recursively load all results from the study directory and summarize them."""
96+
result_df = inspect_results.load_result_df(self.dir)
97+
error_report = inspect_results.error_report(result_df, max_stack_trace=3, use_log=True)
98+
summary_df = inspect_results.summarize_study(result_df)
99+
100+
if also_save:
101+
suffix = f"_{suffix}" if suffix else ""
102+
result_df.to_csv(self.dir / f"result_df{suffix}.csv")
103+
summary_df.to_csv(self.dir / f"summary_df{suffix}.csv")
104+
(self.dir / f"error_report{suffix}.md").write_text(error_report)
105+
106+
return result_df, summary_df, error_report
107+
108+
25109
@dataclass
26-
class Study:
110+
class Study(AbstractStudy):
27111
"""A study coresponds to one or multiple agents evaluated on a benchmark.
28112
29113
This is part of the high level API to help keep experiments organized and reproducible.
@@ -139,7 +223,7 @@ def run(
139223
self._run(n_jobs, parallel_backend, strict_reproducibility)
140224

141225
suffix = f"trial_{i + 1}_of_{n_relaunch}"
142-
_, summary_df, error_report = self.get_results(suffix=suffix)
226+
_, summary_df, _ = self.get_results(suffix=suffix)
143227
logger.info("\n" + str(summary_df))
144228

145229
n_incomplete, n_error = self.find_incomplete(include_errors=relaunch_errors)
@@ -197,60 +281,17 @@ def append_to_journal(self, strict_reproducibility=True):
197281
ValueError: If the reproducibility information is not compatible
198282
with the report.
199283
"""
284+
_, summary_df, _ = self.get_results()
200285
repro.append_to_journal(
201286
self.reproducibility_info,
202-
self.get_report(),
287+
summary_df,
203288
strict_reproducibility=strict_reproducibility,
204289
)
205290

206-
def get_results(self, suffix="", also_save=True):
207-
result_df = inspect_results.load_result_df(self.dir)
208-
error_report = inspect_results.error_report(result_df, max_stack_trace=3, use_log=True)
209-
summary_df = inspect_results.summarize_study(result_df)
210-
211-
if also_save:
212-
suffix = f"_{suffix}" if suffix else ""
213-
result_df.to_csv(self.dir / f"result_df{suffix}.csv")
214-
summary_df.to_csv(self.dir / f"summary_df{suffix}.csv")
215-
(self.dir / f"error_report{suffix}.md").write_text(error_report)
216-
217-
return result_df, summary_df, error_report
218-
219291
@property
220292
def name(self):
221293
agent_names = [a.agent_name for a in self.agent_args]
222-
if len(agent_names) == 1:
223-
study_name = f"{agent_names[0]}_on_{self.benchmark.name}"
224-
else:
225-
study_name = f"{len(agent_names)}_agents_on_{self.benchmark.name}"
226-
227-
study_name = slugify(study_name, max_length=100, allow_unicode=True)
228-
229-
if self.suffix:
230-
study_name += f"_{self.suffix}"
231-
return study_name
232-
233-
def make_dir(self, exp_root=RESULTS_DIR):
234-
if self.dir is None:
235-
dir_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{self.name}"
236-
237-
self.dir = Path(exp_root) / dir_name
238-
self.dir.mkdir(parents=True, exist_ok=True)
239-
240-
def save(self):
241-
"""Pickle the study to the directory"""
242-
243-
# TODO perhaps remove exp_args_list before pickling and when loading bring them from the individual directories
244-
245-
self.make_dir()
246-
247-
with gzip.open(self.dir / "study.pkl.gz", "wb") as f:
248-
pickle.dump(self, f)
249-
250-
def get_report(self, ignore_cache=False, ignore_stale=False):
251-
return inspect_results.get_study_summary(
252-
self.dir, ignore_cache=ignore_cache, ignore_stale=ignore_stale
253-
)
294+
return _make_study_name(agent_names, [self.benchmark.name], self.suffix)
254295

255296
def override_max_steps(self, max_steps):
256297
for exp_args in self.exp_args_list:
@@ -285,6 +326,64 @@ def load_most_recent(root_dir: Path = None, contains=None) -> "Study":
285326
return Study.load(get_most_recent_study(root_dir, contains=contains))
286327

287328

329+
def _make_study_name(agent_names, benchmark_names, suffix=None):
330+
"""Make a study name from the agent and benchmark names."""
331+
if len(agent_names) == 1:
332+
agent_name = agent_names[0]
333+
else:
334+
agent_name = f"{len(agent_names)}_agents"
335+
336+
if len(benchmark_names) == 1:
337+
benchmark_name = benchmark_names[0]
338+
else:
339+
benchmark_name = f"{len(benchmark_names)}_benchmarks"
340+
341+
study_name = f"{agent_name}_on_{benchmark_name}_{suffix if suffix else ''}"
342+
343+
return slugify(study_name, max_length=200, allow_unicode=True)
344+
345+
346+
@dataclass
347+
class SequentialStudies(AbstractStudy):
348+
"""
349+
Sequential execution of multiple studies.
350+
351+
This is required for e.g. WebArena, where a server reset is required between evaluations of each agent.
352+
"""
353+
354+
studies: list[Study]
355+
356+
@property
357+
def name(self):
358+
"""The name of the study."""
359+
agent_names = [a.agent_name for study in self.studies for a in study.agent_args]
360+
benchmark_names = [study.benchmark.name for study in self.studies]
361+
return _make_study_name(agent_names, benchmark_names, self.suffix)
362+
363+
def find_incomplete(self, include_errors=True):
364+
for study in self.studies:
365+
study.find_incomplete(include_errors=include_errors)
366+
367+
def run(self, n_jobs=1, parallel_backend="ray", strict_reproducibility=False, n_relaunch=3):
368+
369+
self.save()
370+
371+
for study in self.studies:
372+
study.make_dir(exp_root=self.dir)
373+
study.run(n_jobs, parallel_backend, strict_reproducibility, n_relaunch)
374+
_, summary_df, _ = self.get_results()
375+
logger.info("\n" + str(summary_df))
376+
logger.info(f"SequentialStudies {self.name} finished.")
377+
378+
def override_max_steps(self, max_steps):
379+
for study in self.studies:
380+
study.override_max_steps(max_steps)
381+
382+
def append_to_journal(self, strict_reproducibility=True):
383+
for study in self.studies:
384+
study.append_to_journal(strict_reproducibility=strict_reproducibility)
385+
386+
288387
def get_most_recent_study(
289388
root_dir: Path = None, date_format: str = "%Y-%m-%d_%H-%M-%S", contains=None
290389
):

0 commit comments

Comments
 (0)