diff --git a/furniture_bench/envs/furniture_sim_env.py b/furniture_bench/envs/furniture_sim_env.py index fc284e6..3063354 100644 --- a/furniture_bench/envs/furniture_sim_env.py +++ b/furniture_bench/envs/furniture_sim_env.py @@ -54,6 +54,7 @@ def __init__( num_envs: int = 1, resize_img: bool = True, obs_keys=None, + use_april_tag_coords: bool = True, concat_robot_state: bool = False, manual_label: bool = False, manual_done: bool = False, @@ -77,6 +78,7 @@ def __init__( num_envs (int): Number of parallel environments. resize_img (bool): If true, images are resized to 224 x 224. obs_keys (list): List of observations for observation space (i.e., RGB-D image from three cameras, proprioceptive states, and poses of the furniture parts.) + use_april_tag_coords (bool): Whether to use AprilTag coords for parts concat_robot_state (bool): Whether to return concatenated `robot_state` or its dictionary form in observation. manual_label (bool): If true, the environment reward is manually labeled. manual_done (bool): If true, the environment is terminated manually. @@ -109,6 +111,7 @@ def __init__( for furn in self.furnitures: furn.max_env_steps = max_env_steps + self.use_april_tag_coords = use_april_tag_coords self.furniture_name = furniture self.num_envs = num_envs self.obs_keys = obs_keys or DEFAULT_VISUAL_OBS @@ -174,7 +177,7 @@ def __init__( str(record_dir / "video.mp4"), cv2.VideoWriter_fourcc(*"MP4V"), 30, - (self.img_size[1] * 2, self.img_size[0]), # Wrist and front cameras. + (self.img_size[0] * 2, self.img_size[1]), # Wrist and front cameras. ) if act_rot_repr != "quat" and act_rot_repr != "axis" and act_rot_repr != "rot_6d": @@ -888,17 +891,18 @@ def _get_parts_poses(self, sim_coord=False): rb_idx = self.part_idxs[part.name][env_idx] part_pose = self.rb_states[rb_idx, :7] # To AprilTag coordinate. - part_pose = torch.concat( - [ - *C.mat2pose( - self.sim_coord_to_april_coord( - C.pose2mat( - part_pose[:3], part_pose[3:7], device=self.device + if self.use_april_tag_coords: + part_pose = torch.concat( + [ + *C.mat2pose( + self.sim_coord_to_april_coord( + C.pose2mat( + part_pose[:3], part_pose[3:7], device=self.device + ) ) ) - ) - ] - ) + ] + ) parts_poses[ env_idx, part_idx * self.pose_dim : (part_idx + 1) * self.pose_dim ] = part_pose