Skip to content

Commit 5448b48

Browse files
committed
osworld bench tasks loading
1 parent 41e5298 commit 5448b48

File tree

1 file changed

+28
-2
lines changed

1 file changed

+28
-2
lines changed

src/agentlab/benchmarks/osworld.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import json
12
import logging
3+
import os
24
from dataclasses import dataclass
35
from typing import Any
46

57
from desktop_env.desktop_env import DesktopEnv
6-
from distributed.protocol.cupy import d
78

89
from agentlab.benchmarks.abstract_env import AbstractBenchmark, AbstractEnv, AbstractEnvArgs
910

@@ -107,4 +108,29 @@ def make_env(self) -> OsworldGym:
107108

108109
class OsworldBenchmark(AbstractBenchmark):
109110
name: str = "osworld"
110-
env_args_list: list[OsworldEnvArgs]
111+
test_set_path: str = "OSWorld/evaluation_examples"
112+
test_set_name: str = "test_all.json"
113+
domain: str = "all"
114+
env_args: OsworldEnvArgs = None # type: ignore # basic env configuration for all tasks
115+
env_args_list: list[OsworldEnvArgs] = None # type: ignore
116+
117+
def model_post_init(self, __context: Any) -> None:
118+
self.env_args_list = []
119+
with open(os.path.join(self.test_set_path, self.test_set_name)) as f:
120+
tasks = json.load(f)
121+
if self.domain != "all":
122+
tasks = {self.domain: tasks[self.domain]}
123+
124+
for domain in tasks:
125+
for task_id in tasks[domain]:
126+
task_file = os.path.join(self.test_set_path, f"examples/{domain}/{task_id}.json")
127+
with open(task_file) as f:
128+
task = json.load(f)
129+
130+
if self.env_args:
131+
env_args = self.env_args.copy()
132+
env_args.task = task
133+
else:
134+
env_args = OsworldEnvArgs(task=task)
135+
self.env_args_list.append(env_args)
136+
logger.info(f"Loaded {len(self.env_args_list)} tasks from domain '{self.domain}'")

0 commit comments

Comments
 (0)