Skip to content

Commit 64c8bc9

Browse files
authored
Merge pull request #180 from ServiceNow/parallel-study
parallel study evaluation
2 parents fc4c62f + 46f84d0 commit 64c8bc9

File tree

6 files changed

+274
-12
lines changed

6 files changed

+274
-12
lines changed

src/agentlab/agents/agent_args.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@
33

44

55
class AgentArgs(AbstractAgentArgs):
6+
"""Base class for agent arguments for instantiating an agent.
7+
8+
Define agent arguments as dataclass variables of this class. For example:
9+
10+
class MyAgentArgs(AgentArgs):
11+
my_arg: str = "default_value"
12+
my_other_arg: int = 42
13+
14+
Note: for working properly with AgentXRay, the arguments need to be serializable and hasable.
15+
"""
616

717
def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode: bool):
818
"""Optional method to set benchmark specific flags.
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from copy import deepcopy
2+
from dataclasses import dataclass
3+
import os
4+
import sys
5+
from browsergym.webarena.instance import WebArenaInstance
6+
7+
8+
class BaseServer:
9+
"""Base class for server instances.
10+
11+
Behaves like an identity function for running in parallel on servers that don't need multiple
12+
instances.
13+
"""
14+
15+
def init(self):
16+
pass
17+
18+
19+
@dataclass
20+
class WebArenaInstanceVars(BaseServer):
21+
base_url: str
22+
shopping: str
23+
shopping_admin: str
24+
reddit: str
25+
gitlab: str
26+
wikipedia: str
27+
map: str
28+
homepage: str
29+
full_reset: str
30+
module_name: str = "webarena"
31+
prefix: str = "WA_"
32+
33+
def make_env_vars(self):
34+
"""Return a dictionary of environment variables"""
35+
return {
36+
f"{self.prefix}SHOPPING": f"{self.base_url}:{self.shopping}",
37+
f"{self.prefix}SHOPPING_ADMIN": f"{self.base_url}:{self.shopping_admin}",
38+
f"{self.prefix}REDDIT": f"{self.base_url}:{self.reddit}",
39+
f"{self.prefix}GITLAB": f"{self.base_url}:{self.gitlab}",
40+
f"{self.prefix}WIKIPEDIA": f"{self.base_url}:{self.wikipedia}",
41+
f"{self.prefix}MAP": f"{self.base_url}:{self.map}",
42+
f"{self.prefix}HOMEPAGE": f"{self.base_url}:{self.homepage}",
43+
f"{self.prefix}FULL_RESET": f"{self.base_url}:{self.full_reset}",
44+
}
45+
46+
def init(self):
47+
# necessary for webarena to re-import the env vars
48+
unimport_modules(self.module_name)
49+
for key, value in self.make_env_vars().items():
50+
os.environ[key] = value
51+
52+
# this is just a dynamic check to see that the env vars are set correctly
53+
bgym_instance = WebArenaInstance()
54+
base_url, _ = _split_url(bgym_instance.urls["reddit"])
55+
assert base_url == self.base_url, f"Expected {self.base_url}, got {base_url}"
56+
57+
@staticmethod
58+
def from_env_vars(prefix="WA_", module_name="webarena"):
59+
kwargs = {"module_name": module_name}
60+
base_urls = set()
61+
for key, url in os.environ.items():
62+
if key.startswith(prefix):
63+
base_url, url_tail = _split_url(url)
64+
base_urls.add(base_url)
65+
kwargs[key[len(prefix) :].lower()] = url_tail
66+
67+
if len(base_urls) > 1:
68+
raise ValueError("Multiple base urls found in environment variables")
69+
70+
kwargs["base_url"] = base_urls.pop()
71+
return WebArenaInstanceVars(**kwargs)
72+
73+
def clone(self):
74+
"""Return a deep copy of the instance"""
75+
return deepcopy(self)
76+
77+
78+
def unimport_modules(base_name):
79+
"""un-import any module starting with base_name"""
80+
for module in sys.modules.copy():
81+
if module.startswith(base_name):
82+
del sys.modules[module]
83+
84+
85+
def _split_url(url: str):
86+
"""Extract the base url and the port/page from a url"""
87+
parts = url.split(":")
88+
base_url = ":".join(parts[0:2])
89+
url_tail = ":".join(parts[2:])
90+
return base_url, url_tail

src/agentlab/experiments/study.py

Lines changed: 78 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import gzip
22
import logging
3+
import os
34
import pickle
45
import uuid
56
from abc import ABC, abstractmethod
@@ -16,6 +17,8 @@
1617
from agentlab.experiments import reproducibility_util as repro
1718
from agentlab.experiments.exp_utils import RESULTS_DIR, add_dependencies
1819
from agentlab.experiments.launch_exp import find_incomplete, non_dummy_count, run_experiments
20+
from agentlab.experiments.multi_server import BaseServer, WebArenaInstanceVars
21+
from multiprocessing import Pool, Manager, Queue
1922

2023
logger = logging.getLogger(__name__)
2124

@@ -27,6 +30,7 @@ def make_study(
2730
suffix="",
2831
comment=None,
2932
ignore_dependencies=False,
33+
parallel_servers=None,
3034
):
3135
"""Run a list of agents on a benchmark.
3236
@@ -57,10 +61,18 @@ def make_study(
5761
3x compare to sequential executionz. To accelerate execution, you can ignore
5862
dependencies and run in full parallel. This leads to a decrease in performance of about
5963
1%-2%, and could be more. Note: ignore_dependencies on VisualWebArena doesn't work.
64+
parallel_servers: list[WebArenaInstanceVars]
65+
The number of parallel servers to use `if "webarena" in benchmark.name`. Use this to
66+
dispatch agent_args on a pool of servers in parallel. If len(agent_args) >
67+
len(parallel_servers), the servers will be reused for next evaluation (with a reset) as
68+
soon as it is done.
6069
6170
Returns:
62-
Study object or SequentialStudies object if the benchmark requires manual reset after each
63-
evaluation such as WebArena and VisualWebArena.
71+
Study | SequentialStudies | ParallelStudies object.
72+
SequentialStudies: if the benchmark requires manual reset after each evaluation such as
73+
WebArena and VisualWebArena.
74+
ParallelStudies: if the benchmark has multiple servers to run in parallel.
75+
Study: otherwise.
6476
"""
6577

6678
if not isinstance(agent_args, (list, tuple)):
@@ -69,7 +81,7 @@ def make_study(
6981
if isinstance(benchmark, str):
7082
benchmark = bgym.DEFAULT_BENCHMARKS[benchmark.lower()]()
7183

72-
if "webarena" in benchmark.name and len(agent_args) > 1:
84+
if len(agent_args) > 1 and ("webarena" in benchmark.name or parallel_servers is not None):
7385
logger.warning(
7486
"*WebArena* requires manual reset after each evaluation. Running through SequentialStudies."
7587
)
@@ -85,8 +97,10 @@ def make_study(
8597
ignore_dependencies=ignore_dependencies,
8698
)
8799
)
88-
89-
return SequentialStudies(studies)
100+
if parallel_servers is not None:
101+
return ParallelStudies(studies, parallel_servers=parallel_servers)
102+
else:
103+
return SequentialStudies(studies)
90104
else:
91105
return Study(
92106
agent_args,
@@ -164,7 +178,7 @@ class Study(AbstractStudy):
164178
A suffix to add to the study name. This can be useful to keep track of your experiments.
165179
By default the study name contains agent name, benchmark name and date.
166180
uuid: str
167-
A unique identifier for the study.
181+
A unique identifier for the study. Will be generated automatically.
168182
reproducibility_info: dict
169183
Information about the study that may affect the reproducibility of the experiment. e.g.:
170184
versions of BrowserGym, benchmark, AgentLab...
@@ -178,12 +192,12 @@ class Study(AbstractStudy):
178192
information. Leave any extra information that can explain why results could be different
179193
than expected.
180194
ignore_dependencies: bool
181-
If True, ignore the dependencies of the tasks in the benchmark. *Use with caution.* So
195+
If True, ignore the dependencies of the tasks in the benchmark. *Use with caution*. So
182196
far, only WebArena and VisualWebArena have dependencies between tasks to minimize the
183197
influence of solving one task before another one. This dependency graph allows
184198
experiments to run in parallel while respecting task dependencies. However, it still
185199
can't run more than 4 and, in practice it's speeding up evaluation by a factor of only
186-
3x compare to sequential executionz. To accelerate execution, you can ignore
200+
3x compare to sequential execution. To accelerate execution, you can ignore
187201
dependencies and run in full parallel. This leads to a decrease in performance of about
188202
1%-2%, and could be more. Note: ignore_dependencies on VisualWebArena doesn't work.
189203
avg_step_timeout: int
@@ -455,13 +469,15 @@ def run(self, n_jobs=1, parallel_backend="ray", strict_reproducibility=False, n_
455469
study.make_dir(exp_root=self.dir)
456470

457471
self.save()
458-
459-
for study in self.studies:
460-
study.run(n_jobs, parallel_backend, strict_reproducibility, n_relaunch)
472+
self._run(n_jobs, parallel_backend, strict_reproducibility, n_relaunch)
461473
_, summary_df, _ = self.get_results()
462474
logger.info("\n" + str(summary_df))
463475
logger.info(f"SequentialStudies {self.name} finished.")
464476

477+
def _run(self, n_jobs=1, parallel_backend="ray", strict_reproducibility=False, n_relaunch=3):
478+
for study in self.studies:
479+
study.run(n_jobs, parallel_backend, strict_reproducibility, n_relaunch)
480+
465481
def override_max_steps(self, max_steps):
466482
for study in self.studies:
467483
study.override_max_steps(max_steps)
@@ -471,6 +487,57 @@ def append_to_journal(self, strict_reproducibility=True):
471487
study.append_to_journal(strict_reproducibility=strict_reproducibility)
472488

473489

490+
def _init_worker(server_queue: Queue):
491+
"""Run once at the initialization of the worker in the multiprocessing.Pool.
492+
493+
This is typically used to initialize different environment variables of the WebArena server for
494+
multiple instances in parallel.
495+
496+
Args:
497+
server_queue: Queue
498+
A queue of object implementing BaseServer to initialize (or anything with a init
499+
method).
500+
"""
501+
server_instance = server_queue.get() # type: "WebArenaInstanceVars"
502+
logger.warning(f"Initializing server instance {server_instance} from process {os.getpid()}")
503+
server_instance.init()
504+
505+
506+
def _run_study(study: Study, n_jobs, parallel_backend, strict_reproducibility, n_relaunch):
507+
"""Wrapper to run a study remotely."""
508+
study.run(n_jobs, parallel_backend, strict_reproducibility, n_relaunch)
509+
510+
511+
@dataclass
512+
class ParallelStudies(SequentialStudies):
513+
514+
parallel_servers: list[BaseServer] | int = None
515+
516+
def _run(
517+
self,
518+
n_jobs=1,
519+
parallel_backend="ray",
520+
strict_reproducibility=False,
521+
n_relaunch=3,
522+
):
523+
parallel_servers = self.parallel_servers
524+
if isinstance(parallel_servers, int):
525+
parallel_servers = [BaseServer() for _ in range(parallel_servers)]
526+
527+
server_queue = Manager().Queue()
528+
for server in parallel_servers:
529+
server_queue.put(server)
530+
531+
with Pool(len(parallel_servers), initializer=_init_worker, initargs=(server_queue,)) as p:
532+
p.starmap(
533+
_run_study,
534+
[
535+
(study, n_jobs, parallel_backend, strict_reproducibility, n_relaunch)
536+
for study in self.studies
537+
],
538+
)
539+
540+
474541
def get_most_recent_study(
475542
root_dir: Path = None, date_format: str = "%Y-%m-%d_%H-%M-%S", contains=None
476543
):
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from agentlab.experiments.multi_server import WebArenaInstanceVars
2+
from browsergym.webarena.instance import WebArenaInstance
3+
4+
5+
def test_webarena_multiserver():
6+
7+
instance_1 = WebArenaInstanceVars(
8+
base_url="http://webarena1.eastus.cloudapp.azure.com",
9+
shopping="8082/",
10+
shopping_admin="8083/admin",
11+
reddit="8080",
12+
gitlab="9001",
13+
wikipedia="8081/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing",
14+
map="443",
15+
homepage="80",
16+
full_reset="7565",
17+
module_name="webarena",
18+
prefix="WA_",
19+
)
20+
21+
instance_1.init()
22+
23+
bgym_instance = WebArenaInstance()
24+
base_url_1 = bgym_instance.urls["reddit"].rsplit(":", 1)[0]
25+
assert base_url_1 == instance_1.base_url
26+
27+
instance_2 = instance_1.clone()
28+
instance_2.base_url = "http://webarena2.eastus.cloudapp.azure.com"
29+
instance_2.init()
30+
31+
bgym_instance = WebArenaInstance()
32+
base_url_2 = bgym_instance.urls["reddit"].rsplit(":", 1)[0]
33+
assert base_url_2 == instance_2.base_url
34+
35+
36+
if __name__ == "__main__":
37+
test_webarena_multiserver()

tests/experiments/test_ray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_execute_task_graph():
3131
# Verify that parallel tasks (task2 and task3) started within a short time of each other
3232
parallel_start_diff = abs(exp_args_list[1].start_time - exp_args_list[2].start_time)
3333
print(f"parallel_start_diff: {parallel_start_diff}")
34-
assert parallel_start_diff < 1.5 # Allow for a small delay
34+
assert parallel_start_diff < 2 # Allow for a small delay
3535

3636
# Ensure that the entire task graph took the expected amount of time
3737
total_time = exp_args_list[-1].end_time - exp_args_list[0].start_time

tests/experiments/test_study.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import pytest
2+
from agentlab.agents.generic_agent.agent_configs import FLAGS_GPT_4o
3+
from agentlab.agents.generic_agent.generic_agent import GenericAgentArgs
4+
from agentlab.llm.chat_api import CheatMiniWoBLLMArgs
5+
from agentlab.experiments.study import ParallelStudies, make_study, Study
6+
from agentlab.experiments.multi_server import WebArenaInstanceVars
7+
8+
9+
def _make_agent_args_list():
10+
# CheatMiniWoB agents won't succeed on WebArena, this is just for testing parallelization
11+
agent_args_list = []
12+
for i in range(2):
13+
agent_args = GenericAgentArgs(
14+
chat_model_args=CheatMiniWoBLLMArgs(),
15+
flags=FLAGS_GPT_4o,
16+
)
17+
18+
agent_args.agent_name = agent_args.agent_name + f"_{i}"
19+
agent_args_list.append(agent_args)
20+
return agent_args_list
21+
22+
23+
@pytest.mark.skip(reason="This test requires WebArena instances to be running")
24+
def manual_test_launch_parallel_study_webarena():
25+
agent_args_list = _make_agent_args_list()
26+
27+
server_instance_1 = WebArenaInstanceVars.from_env_vars()
28+
server_instance_2 = server_instance_1.clone()
29+
server_instance_2.base_url = "http://webarena-slow.eastus.cloudapp.azure.com"
30+
parallel_servers = [server_instance_1, server_instance_2]
31+
32+
for server in parallel_servers:
33+
print(server)
34+
35+
study = make_study(
36+
agent_args_list, benchmark="webarena_tiny", parallel_servers=parallel_servers
37+
)
38+
assert isinstance(study, ParallelStudies)
39+
40+
study.run(n_jobs=4, parallel_backend="ray", n_relaunch=1)
41+
42+
43+
def test_launch_parallel_study():
44+
agent_args_list = _make_agent_args_list()
45+
46+
study = make_study(agent_args_list, benchmark="miniwob_tiny_test", parallel_servers=2)
47+
assert isinstance(study, ParallelStudies)
48+
49+
study.run(n_jobs=4, parallel_backend="ray", n_relaunch=1)
50+
_, summary_df, _ = study.get_results()
51+
assert len(summary_df) == 2
52+
for n_completed in summary_df["n_completed"]:
53+
assert n_completed == "4/4"
54+
55+
56+
if __name__ == "__main__":
57+
# test_launch_parallel_study()
58+
manual_test_launch_parallel_study_webarena()

0 commit comments

Comments
 (0)