Skip to content

Commit 7449033

Browse files
committed
put video recording under flag, lint
1 parent 815893c commit 7449033

File tree

4 files changed

+25
-17
lines changed

4 files changed

+25
-17
lines changed

experiments/run_osworld.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ def main():
2929

3030
if os.environ.get("AGENTLAB_DEBUG"):
3131
task_ids = get_task_ids()
32-
study.exp_args_list = [exp_args for exp_args in study.exp_args_list if exp_args.env_args.task["id"] in task_ids]
32+
study.exp_args_list = [exp_args for exp_args in study.exp_args_list if exp_args.env_args.task["id"] in task_ids] # type: ignore
3333
print(f"Debug on {len(study.exp_args_list)} experiments")
34-
study.run(n_jobs=2, n_relaunch=1, parallel_backend="ray")
34+
study.run(n_jobs=4, n_relaunch=1, parallel_backend="ray")
3535
else:
3636
study.run(n_jobs=n_jobs, n_relaunch=1, parallel_backend="ray")
3737

src/agentlab/agents/tool_use_agent/tool_use_agent.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ def set_benchmark(self, benchmark: AgentLabBenchmark | BgymBenchmark, demo_mode:
384384
if benchmark_name == "osworld":
385385
self.config.obs.use_osworld_obs_preprocessor = True
386386

387+
387388
class ToolUseAgent(bgym.Agent):
388389
def __init__(
389390
self,
@@ -598,5 +599,7 @@ def get_action(self, obs: Any) -> float:
598599
multiaction=False, # whether to use multi-action or not
599600
action_subsets=("coord",),
600601
),
601-
action_set=OSWorldActionSet("computer_13"), # or "pyautogui" #TODO: agent config should only be some primitive types.
602-
)
602+
action_set=OSWorldActionSet(
603+
"computer_13"
604+
), # or "pyautogui" #TODO: agent config should only be some primitive types.
605+
)

src/agentlab/benchmarks/osworld.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ def __init__(
343343
enable_proxy: bool,
344344
max_steps: int,
345345
exp_dir: Path,
346+
record_video: bool = True,
346347
):
347348
self.task = task
348349
self.env_info = {
@@ -375,13 +376,18 @@ def __init__(
375376
self._step_count = 0
376377
self.max_steps = max_steps
377378
self.exp_dir = exp_dir
379+
self.record_video = record_video
378380

379381
def reset(self, seed: int | None = None) -> tuple[dict[str, Any], dict[str, Any]]:
380382
self.env.reset(task_config=self.task, seed=seed)
381383
logging.info(f"Start solving task: {self.task['instruction']}")
382-
time.sleep(60) # Wait for the environment to be ready, as in https://github.com/xlang-ai/OSWorld/blob/main/lib_run_single.py#L15
383-
raw_obs = self.env._get_obs() # Get the initial observation
384-
self.env.controller.start_recording()
384+
time.sleep(
385+
60
386+
) # Wait for the environment to be ready, as in https://github.com/xlang-ai/OSWorld/blob/main/lib_run_single.py#L15
387+
raw_obs = self.env._get_obs() # Get the initial observation
388+
if self.record_video:
389+
self.env.controller.start_recording()
390+
logging.info("Started recording the environment video")
385391
obs = self.to_agentlab_observation(raw_obs)
386392
self._step_count = 0
387393
return obs, self.env_info
@@ -520,9 +526,10 @@ def parse_agentlab_action_str_to_func_args(action: str):
520526
return None, None, None
521527

522528
def close(self):
523-
video_name = str(self.exp_dir / "recording.mp4")
524-
self.env.controller.end_recording(video_name)
525-
logger.info(f"Recorded video saved to {video_name}")
529+
if self.record_video:
530+
video_name = str(self.exp_dir / "recording.mp4")
531+
self.env.controller.end_recording(video_name)
532+
logger.info(f"Recorded video saved to {video_name}")
526533
return self.env.close()
527534

528535

@@ -671,10 +678,8 @@ def fix_settings_file_path_in_config(self, task) -> str:
671678
osworld_repo = os.getenv("OSWORLD_REPO", "OSWorld")
672679
updated_task = deepcopy(task) # Avoid modifying the original task
673680
for config in updated_task["config"]:
674-
if config.get("parameters", False) and config["parameters"].get(
675-
"settings_file", False
676-
):
677-
config["parameters"]["settings_file"] = os.path.join(
678-
osworld_repo, config["parameters"]["settings_file"]
679-
)
681+
if config.get("parameters", False) and config["parameters"].get("settings_file", False):
682+
config["parameters"]["settings_file"] = os.path.join(
683+
osworld_repo, config["parameters"]["settings_file"]
684+
)
680685
return updated_task

src/agentlab/benchmarks/osworld_axtree_preprocessing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def collect_leaf_nodes(node, leaf_nodes):
2525
collect_leaf_nodes(root, leaf_nodes)
2626
return leaf_nodes
2727

28+
2829
attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes"
2930
attributes_ns_windows = "https://accessibility.windows.example.org/ns/attributes"
3031
state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state"
@@ -337,4 +338,3 @@ def trim_accessibility_tree(linearized_accessibility_tree, max_tokens):
337338
linearized_accessibility_tree = enc.decode(tokens[:max_tokens])
338339
linearized_accessibility_tree += "[...]\n"
339340
return linearized_accessibility_tree
340-

0 commit comments

Comments
 (0)