Skip to content

Commit 30742cf

Browse files
committed
default args in the dataclass
1 parent f29f048 commit 30742cf

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

src/agentlab/benchmarks/osworld.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,18 @@ class OsworldGym(AbstractEnv):
1414
def __init__(
1515
self,
1616
task: dict,
17-
provider_name: str = "vmware",
18-
region: str | None = None,
19-
path_to_vm: str | None = None,
20-
snapshot_name: str = "init_state",
21-
action_space: str = "computer_13",
22-
cache_dir: str = "cache",
23-
screen_size: tuple[int, int] = (1920, 1080),
24-
headless: bool = False,
25-
require_a11y_tree: bool = True,
26-
require_terminal: bool = False,
27-
os_type: str = "Ubuntu",
28-
enable_proxy: bool = False,
17+
provider_name: str,
18+
region: str | None,
19+
path_to_vm: str | None,
20+
snapshot_name: str,
21+
action_space: str,
22+
cache_dir: str,
23+
screen_size: tuple[int, int],
24+
headless: bool,
25+
require_a11y_tree: bool,
26+
require_terminal: bool,
27+
os_type: str,
28+
enable_proxy: bool,
2929
):
3030
self.task = task
3131
self.env_info = {
@@ -72,21 +72,21 @@ def close(self):
7272
@dataclass
7373
class OsworldEnvArgs(AbstractEnvArgs):
7474
task: dict[str, Any]
75-
provider_name: str
76-
region: str | None
77-
path_to_vm: str | None
78-
snapshot_name: str
79-
action_space: str
80-
cache_dir: str
81-
screen_size: tuple[int, int]
82-
headless: bool
83-
require_a11y_tree: bool
84-
require_terminal: bool
85-
os_type: str
86-
enable_proxy: bool
75+
path_to_vm: str | None = None
76+
provider_name: str = "vmware" # path to .vmx file
77+
region: str = "us-east-1" # AWS specific, does not apply to all providers
78+
snapshot_name: str = "init_state" # snapshot name to revert to
79+
action_space: str = "computer_13" # "computer_13" | "pyautogui"
80+
cache_dir: str = "cache"
81+
screen_size: tuple[int, int] = (1920, 1080)
82+
headless: bool = False
83+
require_a11y_tree: bool = True
84+
require_terminal: bool = False
85+
os_type: str = "Ubuntu"
86+
enable_proxy: bool = False
8787

8888
def make_env(self) -> OsworldGym:
89-
logger.info(f"Creating OsworldGym with task: {self.task}")
89+
logger.info(f"Creating OSWorld Gym with task: {self.task}")
9090
gym = OsworldGym(
9191
task=self.task,
9292
provider_name=self.provider_name,

0 commit comments

Comments
 (0)