Skip to content

Commit c1e953e

Browse files
authored
Merge pull request #228 from ServiceNow/tlsdc/study_updates
Switching _agents_on_benchmark to Study method for flexibility
2 parents 67f186d + b8d2bae commit c1e953e

File tree

1 file changed

+82
-80
lines changed

1 file changed

+82
-80
lines changed

src/agentlab/experiments/study.py

Lines changed: 82 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
from concurrent.futures import ProcessPoolExecutor
21
import gzip
32
import logging
43
import os
54
import pickle
5+
import random
66
import uuid
77
from abc import ABC, abstractmethod
8+
from concurrent.futures import ProcessPoolExecutor
89
from dataclasses import dataclass
910
from datetime import datetime
11+
from multiprocessing import Manager, Pool, Queue
1012
from pathlib import Path
1113

1214
import bgym
@@ -19,8 +21,6 @@
1921
from agentlab.experiments.exp_utils import RESULTS_DIR, add_dependencies
2022
from agentlab.experiments.launch_exp import find_incomplete, non_dummy_count, run_experiments
2123
from agentlab.experiments.multi_server import BaseServer, WebArenaInstanceVars
22-
from multiprocessing import Pool, Manager, Queue
23-
import random
2424

2525
logger = logging.getLogger(__name__)
2626

@@ -238,7 +238,7 @@ def __post_init__(self):
238238

239239
def make_exp_args_list(self):
240240
"""Generate the exp_args_list from the agent_args and the benchmark."""
241-
self.exp_args_list = _agents_on_benchmark(
241+
self.exp_args_list = self.agents_on_benchmark(
242242
self.agent_args,
243243
self.benchmark,
244244
logging_level=self.logging_level,
@@ -424,6 +424,84 @@ def load(dir: Path) -> "Study":
424424
def load_most_recent(root_dir: Path = None, contains=None) -> "Study":
425425
return Study.load(get_most_recent_study(root_dir, contains=contains))
426426

427+
def agents_on_benchmark(
428+
self,
429+
agents: list[AgentArgs] | AgentArgs,
430+
benchmark: bgym.Benchmark,
431+
demo_mode=False,
432+
logging_level: int = logging.INFO,
433+
logging_level_stdout: int = logging.INFO,
434+
ignore_dependencies=False,
435+
):
436+
"""Run one or multiple agents on a benchmark.
437+
438+
Args:
439+
agents: list[AgentArgs] | AgentArgs
440+
The agent configuration(s) to run.
441+
benchmark: bgym.Benchmark
442+
The benchmark to run the agents on.
443+
demo_mode: bool
444+
If True, the experiments will be run in demo mode.
445+
logging_level: int
446+
The logging level for individual jobs.
447+
logging_level_stdout: int
448+
The logging level for the stdout.
449+
ignore_dependencies: bool
450+
If True, the dependencies will be ignored and all experiments can be run in parallel.
451+
452+
Returns:
453+
list[ExpArgs]: The list of experiments to run.
454+
455+
Raises:
456+
ValueError: If multiple agents are run on a benchmark that requires manual reset.
457+
"""
458+
459+
if not isinstance(agents, (list, tuple)):
460+
agents = [agents]
461+
462+
if benchmark.name.startswith("visualwebarena") or benchmark.name.startswith("webarena"):
463+
if len(agents) > 1:
464+
raise ValueError(
465+
f"Only one agent can be run on {benchmark.name} since the instance requires manual reset after each evaluation."
466+
)
467+
468+
for agent in agents:
469+
agent.set_benchmark(
470+
benchmark, demo_mode
471+
) # the agent can adapt (lightly?) to the benchmark
472+
473+
env_args_list = benchmark.env_args_list
474+
if demo_mode:
475+
set_demo_mode(env_args_list)
476+
477+
exp_args_list = []
478+
479+
for agent in agents:
480+
for env_args in env_args_list:
481+
exp_args = ExpArgs(
482+
agent_args=agent,
483+
env_args=env_args,
484+
logging_level=logging_level,
485+
logging_level_stdout=logging_level_stdout,
486+
)
487+
exp_args_list.append(exp_args)
488+
489+
for i, exp_args in enumerate(exp_args_list):
490+
exp_args.order = i
491+
492+
# not required with ray, but keeping around if we would need it for visualwebareana on joblib
493+
# _flag_sequential_exp(exp_args_list, benchmark)
494+
495+
if not ignore_dependencies:
496+
# populate the depends_on field based on the task dependencies in the benchmark
497+
exp_args_list = add_dependencies(exp_args_list, benchmark.dependency_graph_over_tasks())
498+
else:
499+
logger.warning(
500+
f"Ignoring dependencies for benchmark {benchmark.name}. This could lead to different results."
501+
)
502+
503+
return exp_args_list
504+
427505

428506
def _make_study_name(agent_names, benchmark_names, suffix=None):
429507
"""Make a study name from the agent and benchmark names."""
@@ -634,82 +712,6 @@ def set_demo_mode(env_args_list: list[EnvArgs]):
634712
env_args.slow_mo = 1000
635713

636714

637-
def _agents_on_benchmark(
638-
agents: list[AgentArgs] | AgentArgs,
639-
benchmark: bgym.Benchmark,
640-
demo_mode=False,
641-
logging_level: int = logging.INFO,
642-
logging_level_stdout: int = logging.INFO,
643-
ignore_dependencies=False,
644-
):
645-
"""Run one or multiple agents on a benchmark.
646-
647-
Args:
648-
agents: list[AgentArgs] | AgentArgs
649-
The agent configuration(s) to run.
650-
benchmark: bgym.Benchmark
651-
The benchmark to run the agents on.
652-
demo_mode: bool
653-
If True, the experiments will be run in demo mode.
654-
logging_level: int
655-
The logging level for individual jobs.
656-
logging_level_stdout: int
657-
The logging level for the stdout.
658-
ignore_dependencies: bool
659-
If True, the dependencies will be ignored and all experiments can be run in parallel.
660-
661-
Returns:
662-
list[ExpArgs]: The list of experiments to run.
663-
664-
Raises:
665-
ValueError: If multiple agents are run on a benchmark that requires manual reset.
666-
"""
667-
668-
if not isinstance(agents, (list, tuple)):
669-
agents = [agents]
670-
671-
if benchmark.name.startswith("visualwebarena") or benchmark.name.startswith("webarena"):
672-
if len(agents) > 1:
673-
raise ValueError(
674-
f"Only one agent can be run on {benchmark.name} since the instance requires manual reset after each evaluation."
675-
)
676-
677-
for agent in agents:
678-
agent.set_benchmark(benchmark, demo_mode) # the agent can adapt (lightly?) to the benchmark
679-
680-
env_args_list = benchmark.env_args_list
681-
if demo_mode:
682-
set_demo_mode(env_args_list)
683-
684-
exp_args_list = []
685-
686-
for agent in agents:
687-
for env_args in env_args_list:
688-
exp_args = ExpArgs(
689-
agent_args=agent,
690-
env_args=env_args,
691-
logging_level=logging_level,
692-
logging_level_stdout=logging_level_stdout,
693-
)
694-
exp_args_list.append(exp_args)
695-
696-
for i, exp_args in enumerate(exp_args_list):
697-
exp_args.order = i
698-
699-
# not required with ray, but keeping around if we would need it for visualwebareana on joblib
700-
# _flag_sequential_exp(exp_args_list, benchmark)
701-
702-
if not ignore_dependencies:
703-
# populate the depends_on field based on the task dependencies in the benchmark
704-
exp_args_list = add_dependencies(exp_args_list, benchmark.dependency_graph_over_tasks())
705-
else:
706-
logger.warning(
707-
f"Ignoring dependencies for benchmark {benchmark.name}. This could lead to different results."
708-
)
709-
710-
return exp_args_list
711-
712-
713715
# def _flag_sequential_exp(exp_args_list: list[ExpArgs], benchmark: Benchmark):
714716
# if benchmark.name.startswith("visualwebarena"):
715717
# sequential_subset = benchmark.subset_from_glob("requires_reset", "True")

0 commit comments

Comments
 (0)