11import gzip
22import logging
3+ import os
34import pickle
45import uuid
56from abc import ABC , abstractmethod
1617from agentlab .experiments import reproducibility_util as repro
1718from agentlab .experiments .exp_utils import RESULTS_DIR , add_dependencies
1819from agentlab .experiments .launch_exp import find_incomplete , non_dummy_count , run_experiments
20+ from agentlab .experiments .multi_server import BaseServer , WebArenaInstanceVars
21+ from multiprocessing import Pool , Manager , Queue
1922
2023logger = logging .getLogger (__name__ )
2124
@@ -27,6 +30,7 @@ def make_study(
2730 suffix = "" ,
2831 comment = None ,
2932 ignore_dependencies = False ,
33+ parallel_servers = None ,
3034):
3135 """Run a list of agents on a benchmark.
3236
@@ -57,10 +61,18 @@ def make_study(
5761 3x compare to sequential executionz. To accelerate execution, you can ignore
5862 dependencies and run in full parallel. This leads to a decrease in performance of about
5963 1%-2%, and could be more. Note: ignore_dependencies on VisualWebArena doesn't work.
64+ parallel_servers: list[WebArenaInstanceVars]
65+ The number of parallel servers to use `if "webarena" in benchmark.name`. Use this to
66+ dispatch agent_args on a pool of servers in parallel. If len(agent_args) >
67+ len(parallel_servers), the servers will be reused for next evaluation (with a reset) as
68+ soon as it is done.
6069
6170 Returns:
62- Study object or SequentialStudies object if the benchmark requires manual reset after each
63- evaluation such as WebArena and VisualWebArena.
71+ Study | SequentialStudies | ParallelStudies object.
72+ SequentialStudies: if the benchmark requires manual reset after each evaluation such as
73+ WebArena and VisualWebArena.
74+ ParallelStudies: if the benchmark has multiple servers to run in parallel.
75+ Study: otherwise.
6476 """
6577
6678 if not isinstance (agent_args , (list , tuple )):
@@ -69,7 +81,7 @@ def make_study(
6981 if isinstance (benchmark , str ):
7082 benchmark = bgym .DEFAULT_BENCHMARKS [benchmark .lower ()]()
7183
72- if "webarena" in benchmark .name and len ( agent_args ) > 1 :
84+ if len ( agent_args ) > 1 and ( "webarena" in benchmark .name or parallel_servers is not None ) :
7385 logger .warning (
7486 "*WebArena* requires manual reset after each evaluation. Running through SequentialStudies."
7587 )
@@ -85,8 +97,10 @@ def make_study(
8597 ignore_dependencies = ignore_dependencies ,
8698 )
8799 )
88-
89- return SequentialStudies (studies )
100+ if parallel_servers is not None :
101+ return ParallelStudies (studies , parallel_servers = parallel_servers )
102+ else :
103+ return SequentialStudies (studies )
90104 else :
91105 return Study (
92106 agent_args ,
@@ -164,7 +178,7 @@ class Study(AbstractStudy):
164178 A suffix to add to the study name. This can be useful to keep track of your experiments.
165179 By default the study name contains agent name, benchmark name and date.
166180 uuid: str
167- A unique identifier for the study.
181+ A unique identifier for the study. Will be generated automatically.
168182 reproducibility_info: dict
169183 Information about the study that may affect the reproducibility of the experiment. e.g.:
170184 versions of BrowserGym, benchmark, AgentLab...
@@ -178,12 +192,12 @@ class Study(AbstractStudy):
178192 information. Leave any extra information that can explain why results could be different
179193 than expected.
180194 ignore_dependencies: bool
181- If True, ignore the dependencies of the tasks in the benchmark. *Use with caution.* So
195+ If True, ignore the dependencies of the tasks in the benchmark. *Use with caution*. So
182196 far, only WebArena and VisualWebArena have dependencies between tasks to minimize the
183197 influence of solving one task before another one. This dependency graph allows
184198 experiments to run in parallel while respecting task dependencies. However, it still
185199 can't run more than 4 and, in practice it's speeding up evaluation by a factor of only
186- 3x compare to sequential executionz . To accelerate execution, you can ignore
200+ 3x compare to sequential execution . To accelerate execution, you can ignore
187201 dependencies and run in full parallel. This leads to a decrease in performance of about
188202 1%-2%, and could be more. Note: ignore_dependencies on VisualWebArena doesn't work.
189203 avg_step_timeout: int
@@ -455,13 +469,15 @@ def run(self, n_jobs=1, parallel_backend="ray", strict_reproducibility=False, n_
455469 study .make_dir (exp_root = self .dir )
456470
457471 self .save ()
458-
459- for study in self .studies :
460- study .run (n_jobs , parallel_backend , strict_reproducibility , n_relaunch )
472+ self ._run (n_jobs , parallel_backend , strict_reproducibility , n_relaunch )
461473 _ , summary_df , _ = self .get_results ()
462474 logger .info ("\n " + str (summary_df ))
463475 logger .info (f"SequentialStudies { self .name } finished." )
464476
477+ def _run (self , n_jobs = 1 , parallel_backend = "ray" , strict_reproducibility = False , n_relaunch = 3 ):
478+ for study in self .studies :
479+ study .run (n_jobs , parallel_backend , strict_reproducibility , n_relaunch )
480+
465481 def override_max_steps (self , max_steps ):
466482 for study in self .studies :
467483 study .override_max_steps (max_steps )
@@ -471,6 +487,57 @@ def append_to_journal(self, strict_reproducibility=True):
471487 study .append_to_journal (strict_reproducibility = strict_reproducibility )
472488
473489
490+ def _init_worker (server_queue : Queue ):
491+ """Run once at the initialization of the worker in the multiprocessing.Pool.
492+
493+ This is typically used to initialize different environment variables of the WebArena server for
494+ multiple instances in parallel.
495+
496+ Args:
497+ server_queue: Queue
498+ A queue of object implementing BaseServer to initialize (or anything with a init
499+ method).
500+ """
501+ server_instance = server_queue .get () # type: "WebArenaInstanceVars"
502+ logger .warning (f"Initializing server instance { server_instance } from process { os .getpid ()} " )
503+ server_instance .init ()
504+
505+
506+ def _run_study (study : Study , n_jobs , parallel_backend , strict_reproducibility , n_relaunch ):
507+ """Wrapper to run a study remotely."""
508+ study .run (n_jobs , parallel_backend , strict_reproducibility , n_relaunch )
509+
510+
511+ @dataclass
512+ class ParallelStudies (SequentialStudies ):
513+
514+ parallel_servers : list [BaseServer ] | int = None
515+
516+ def _run (
517+ self ,
518+ n_jobs = 1 ,
519+ parallel_backend = "ray" ,
520+ strict_reproducibility = False ,
521+ n_relaunch = 3 ,
522+ ):
523+ parallel_servers = self .parallel_servers
524+ if isinstance (parallel_servers , int ):
525+ parallel_servers = [BaseServer () for _ in range (parallel_servers )]
526+
527+ server_queue = Manager ().Queue ()
528+ for server in parallel_servers :
529+ server_queue .put (server )
530+
531+ with Pool (len (parallel_servers ), initializer = _init_worker , initargs = (server_queue ,)) as p :
532+ p .starmap (
533+ _run_study ,
534+ [
535+ (study , n_jobs , parallel_backend , strict_reproducibility , n_relaunch )
536+ for study in self .studies
537+ ],
538+ )
539+
540+
474541def get_most_recent_study (
475542 root_dir : Path = None , date_format : str = "%Y-%m-%d_%H-%M-%S" , contains = None
476543):
0 commit comments