|
| 1 | +import json |
1 | 2 | import logging |
| 3 | +import os |
2 | 4 | from dataclasses import dataclass |
3 | 5 | from typing import Any |
4 | 6 |
|
5 | 7 | from desktop_env.desktop_env import DesktopEnv |
6 | | -from distributed.protocol.cupy import d |
7 | 8 |
|
8 | 9 | from agentlab.benchmarks.abstract_env import AbstractBenchmark, AbstractEnv, AbstractEnvArgs |
9 | 10 |
|
@@ -107,4 +108,29 @@ def make_env(self) -> OsworldGym: |
107 | 108 |
|
108 | 109 | class OsworldBenchmark(AbstractBenchmark): |
109 | 110 | name: str = "osworld" |
110 | | - env_args_list: list[OsworldEnvArgs] |
| 111 | + test_set_path: str = "OSWorld/evaluation_examples" |
| 112 | + test_set_name: str = "test_all.json" |
| 113 | + domain: str = "all" |
| 114 | + env_args: OsworldEnvArgs = None # type: ignore # basic env configuration for all tasks |
| 115 | + env_args_list: list[OsworldEnvArgs] = None # type: ignore |
| 116 | + |
| 117 | + def model_post_init(self, __context: Any) -> None: |
| 118 | + self.env_args_list = [] |
| 119 | + with open(os.path.join(self.test_set_path, self.test_set_name)) as f: |
| 120 | + tasks = json.load(f) |
| 121 | + if self.domain != "all": |
| 122 | + tasks = {self.domain: tasks[self.domain]} |
| 123 | + |
| 124 | + for domain in tasks: |
| 125 | + for task_id in tasks[domain]: |
| 126 | + task_file = os.path.join(self.test_set_path, f"examples/{domain}/{task_id}.json") |
| 127 | + with open(task_file) as f: |
| 128 | + task = json.load(f) |
| 129 | + |
| 130 | + if self.env_args: |
| 131 | + env_args = self.env_args.copy() |
| 132 | + env_args.task = task |
| 133 | + else: |
| 134 | + env_args = OsworldEnvArgs(task=task) |
| 135 | + self.env_args_list.append(env_args) |
| 136 | + logger.info(f"Loaded {len(self.env_args_list)} tasks from domain '{self.domain}'") |
0 commit comments