Skip to content

Commit 18adb81

Browse files
committed
Initial commit for parallel study evaluation
1 parent c52b7cd commit 18adb81

File tree

4 files changed

+245
-13
lines changed

4 files changed

+245
-13
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from copy import deepcopy
2+
from dataclasses import dataclass
3+
import os
4+
import sys
5+
from browsergym.webarena.instance import WebArenaInstance
6+
7+
8+
class BaseServer:
9+
"""Base class for server instances.
10+
11+
Behaves like an identity function for running in parallel on servers that don't need multiple
12+
instances.
13+
"""
14+
15+
def init(self):
16+
pass
17+
18+
def __str__(self):
19+
return "BaseServer"
20+
21+
22+
@dataclass
23+
class WebArenaInstanceVars(BaseServer):
24+
base_url: str
25+
shopping: str
26+
shopping_admin: str
27+
reddit: str
28+
gitlab: str
29+
wikipedia: str
30+
map: str
31+
homepage: str
32+
full_reset: str
33+
module_name: str = "webarena"
34+
prefix: str = "WA_"
35+
36+
def make_env_vars(self):
37+
"""Return a dictionary of environment variables"""
38+
return {
39+
f"{self.prefix}SHOPPING": f"{self.base_url}:{self.shopping}",
40+
f"{self.prefix}SHOPPING_ADMIN": f"{self.base_url}:{self.shopping_admin}",
41+
f"{self.prefix}REDDIT": f"{self.base_url}:{self.reddit}",
42+
f"{self.prefix}GITLAB": f"{self.base_url}:{self.gitlab}",
43+
f"{self.prefix}WIKIPEDIA": f"{self.base_url}:{self.wikipedia}",
44+
f"{self.prefix}MAP": f"{self.base_url}:{self.map}",
45+
f"{self.prefix}HOMEPAGE": f"{self.base_url}:{self.homepage}",
46+
f"{self.prefix}FULL_RESET": f"{self.base_url}:{self.full_reset}",
47+
}
48+
49+
def init(self):
50+
# necessary for webarena to re-import the env vars
51+
unimport_modules(self.module_name)
52+
for key, value in self.make_env_vars().items():
53+
os.environ[key] = value
54+
bgym_instance = WebArenaInstance()
55+
base_url, _ = _split_url(bgym_instance.urls["reddit"])
56+
assert base_url == self.base_url, f"Expected {self.base_url}, got {base_url}"
57+
58+
@staticmethod
59+
def from_env_vars(prefix="WA_", module_name="webarena"):
60+
kwargs = {"module_name": module_name}
61+
base_urls = set()
62+
for key, url in os.environ.items():
63+
if key.startswith(prefix):
64+
base_url, url_tail = _split_url(url)
65+
base_urls.add(base_url)
66+
kwargs[key[len(prefix) :].lower()] = url_tail
67+
68+
if len(base_urls) > 1:
69+
raise ValueError("Multiple base urls found in environment variables")
70+
71+
kwargs["base_url"] = base_urls.pop()
72+
return WebArenaInstanceVars(**kwargs)
73+
74+
def clone(self):
75+
"""Return a deep copy of the instance"""
76+
return deepcopy(self)
77+
78+
79+
def unimport_modules(base_name):
80+
"""un-import any module starting with base_name"""
81+
for module in sys.modules.copy():
82+
if module.startswith(base_name):
83+
del sys.modules[module]
84+
85+
86+
def _split_url(url: str):
87+
"""Extract the base url and the port/page from a url"""
88+
parts = url.split(":")
89+
base_url = ":".join(parts[0:2])
90+
url_tail = ":".join(parts[2:])
91+
return base_url, url_tail

src/agentlab/experiments/study.py

Lines changed: 74 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import gzip
22
import logging
3+
import os
34
import pickle
45
import uuid
56
from abc import ABC, abstractmethod
@@ -16,6 +17,8 @@
1617
from agentlab.experiments import reproducibility_util as repro
1718
from agentlab.experiments.exp_utils import RESULTS_DIR, add_dependencies
1819
from agentlab.experiments.launch_exp import find_incomplete, non_dummy_count, run_experiments
20+
from agentlab.experiments.multi_server import BaseServer, WebArenaInstanceVars
21+
from multiprocessing import Pool, Manager, Queue
1922

2023
logger = logging.getLogger(__name__)
2124

@@ -27,6 +30,7 @@ def make_study(
2730
suffix="",
2831
comment=None,
2932
ignore_dependencies=False,
33+
parallel_servers=None,
3034
):
3135
"""Run a list of agents on a benchmark.
3236
@@ -57,10 +61,17 @@ def make_study(
5761
3x compare to sequential executionz. To accelerate execution, you can ignore
5862
dependencies and run in full parallel. This leads to a decrease in performance of about
5963
1%-2%, and could be more. Note: ignore_dependencies on VisualWebArena doesn't work.
60-
61-
Returns:
62-
Study object or SequentialStudies object if the benchmark requires manual reset after each
63-
evaluation such as WebArena and VisualWebArena.
64+
parallel_servers: list[WebArenaInstanceVars]
65+
The number of parallel servers to use `if "webarena" in benchmark.name`. Use this to
66+
dispatch agent_args on a pool of servers in parallel. If len(agent_args) >
67+
len(parallel_servers), the servers will be reused for next evaluation (with a reset) as
68+
soon as it is done.
69+
70+
Returns: Study | SequentialStudies | ParallelStudies object.
71+
SequentialStudies: if the benchmark requires manual reset after each evaluation such as
72+
WebArena and VisualWebArena.
73+
ParallelStudies: if the benchmark has multiple servers to run in parallel.
74+
Study: otherwise.
6475
"""
6576

6677
if not isinstance(agent_args, (list, tuple)):
@@ -69,7 +80,7 @@ def make_study(
6980
if isinstance(benchmark, str):
7081
benchmark = bgym.DEFAULT_BENCHMARKS[benchmark.lower()]()
7182

72-
if "webarena" in benchmark.name and len(agent_args) > 1:
83+
if len(agent_args) > 1 and ("webarena" in benchmark.name or parallel_servers is not None):
7384
logger.warning(
7485
"*WebArena* requires manual reset after each evaluation. Running through SequentialStudies."
7586
)
@@ -85,8 +96,10 @@ def make_study(
8596
ignore_dependencies=ignore_dependencies,
8697
)
8798
)
88-
89-
return SequentialStudies(studies)
99+
if parallel_servers is not None:
100+
return ParallelStudies(studies, parallel_servers=parallel_servers)
101+
else:
102+
return SequentialStudies(studies)
90103
else:
91104
return Study(
92105
agent_args,
@@ -164,7 +177,7 @@ class Study(AbstractStudy):
164177
A suffix to add to the study name. This can be useful to keep track of your experiments.
165178
By default the study name contains agent name, benchmark name and date.
166179
uuid: str
167-
A unique identifier for the study.
180+
A unique identifier for the study. Will be generated automatically.
168181
reproducibility_info: dict
169182
Information about the study that may affect the reproducibility of the experiment. e.g.:
170183
versions of BrowserGym, benchmark, AgentLab...
@@ -178,12 +191,12 @@ class Study(AbstractStudy):
178191
information. Leave any extra information that can explain why results could be different
179192
than expected.
180193
ignore_dependencies: bool
181-
If True, ignore the dependencies of the tasks in the benchmark. *Use with caution.* So
194+
If True, ignore the dependencies of the tasks in the benchmark. *Use with caution*. So
182195
far, only WebArena and VisualWebArena have dependencies between tasks to minimize the
183196
influence of solving one task before another one. This dependency graph allows
184197
experiments to run in parallel while respecting task dependencies. However, it still
185198
can't run more than 4 and, in practice it's speeding up evaluation by a factor of only
186-
3x compare to sequential executionz. To accelerate execution, you can ignore
199+
3x compare to sequential execution. To accelerate execution, you can ignore
187200
dependencies and run in full parallel. This leads to a decrease in performance of about
188201
1%-2%, and could be more. Note: ignore_dependencies on VisualWebArena doesn't work.
189202
avg_step_timeout: int
@@ -455,13 +468,15 @@ def run(self, n_jobs=1, parallel_backend="ray", strict_reproducibility=False, n_
455468
study.make_dir(exp_root=self.dir)
456469

457470
self.save()
458-
459-
for study in self.studies:
460-
study.run(n_jobs, parallel_backend, strict_reproducibility, n_relaunch)
471+
self._run(n_jobs, parallel_backend, strict_reproducibility, n_relaunch)
461472
_, summary_df, _ = self.get_results()
462473
logger.info("\n" + str(summary_df))
463474
logger.info(f"SequentialStudies {self.name} finished.")
464475

476+
def _run(self, n_jobs=1, parallel_backend="ray", strict_reproducibility=False, n_relaunch=3):
477+
for study in self.studies:
478+
study.run(n_jobs, parallel_backend, strict_reproducibility, n_relaunch)
479+
465480
def override_max_steps(self, max_steps):
466481
for study in self.studies:
467482
study.override_max_steps(max_steps)
@@ -471,6 +486,52 @@ def append_to_journal(self, strict_reproducibility=True):
471486
study.append_to_journal(strict_reproducibility=strict_reproducibility)
472487

473488

489+
def _init_worker(server_queue: Queue):
490+
"""Run once at the initialization of the worker in the multiprocessing.Pool.
491+
492+
This is typically used to initialize different environment variables of the WebArena server for
493+
multiple instances in parallel.
494+
"""
495+
server_instance = server_queue.get() # type: "WebArenaInstanceVars"
496+
logger.warning(f"Initializing server instance {server_instance} from process {os.getpid()}")
497+
server_instance.init()
498+
499+
500+
def _run_study(study: Study, n_jobs, parallel_backend, strict_reproducibility, n_relaunch):
501+
"""Wrapper to run a study remotely."""
502+
study.run(n_jobs, parallel_backend, strict_reproducibility, n_relaunch)
503+
504+
505+
@dataclass
506+
class ParallelStudies(SequentialStudies):
507+
508+
parallel_servers: list[BaseServer] | int = None
509+
510+
def _run(
511+
self,
512+
n_jobs=1,
513+
parallel_backend="ray",
514+
strict_reproducibility=False,
515+
n_relaunch=3,
516+
):
517+
parallel_servers = self.parallel_servers
518+
if isinstance(parallel_servers, int):
519+
parallel_servers = [BaseServer() for _ in range(parallel_servers)]
520+
521+
server_queue = Manager().Queue()
522+
for server in parallel_servers:
523+
server_queue.put(server)
524+
525+
with Pool(len(parallel_servers), initializer=_init_worker, initargs=(server_queue,)) as p:
526+
p.starmap(
527+
_run_study,
528+
[
529+
(study, n_jobs, parallel_backend, strict_reproducibility, n_relaunch)
530+
for study in self.studies
531+
],
532+
)
533+
534+
474535
def get_most_recent_study(
475536
root_dir: Path = None, date_format: str = "%Y-%m-%d_%H-%M-%S", contains=None
476537
):
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from agentlab.experiments.multi_server import WebArenaInstanceVars
2+
from browsergym.webarena.instance import WebArenaInstance
3+
4+
5+
def test_webarena_multiserver():
6+
instance = WebArenaInstanceVars.from_env_vars()
7+
instance_1 = instance.clone()
8+
instance_1.base_url = "http://webarena1.eastus.cloudapp.azure.com"
9+
instance_1.init()
10+
11+
bgym_instance = WebArenaInstance()
12+
base_url_1 = bgym_instance.urls["reddit"].rsplit(":", 1)[0]
13+
assert base_url_1 == instance_1.base_url
14+
15+
instance_2 = instance.clone()
16+
instance_2.base_url = "http://webarena2.eastus.cloudapp.azure.com"
17+
instance_2.init()
18+
19+
bgym_instance = WebArenaInstance()
20+
base_url_2 = bgym_instance.urls["reddit"].rsplit(":", 1)[0]
21+
assert base_url_2 == instance_2.base_url
22+
23+
24+
if __name__ == "__main__":
25+
test_webarena_multiserver()

tests/experiments/test_study.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import pytest
2+
from agentlab.agents.generic_agent.agent_configs import FLAGS_GPT_4o
3+
from agentlab.agents.generic_agent.generic_agent import GenericAgentArgs
4+
from agentlab.llm.chat_api import CheatMiniWoBLLMArgs
5+
from agentlab.experiments.study import ParallelStudies, make_study, Study
6+
from agentlab.experiments.multi_server import WebArenaInstanceVars
7+
8+
9+
def _make_agent_args_list():
10+
# CheatMiniWoB agents won't succeed on WebArena, this is just for testing parallelization
11+
agent_args_list = []
12+
for i in range(2):
13+
agent_args = GenericAgentArgs(
14+
chat_model_args=CheatMiniWoBLLMArgs(),
15+
flags=FLAGS_GPT_4o,
16+
)
17+
18+
agent_args.agent_name = agent_args.agent_name + f"_{i}"
19+
agent_args_list.append(agent_args)
20+
return agent_args_list
21+
22+
23+
@pytest.mark.skip(reason="This test requires WebArena instances to be running")
24+
def test_launch_parallel_study_webarena():
25+
agent_args_list = _make_agent_args_list()
26+
27+
server_instance_1 = WebArenaInstanceVars.from_env_vars()
28+
server_instance_2 = server_instance_1.clone()
29+
parallel_servers = [server_instance_1, server_instance_2]
30+
31+
study = make_study(
32+
agent_args_list, benchmark="webarena_tiny", parallel_servers=parallel_servers
33+
)
34+
assert isinstance(study, ParallelStudies)
35+
36+
study.run(n_jobs=4, parallel_backend="ray", n_relaunch=1)
37+
38+
39+
def test_launch_parallel_study():
40+
agent_args_list = _make_agent_args_list()
41+
42+
study = make_study(agent_args_list, benchmark="miniwob_tiny_test", parallel_servers=2)
43+
assert isinstance(study, ParallelStudies)
44+
45+
study.run(n_jobs=4, parallel_backend="ray", n_relaunch=1)
46+
_, summary_df, _ = study.get_results()
47+
assert len(summary_df) == 2
48+
for n_completed in summary_df["n_completed"]:
49+
assert n_completed == "4/4"
50+
51+
study_ = Study.load_study(study.study_dir)
52+
53+
54+
if __name__ == "__main__":
55+
test_launch_parallel_study()

0 commit comments

Comments
 (0)