11import gzip
22import logging
3+ import os
34import pickle
45import uuid
56from abc import ABC , abstractmethod
1617from agentlab .experiments import reproducibility_util as repro
1718from agentlab .experiments .exp_utils import RESULTS_DIR , add_dependencies
1819from agentlab .experiments .launch_exp import find_incomplete , non_dummy_count , run_experiments
20+ from agentlab .experiments .multi_server import BaseServer , WebArenaInstanceVars
21+ from multiprocessing import Pool , Manager , Queue
1922
2023logger = logging .getLogger (__name__ )
2124
@@ -27,6 +30,7 @@ def make_study(
2730 suffix = "" ,
2831 comment = None ,
2932 ignore_dependencies = False ,
33+ parallel_servers = None ,
3034):
3135 """Run a list of agents on a benchmark.
3236
@@ -57,10 +61,17 @@ def make_study(
5761 3x compare to sequential executionz. To accelerate execution, you can ignore
5862 dependencies and run in full parallel. This leads to a decrease in performance of about
5963 1%-2%, and could be more. Note: ignore_dependencies on VisualWebArena doesn't work.
60-
61- Returns:
62- Study object or SequentialStudies object if the benchmark requires manual reset after each
63- evaluation such as WebArena and VisualWebArena.
64+ parallel_servers: list[WebArenaInstanceVars]
65+ The number of parallel servers to use `if "webarena" in benchmark.name`. Use this to
66+ dispatch agent_args on a pool of servers in parallel. If len(agent_args) >
67+ len(parallel_servers), the servers will be reused for next evaluation (with a reset) as
68+ soon as it is done.
69+
70+ Returns: Study | SequentialStudies | ParallelStudies object.
71+ SequentialStudies: if the benchmark requires manual reset after each evaluation such as
72+ WebArena and VisualWebArena.
73+ ParallelStudies: if the benchmark has multiple servers to run in parallel.
74+ Study: otherwise.
6475 """
6576
6677 if not isinstance (agent_args , (list , tuple )):
@@ -69,7 +80,7 @@ def make_study(
6980 if isinstance (benchmark , str ):
7081 benchmark = bgym .DEFAULT_BENCHMARKS [benchmark .lower ()]()
7182
72- if "webarena" in benchmark .name and len ( agent_args ) > 1 :
83+ if len ( agent_args ) > 1 and ( "webarena" in benchmark .name or parallel_servers is not None ) :
7384 logger .warning (
7485 "*WebArena* requires manual reset after each evaluation. Running through SequentialStudies."
7586 )
@@ -85,8 +96,10 @@ def make_study(
8596 ignore_dependencies = ignore_dependencies ,
8697 )
8798 )
88-
89- return SequentialStudies (studies )
99+ if parallel_servers is not None :
100+ return ParallelStudies (studies , parallel_servers = parallel_servers )
101+ else :
102+ return SequentialStudies (studies )
90103 else :
91104 return Study (
92105 agent_args ,
@@ -164,7 +177,7 @@ class Study(AbstractStudy):
164177 A suffix to add to the study name. This can be useful to keep track of your experiments.
165178 By default the study name contains agent name, benchmark name and date.
166179 uuid: str
167- A unique identifier for the study.
180+ A unique identifier for the study. Will be generated automatically.
168181 reproducibility_info: dict
169182 Information about the study that may affect the reproducibility of the experiment. e.g.:
170183 versions of BrowserGym, benchmark, AgentLab...
@@ -178,12 +191,12 @@ class Study(AbstractStudy):
178191 information. Leave any extra information that can explain why results could be different
179192 than expected.
180193 ignore_dependencies: bool
181- If True, ignore the dependencies of the tasks in the benchmark. *Use with caution.* So
194+ If True, ignore the dependencies of the tasks in the benchmark. *Use with caution*. So
182195 far, only WebArena and VisualWebArena have dependencies between tasks to minimize the
183196 influence of solving one task before another one. This dependency graph allows
184197 experiments to run in parallel while respecting task dependencies. However, it still
185198 can't run more than 4 and, in practice it's speeding up evaluation by a factor of only
186- 3x compare to sequential executionz . To accelerate execution, you can ignore
199+ 3x compare to sequential execution . To accelerate execution, you can ignore
187200 dependencies and run in full parallel. This leads to a decrease in performance of about
188201 1%-2%, and could be more. Note: ignore_dependencies on VisualWebArena doesn't work.
189202 avg_step_timeout: int
@@ -455,13 +468,15 @@ def run(self, n_jobs=1, parallel_backend="ray", strict_reproducibility=False, n_
455468 study .make_dir (exp_root = self .dir )
456469
457470 self .save ()
458-
459- for study in self .studies :
460- study .run (n_jobs , parallel_backend , strict_reproducibility , n_relaunch )
471+ self ._run (n_jobs , parallel_backend , strict_reproducibility , n_relaunch )
461472 _ , summary_df , _ = self .get_results ()
462473 logger .info ("\n " + str (summary_df ))
463474 logger .info (f"SequentialStudies { self .name } finished." )
464475
476+ def _run (self , n_jobs = 1 , parallel_backend = "ray" , strict_reproducibility = False , n_relaunch = 3 ):
477+ for study in self .studies :
478+ study .run (n_jobs , parallel_backend , strict_reproducibility , n_relaunch )
479+
465480 def override_max_steps (self , max_steps ):
466481 for study in self .studies :
467482 study .override_max_steps (max_steps )
@@ -471,6 +486,52 @@ def append_to_journal(self, strict_reproducibility=True):
471486 study .append_to_journal (strict_reproducibility = strict_reproducibility )
472487
473488
489+ def _init_worker (server_queue : Queue ):
490+ """Run once at the initialization of the worker in the multiprocessing.Pool.
491+
492+ This is typically used to initialize different environment variables of the WebArena server for
493+ multiple instances in parallel.
494+ """
495+ server_instance = server_queue .get () # type: "WebArenaInstanceVars"
496+ logger .warning (f"Initializing server instance { server_instance } from process { os .getpid ()} " )
497+ server_instance .init ()
498+
499+
500+ def _run_study (study : Study , n_jobs , parallel_backend , strict_reproducibility , n_relaunch ):
501+ """Wrapper to run a study remotely."""
502+ study .run (n_jobs , parallel_backend , strict_reproducibility , n_relaunch )
503+
504+
505+ @dataclass
506+ class ParallelStudies (SequentialStudies ):
507+
508+ parallel_servers : list [BaseServer ] | int = None
509+
510+ def _run (
511+ self ,
512+ n_jobs = 1 ,
513+ parallel_backend = "ray" ,
514+ strict_reproducibility = False ,
515+ n_relaunch = 3 ,
516+ ):
517+ parallel_servers = self .parallel_servers
518+ if isinstance (parallel_servers , int ):
519+ parallel_servers = [BaseServer () for _ in range (parallel_servers )]
520+
521+ server_queue = Manager ().Queue ()
522+ for server in parallel_servers :
523+ server_queue .put (server )
524+
525+ with Pool (len (parallel_servers ), initializer = _init_worker , initargs = (server_queue ,)) as p :
526+ p .starmap (
527+ _run_study ,
528+ [
529+ (study , n_jobs , parallel_backend , strict_reproducibility , n_relaunch )
530+ for study in self .studies
531+ ],
532+ )
533+
534+
474535def get_most_recent_study (
475536 root_dir : Path = None , date_format : str = "%Y-%m-%d_%H-%M-%S" , contains = None
476537):
0 commit comments