|
| 1 | +import json |
| 2 | +import logging |
| 3 | +import os |
| 4 | + |
| 5 | +from agentlab.agents.tool_use_agent.tool_use_agent import OSWORLD_CLAUDE |
| 6 | +from agentlab.benchmarks.osworld import OsworldBenchmark |
| 7 | +from agentlab.experiments.study import Study, make_study |
| 8 | + |
| 9 | +fmt = "%(asctime)s - %(levelname)s - %(name)s:%(lineno)d - %(funcName)s() - %(message)s" |
| 10 | +logging.basicConfig(level=logging.INFO, force=True, format=fmt, handlers=[logging.StreamHandler()]) |
| 11 | + |
| 12 | + |
| 13 | +def get_most_recent_incomplete_study() -> Study: |
| 14 | + """ |
| 15 | + Relaunch an existing study, this will continue incomplete experiments and relaunch errored experiments. |
| 16 | + """ |
| 17 | + study = Study.load_most_recent() |
| 18 | + study.find_incomplete(include_errors=True) |
| 19 | + return study |
| 20 | + |
| 21 | + |
| 22 | +def get_task_ids() -> set[str]: |
| 23 | + with open("experiments/osworld_debug_task_ids.json", "r") as f: |
| 24 | + task_ids = json.load(f) |
| 25 | + return set([task["id"] for task in task_ids]) |
| 26 | + |
| 27 | + |
| 28 | +def main(): |
| 29 | + n_jobs = 4 |
| 30 | + use_vmware = True |
| 31 | + relaunch = False |
| 32 | + agent_args = [ |
| 33 | + OSWORLD_CLAUDE, |
| 34 | + # OSWORLD_OAI # performs poorly. |
| 35 | + ] # type: ignore |
| 36 | + parallel_backend = "ray" |
| 37 | + os.environ["AGENTLAB_DEBUG"] = os.environ.get("AGENTLAB_DEBUG", "1") |
| 38 | + |
| 39 | + study = make_study( |
| 40 | + benchmark=OsworldBenchmark( |
| 41 | + test_set_name="test_small.json" |
| 42 | + ), # or test_all.json (Exper) # type: ignore |
| 43 | + agent_args=agent_args, # type: ignore |
| 44 | + comment="osworld debug 2", |
| 45 | + logging_level=logging.INFO, |
| 46 | + logging_level_stdout=logging.INFO, |
| 47 | + ) |
| 48 | + |
| 49 | + if use_vmware: |
| 50 | + for exp_args in study.exp_args_list: |
| 51 | + exp_args.env_args.provider_name = "vmware" # type: ignore |
| 52 | + exp_args.env_args.path_to_vm = "OSWorld/vmware_vm_data/Ubuntu0/Ubuntu0.vmx" # type: ignore |
| 53 | + parallel_backend = "sequential" |
| 54 | + |
| 55 | + if os.environ.get("AGENTLAB_DEBUG"): |
| 56 | + task_ids = get_task_ids() |
| 57 | + study.exp_args_list = [exp_args for exp_args in study.exp_args_list if exp_args.env_args.task["id"] in task_ids] # type: ignore |
| 58 | + print(f"Debug on {len(study.exp_args_list)} experiments") |
| 59 | + n_jobs = 1 # Make sure to use 1 job when debugging in VS |
| 60 | + |
| 61 | + study = get_most_recent_incomplete_study() if relaunch else study |
| 62 | + study.run(n_jobs=n_jobs, n_relaunch=1, parallel_backend=parallel_backend) |
| 63 | + |
| 64 | + |
| 65 | +if __name__ == "__main__": |
| 66 | + main() |
0 commit comments