|
2 | 2 | import logging |
3 | 3 | import os |
4 | 4 |
|
5 | | -from agentlab.agents.tool_use_agent.tool_use_agent import OSWORLD_CLAUDE |
| 5 | +from tapeagents import agent |
| 6 | + |
| 7 | +from agentlab.agents.tool_use_agent.tool_use_agent import OSWORLD_CLAUDE, OSWORLD_OAI |
6 | 8 | from agentlab.benchmarks.osworld import OsworldBenchmark |
7 | | -from agentlab.experiments.study import make_study |
| 9 | +from agentlab.experiments.study import make_study, Study |
8 | 10 |
|
9 | 11 | fmt = "%(asctime)s - %(levelname)s - %(name)s:%(lineno)d - %(funcName)s() - %(message)s" |
10 | 12 | logging.basicConfig(level=logging.INFO, force=True, format=fmt, handlers=[logging.StreamHandler()]) |
11 | 13 |
|
12 | 14 |
|
| 15 | +def get_most_recent_incomplete_study() -> Study: |
| 16 | + """ |
| 17 | + Relaunch an existing study, this will continue incomplete experiments and relaunch errored experiments. |
| 18 | + """ |
| 19 | + study = Study.load_most_recent() |
| 20 | + study.find_incomplete(include_errors=True) |
| 21 | + return study |
| 22 | + |
13 | 23 | def get_task_ids() -> set[str]: |
14 | 24 | with open("experiments/osworld_debug_task_ids.json", "r") as f: |
15 | 25 | task_ids = json.load(f) |
16 | 26 | return set([task["id"] for task in task_ids]) |
17 | 27 |
|
18 | 28 |
|
19 | 29 | def main(): |
20 | | - n_jobs = 1 |
21 | | - os.environ["AGENTLAB_DEBUG"] = "1" |
| 30 | + n_jobs = 4 |
| 31 | + use_vmware = True |
| 32 | + relaunch = True |
| 33 | + agent_args = [ |
| 34 | + OSWORLD_CLAUDE, |
| 35 | + # OSWORLD_OAI # performs poorly. |
| 36 | + ] # type: ignore |
| 37 | + parallel_backend = "ray" |
| 38 | + os.environ["AGENTLAB_DEBUG"] = os.environ.get("AGENTLAB_DEBUG", "1") |
| 39 | + |
22 | 40 | study = make_study( |
23 | | - benchmark=OsworldBenchmark(test_set_name="test_small.json"), # type: ignore |
24 | | - agent_args=[OSWORLD_CLAUDE], |
| 41 | + benchmark=OsworldBenchmark(test_set_name="test_small.json"), # or test_all.json (Exper) # type: ignore |
| 42 | + agent_args=agent_args, # type: ignore |
25 | 43 | comment="osworld debug 2", |
26 | 44 | logging_level=logging.INFO, |
27 | 45 | logging_level_stdout=logging.INFO, |
28 | 46 | ) |
29 | 47 |
|
| 48 | + if use_vmware: |
| 49 | + for exp_args in study.exp_args_list: |
| 50 | + exp_args.env_args.provider_name = "vmware" # type: ignore |
| 51 | + exp_args.env_args.path_to_vm = "OSWorld/vmware_vm_data/Ubuntu0/Ubuntu0.vmx" # type: ignore |
| 52 | + parallel_backend = "sequential" |
| 53 | + |
30 | 54 | if os.environ.get("AGENTLAB_DEBUG"): |
31 | 55 | task_ids = get_task_ids() |
32 | 56 | study.exp_args_list = [exp_args for exp_args in study.exp_args_list if exp_args.env_args.task["id"] in task_ids] # type: ignore |
33 | 57 | print(f"Debug on {len(study.exp_args_list)} experiments") |
34 | | - study.run(n_jobs=4, n_relaunch=1, parallel_backend="ray") |
35 | | - else: |
36 | | - study.run(n_jobs=n_jobs, n_relaunch=1, parallel_backend="ray") |
| 58 | + n_jobs = 1 # Make sure to use 1 job when debugging in VS |
| 59 | + |
| 60 | + study = get_most_recent_incomplete_study() if relaunch else study |
| 61 | + study.run(n_jobs=n_jobs, n_relaunch=1, parallel_backend=parallel_backend) |
37 | 62 |
|
38 | 63 |
|
39 | 64 | if __name__ == "__main__": |
|
0 commit comments