@@ -38,6 +38,7 @@ def __init__(
3838 require_terminal : bool ,
3939 os_type : str ,
4040 enable_proxy : bool ,
41+ max_steps : int = 50 ,
4142 ):
4243 self .task = task
4344 self .env_info = {
@@ -67,10 +68,13 @@ def __init__(
6768 require_terminal = require_terminal ,
6869 os_type = os_type ,
6970 )
71+ self ._step_count = 0
72+ self .max_steps = max_steps
7073
7174 def reset (self , seed : int | None = None ) -> tuple [dict [str , Any ], dict [str , Any ]]:
7275 raw_obs = self .env .reset (task_config = self .task , seed = seed )
7376 obs = self .env_to_agentlab_observation (raw_obs )
77+ self ._step_count = 0
7478 return obs , self .env_info
7579
7680 @add_step_timing_to_env_info_decorator
@@ -79,7 +83,8 @@ def step(self, action: str):
7983 env_action = self .agentlab_to_env_action (action )
8084 logger .info (f"AgentLab Action returned: { action } , converted to: { env_action } " )
8185 raw_obs , reward , done , info = self .env .step (env_action )
82- truncated = False # Figure out how to handle truncation in OSWorld
86+ self ._step_count += 1
87+ truncated = info .get ('fail' , False ) or self ._step_count >= self .max_steps
8388 obs = self .env_to_agentlab_observation (raw_obs )
8489 return obs , reward , done , truncated , info
8590
@@ -387,7 +392,7 @@ class OsworldEnvArgs(AbstractEnvArgs):
387392 require_terminal : bool = False
388393 os_type : str = "Ubuntu"
389394 enable_proxy : bool = False
390- # TODO: Add max steps.
395+ max_steps : int = 100
391396
392397 def make_env (
393398 self , exp_dir : Path , action_mapping = None , use_raw_page_output : bool = False
@@ -407,6 +412,7 @@ def make_env(
407412 require_terminal = self .require_terminal ,
408413 os_type = self .os_type ,
409414 enable_proxy = self .enable_proxy ,
415+ max_steps = self .max_steps ,
410416 )
411417 return gym
412418
0 commit comments