Skip to content

Commit 2d7d5a2

Browse files
Add max_steps parameter to OsworldGym and OsworldEnvArgs for step limit control
1 parent 400b947 commit 2d7d5a2

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/agentlab/benchmarks/osworld.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
require_terminal: bool,
3939
os_type: str,
4040
enable_proxy: bool,
41+
max_steps: int = 50,
4142
):
4243
self.task = task
4344
self.env_info = {
@@ -67,10 +68,13 @@ def __init__(
6768
require_terminal=require_terminal,
6869
os_type=os_type,
6970
)
71+
self._step_count = 0
72+
self.max_steps = max_steps
7073

7174
def reset(self, seed: int | None = None) -> tuple[dict[str, Any], dict[str, Any]]:
7275
raw_obs = self.env.reset(task_config=self.task, seed=seed)
7376
obs = self.env_to_agentlab_observation(raw_obs)
77+
self._step_count = 0
7478
return obs, self.env_info
7579

7680
@add_step_timing_to_env_info_decorator
@@ -79,7 +83,8 @@ def step(self, action: str):
7983
env_action = self.agentlab_to_env_action(action)
8084
logger.info(f"AgentLab Action returned: {action}, converted to: {env_action}")
8185
raw_obs, reward, done, info = self.env.step(env_action)
82-
truncated = False # Figure out how to handle truncation in OSWorld
86+
self._step_count += 1
87+
truncated = info.get('fail', False) or self._step_count >= self.max_steps
8388
obs = self.env_to_agentlab_observation(raw_obs)
8489
return obs, reward, done, truncated, info
8590

@@ -387,7 +392,7 @@ class OsworldEnvArgs(AbstractEnvArgs):
387392
require_terminal: bool = False
388393
os_type: str = "Ubuntu"
389394
enable_proxy: bool = False
390-
# TODO: Add max steps.
395+
max_steps: int = 100
391396

392397
def make_env(
393398
self, exp_dir: Path, action_mapping=None, use_raw_page_output: bool = False
@@ -407,6 +412,7 @@ def make_env(
407412
require_terminal=self.require_terminal,
408413
os_type=self.os_type,
409414
enable_proxy=self.enable_proxy,
415+
max_steps=self.max_steps,
410416
)
411417
return gym
412418

0 commit comments

Comments
 (0)