|
1 | | -from concurrent.futures import ProcessPoolExecutor |
2 | 1 | import gzip |
3 | 2 | import logging |
4 | 3 | import os |
5 | 4 | import pickle |
| 5 | +import random |
6 | 6 | import uuid |
7 | 7 | from abc import ABC, abstractmethod |
| 8 | +from concurrent.futures import ProcessPoolExecutor |
8 | 9 | from dataclasses import dataclass |
9 | 10 | from datetime import datetime |
| 11 | +from multiprocessing import Manager, Pool, Queue |
10 | 12 | from pathlib import Path |
11 | 13 |
|
12 | 14 | import bgym |
|
19 | 21 | from agentlab.experiments.exp_utils import RESULTS_DIR, add_dependencies |
20 | 22 | from agentlab.experiments.launch_exp import find_incomplete, non_dummy_count, run_experiments |
21 | 23 | from agentlab.experiments.multi_server import BaseServer, WebArenaInstanceVars |
22 | | -from multiprocessing import Pool, Manager, Queue |
23 | | -import random |
24 | 24 |
|
25 | 25 | logger = logging.getLogger(__name__) |
26 | 26 |
|
@@ -238,7 +238,7 @@ def __post_init__(self): |
238 | 238 |
|
239 | 239 | def make_exp_args_list(self): |
240 | 240 | """Generate the exp_args_list from the agent_args and the benchmark.""" |
241 | | - self.exp_args_list = _agents_on_benchmark( |
| 241 | + self.exp_args_list = self.agents_on_benchmark( |
242 | 242 | self.agent_args, |
243 | 243 | self.benchmark, |
244 | 244 | logging_level=self.logging_level, |
@@ -424,6 +424,84 @@ def load(dir: Path) -> "Study": |
424 | 424 | def load_most_recent(root_dir: Path = None, contains=None) -> "Study": |
425 | 425 | return Study.load(get_most_recent_study(root_dir, contains=contains)) |
426 | 426 |
|
| 427 | + def agents_on_benchmark( |
| 428 | + self, |
| 429 | + agents: list[AgentArgs] | AgentArgs, |
| 430 | + benchmark: bgym.Benchmark, |
| 431 | + demo_mode=False, |
| 432 | + logging_level: int = logging.INFO, |
| 433 | + logging_level_stdout: int = logging.INFO, |
| 434 | + ignore_dependencies=False, |
| 435 | + ): |
| 436 | + """Run one or multiple agents on a benchmark. |
| 437 | +
|
| 438 | + Args: |
| 439 | + agents: list[AgentArgs] | AgentArgs |
| 440 | + The agent configuration(s) to run. |
| 441 | + benchmark: bgym.Benchmark |
| 442 | + The benchmark to run the agents on. |
| 443 | + demo_mode: bool |
| 444 | + If True, the experiments will be run in demo mode. |
| 445 | + logging_level: int |
| 446 | + The logging level for individual jobs. |
| 447 | + logging_level_stdout: int |
| 448 | + The logging level for the stdout. |
| 449 | + ignore_dependencies: bool |
| 450 | + If True, the dependencies will be ignored and all experiments can be run in parallel. |
| 451 | +
|
| 452 | + Returns: |
| 453 | + list[ExpArgs]: The list of experiments to run. |
| 454 | +
|
| 455 | + Raises: |
| 456 | + ValueError: If multiple agents are run on a benchmark that requires manual reset. |
| 457 | + """ |
| 458 | + |
| 459 | + if not isinstance(agents, (list, tuple)): |
| 460 | + agents = [agents] |
| 461 | + |
| 462 | + if benchmark.name.startswith("visualwebarena") or benchmark.name.startswith("webarena"): |
| 463 | + if len(agents) > 1: |
| 464 | + raise ValueError( |
| 465 | + f"Only one agent can be run on {benchmark.name} since the instance requires manual reset after each evaluation." |
| 466 | + ) |
| 467 | + |
| 468 | + for agent in agents: |
| 469 | + agent.set_benchmark( |
| 470 | + benchmark, demo_mode |
| 471 | + ) # the agent can adapt (lightly?) to the benchmark |
| 472 | + |
| 473 | + env_args_list = benchmark.env_args_list |
| 474 | + if demo_mode: |
| 475 | + set_demo_mode(env_args_list) |
| 476 | + |
| 477 | + exp_args_list = [] |
| 478 | + |
| 479 | + for agent in agents: |
| 480 | + for env_args in env_args_list: |
| 481 | + exp_args = ExpArgs( |
| 482 | + agent_args=agent, |
| 483 | + env_args=env_args, |
| 484 | + logging_level=logging_level, |
| 485 | + logging_level_stdout=logging_level_stdout, |
| 486 | + ) |
| 487 | + exp_args_list.append(exp_args) |
| 488 | + |
| 489 | + for i, exp_args in enumerate(exp_args_list): |
| 490 | + exp_args.order = i |
| 491 | + |
| 492 | + # not required with ray, but keeping around if we would need it for visualwebareana on joblib |
| 493 | + # _flag_sequential_exp(exp_args_list, benchmark) |
| 494 | + |
| 495 | + if not ignore_dependencies: |
| 496 | + # populate the depends_on field based on the task dependencies in the benchmark |
| 497 | + exp_args_list = add_dependencies(exp_args_list, benchmark.dependency_graph_over_tasks()) |
| 498 | + else: |
| 499 | + logger.warning( |
| 500 | + f"Ignoring dependencies for benchmark {benchmark.name}. This could lead to different results." |
| 501 | + ) |
| 502 | + |
| 503 | + return exp_args_list |
| 504 | + |
427 | 505 |
|
428 | 506 | def _make_study_name(agent_names, benchmark_names, suffix=None): |
429 | 507 | """Make a study name from the agent and benchmark names.""" |
@@ -634,82 +712,6 @@ def set_demo_mode(env_args_list: list[EnvArgs]): |
634 | 712 | env_args.slow_mo = 1000 |
635 | 713 |
|
636 | 714 |
|
637 | | -def _agents_on_benchmark( |
638 | | - agents: list[AgentArgs] | AgentArgs, |
639 | | - benchmark: bgym.Benchmark, |
640 | | - demo_mode=False, |
641 | | - logging_level: int = logging.INFO, |
642 | | - logging_level_stdout: int = logging.INFO, |
643 | | - ignore_dependencies=False, |
644 | | -): |
645 | | - """Run one or multiple agents on a benchmark. |
646 | | -
|
647 | | - Args: |
648 | | - agents: list[AgentArgs] | AgentArgs |
649 | | - The agent configuration(s) to run. |
650 | | - benchmark: bgym.Benchmark |
651 | | - The benchmark to run the agents on. |
652 | | - demo_mode: bool |
653 | | - If True, the experiments will be run in demo mode. |
654 | | - logging_level: int |
655 | | - The logging level for individual jobs. |
656 | | - logging_level_stdout: int |
657 | | - The logging level for the stdout. |
658 | | - ignore_dependencies: bool |
659 | | - If True, the dependencies will be ignored and all experiments can be run in parallel. |
660 | | -
|
661 | | - Returns: |
662 | | - list[ExpArgs]: The list of experiments to run. |
663 | | -
|
664 | | - Raises: |
665 | | - ValueError: If multiple agents are run on a benchmark that requires manual reset. |
666 | | - """ |
667 | | - |
668 | | - if not isinstance(agents, (list, tuple)): |
669 | | - agents = [agents] |
670 | | - |
671 | | - if benchmark.name.startswith("visualwebarena") or benchmark.name.startswith("webarena"): |
672 | | - if len(agents) > 1: |
673 | | - raise ValueError( |
674 | | - f"Only one agent can be run on {benchmark.name} since the instance requires manual reset after each evaluation." |
675 | | - ) |
676 | | - |
677 | | - for agent in agents: |
678 | | - agent.set_benchmark(benchmark, demo_mode) # the agent can adapt (lightly?) to the benchmark |
679 | | - |
680 | | - env_args_list = benchmark.env_args_list |
681 | | - if demo_mode: |
682 | | - set_demo_mode(env_args_list) |
683 | | - |
684 | | - exp_args_list = [] |
685 | | - |
686 | | - for agent in agents: |
687 | | - for env_args in env_args_list: |
688 | | - exp_args = ExpArgs( |
689 | | - agent_args=agent, |
690 | | - env_args=env_args, |
691 | | - logging_level=logging_level, |
692 | | - logging_level_stdout=logging_level_stdout, |
693 | | - ) |
694 | | - exp_args_list.append(exp_args) |
695 | | - |
696 | | - for i, exp_args in enumerate(exp_args_list): |
697 | | - exp_args.order = i |
698 | | - |
699 | | - # not required with ray, but keeping around if we would need it for visualwebareana on joblib |
700 | | - # _flag_sequential_exp(exp_args_list, benchmark) |
701 | | - |
702 | | - if not ignore_dependencies: |
703 | | - # populate the depends_on field based on the task dependencies in the benchmark |
704 | | - exp_args_list = add_dependencies(exp_args_list, benchmark.dependency_graph_over_tasks()) |
705 | | - else: |
706 | | - logger.warning( |
707 | | - f"Ignoring dependencies for benchmark {benchmark.name}. This could lead to different results." |
708 | | - ) |
709 | | - |
710 | | - return exp_args_list |
711 | | - |
712 | | - |
713 | 715 | # def _flag_sequential_exp(exp_args_list: list[ExpArgs], benchmark: Benchmark): |
714 | 716 | # if benchmark.name.startswith("visualwebarena"): |
715 | 717 | # sequential_subset = benchmark.subset_from_glob("requires_reset", "True") |
|
0 commit comments