Skip to content

Commit 73baabe

Browse files
authored
Merge pull request #195 from ServiceNow/fix-daemonic-process-issue
Implement parallel processing for studies using ProcessPoolExecutor a…
2 parents 78ad38f + f7a55d7 commit 73baabe

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

src/agentlab/experiments/study.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from concurrent.futures import ProcessPoolExecutor
12
import gzip
23
import logging
34
import os
@@ -498,6 +499,8 @@ def _init_worker(server_queue: Queue):
498499
A queue of object implementing BaseServer to initialize (or anything with a init
499500
method).
500501
"""
502+
print("initializing server instance with on process", os.getpid())
503+
print(f"using queue {server_queue}")
501504
server_instance = server_queue.get() # type: "WebArenaInstanceVars"
502505
logger.warning(f"Initializing server instance {server_instance} from process {os.getpid()}")
503506
server_instance.init()
@@ -510,6 +513,42 @@ def _run_study(study: Study, n_jobs, parallel_backend, strict_reproducibility, n
510513

511514
@dataclass
512515
class ParallelStudies(SequentialStudies):
516+
parallel_servers: list[BaseServer] | int = None
517+
518+
def _run(
519+
self,
520+
n_jobs=1,
521+
parallel_backend="ray",
522+
strict_reproducibility=False,
523+
n_relaunch=3,
524+
):
525+
parallel_servers = self.parallel_servers
526+
if isinstance(parallel_servers, int):
527+
parallel_servers = [BaseServer() for _ in range(parallel_servers)]
528+
529+
server_queue = Manager().Queue()
530+
for server in parallel_servers:
531+
server_queue.put(server)
532+
533+
with ProcessPoolExecutor(
534+
max_workers=len(parallel_servers), initializer=_init_worker, initargs=(server_queue,)
535+
) as executor:
536+
# Create list of arguments for each study
537+
study_args = [
538+
(study, n_jobs, parallel_backend, strict_reproducibility, n_relaunch)
539+
for study in self.studies
540+
]
541+
542+
# Submit all tasks and wait for completion
543+
futures = [executor.submit(_run_study, *args) for args in study_args]
544+
545+
# Wait for all futures to complete and raise any exceptions
546+
for future in futures:
547+
future.result()
548+
549+
550+
@dataclass
551+
class ParallelStudies_alt(SequentialStudies):
513552

514553
parallel_servers: list[BaseServer] | int = None
515554

tests/experiments/test_study.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
from agentlab.llm.chat_api import CheatMiniWoBLLMArgs
55
from agentlab.experiments.study import ParallelStudies, make_study, Study
66
from agentlab.experiments.multi_server import WebArenaInstanceVars
7+
import logging
8+
9+
10+
logging.getLogger().setLevel(logging.INFO)
711

812

913
def _make_agent_args_list():
@@ -28,13 +32,18 @@ def manual_test_launch_parallel_study_webarena():
2832
server_instance_2 = server_instance_1.clone()
2933
server_instance_2.base_url = "http://webarena-slow.eastus.cloudapp.azure.com"
3034
parallel_servers = [server_instance_1, server_instance_2]
35+
# parallel_servers = [server_instance_2]
3136

3237
for server in parallel_servers:
3338
print(server)
3439

3540
study = make_study(
36-
agent_args_list, benchmark="webarena_tiny", parallel_servers=parallel_servers
41+
agent_args_list,
42+
benchmark="webarena_tiny",
43+
parallel_servers=parallel_servers,
44+
ignore_dependencies=True,
3745
)
46+
study.override_max_steps(2)
3847
assert isinstance(study, ParallelStudies)
3948

4049
study.run(n_jobs=4, parallel_backend="ray", n_relaunch=1)

0 commit comments

Comments
 (0)