Skip to content

Commit adbaf2d

Browse files
committed
boilerplate
1 parent dce2633 commit adbaf2d

File tree

2 files changed

+113
-1
lines changed

2 files changed

+113
-1
lines changed

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ matplotlib
2626
ray[default]
2727
python-slugify
2828
pillow
29-
gymnasium>=0.27
29+
gymnasium>=0.27
30+
desktop-env~=0.1.22

src/agentlab/benchmarks/osworld.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import logging
2+
from dataclasses import dataclass
3+
from math import log
4+
from typing import Any
5+
6+
from desktop_env.desktop_env import DesktopEnv
7+
from distributed.protocol.cupy import d
8+
9+
from agentlab.benchmarks.abstract_env import AbstractBenchmark, AbstractEnv, AbstractEnvArgs
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class OsworldGym(AbstractEnv):
15+
def __init__(
16+
self,
17+
task: dict,
18+
provider_name: str = "vmware",
19+
region: str | None = None,
20+
path_to_vm: str | None = None,
21+
snapshot_name: str = "init_state",
22+
action_space: str = "computer_13",
23+
cache_dir: str = "cache",
24+
screen_size: tuple[int, int] = (1920, 1080),
25+
headless: bool = False,
26+
require_a11y_tree: bool = True,
27+
require_terminal: bool = False,
28+
os_type: str = "Ubuntu",
29+
enable_proxy: bool = False,
30+
):
31+
self.task = task
32+
self.env_info = {
33+
"provider_name": provider_name,
34+
"region": region,
35+
"path_to_vm": path_to_vm,
36+
"snapshot_name": snapshot_name,
37+
"action_space": action_space,
38+
"cache_dir": cache_dir,
39+
"screen_size": screen_size,
40+
"headless": headless,
41+
"require_a11y_tree": require_a11y_tree,
42+
"require_terminal": require_terminal,
43+
"os_type": os_type,
44+
"enable_proxy": enable_proxy,
45+
}
46+
self.env = DesktopEnv(
47+
action_space=action_space,
48+
provider_name=provider_name,
49+
region=region, # type: ignore
50+
path_to_vm=path_to_vm, # type: ignore
51+
snapshot_name=snapshot_name,
52+
cache_dir=cache_dir,
53+
screen_size=screen_size, # type: ignore
54+
headless=headless,
55+
require_a11y_tree=require_a11y_tree,
56+
require_terminal=require_terminal,
57+
os_type=os_type,
58+
)
59+
60+
def reset(self, seed: int | None = None) -> tuple[dict[str, Any], dict[str, Any]]:
61+
obs = self.env.reset(task_config=self.task, seed=seed)
62+
return obs, self.env_info
63+
64+
def step(self, action: str):
65+
obs, reward, done, info = self.env.step(action)
66+
truncated = False
67+
return obs, reward, done, truncated, info
68+
69+
def close(self):
70+
return self.env.close()
71+
72+
73+
@dataclass
74+
class OsworldEnvArgs(AbstractEnvArgs):
75+
task: dict[str, Any]
76+
provider_name: str
77+
region: str | None
78+
path_to_vm: str | None
79+
snapshot_name: str
80+
action_space: str
81+
cache_dir: str
82+
screen_size: tuple[int, int]
83+
headless: bool
84+
require_a11y_tree: bool
85+
require_terminal: bool
86+
os_type: str
87+
enable_proxy: bool
88+
89+
def make_env(self) -> OsworldGym:
90+
logger.info(f"Creating OsworldGym with task: {self.task}")
91+
gym = OsworldGym(
92+
task=self.task,
93+
provider_name=self.provider_name,
94+
region=self.region,
95+
path_to_vm=self.path_to_vm,
96+
snapshot_name=self.snapshot_name,
97+
action_space=self.action_space,
98+
cache_dir=self.cache_dir,
99+
screen_size=self.screen_size,
100+
headless=self.headless,
101+
require_a11y_tree=self.require_a11y_tree,
102+
require_terminal=self.require_terminal,
103+
os_type=self.os_type,
104+
enable_proxy=self.enable_proxy,
105+
)
106+
return gym
107+
108+
109+
class OsworldBenchmark(AbstractBenchmark):
110+
name: str = "osworld"
111+
env_args_list: list[OsworldEnvArgs]

0 commit comments

Comments
 (0)