diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 3e36a0d981..ea2b4b9d06 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -47,6 +47,8 @@ title: Using Libero - local: metaworld title: Using MetaWorld + - local: rlbench + title: Using RLBench title: "Simulation" - sections: - local: introduction_processors diff --git a/docs/source/rlbench.mdx b/docs/source/rlbench.mdx new file mode 100644 index 0000000000..f0bca9f74d --- /dev/null +++ b/docs/source/rlbench.mdx @@ -0,0 +1,122 @@ +# RLBench + +**RLBench** is a large-scale benchmark designed to accelerate research in **robot learning**, with a strong focus on **vision-guided manipulation**. It provides a challenging and standardized environment for developing and testing algorithms that can learn complex robotic tasks. + +- 📄 [RLBench paper](https://arxiv.org/abs/1909.12271) +- 💻 [Original RLBench repo](https://github.com/stepjam/RLBench) + +![RLBench Tasks](https://github.com/stepjam/RLBench/blob/master/readme_files/task_grid.png) + +## Why RLBench? + +- **Diverse and Challenging Tasks:** RLBench includes over 100 unique, hand-designed tasks, ranging from simple reaching and pushing to complex, multi-stage activities like opening an oven and placing a tray inside. This diversity tests an algorithm's ability to generalize across different objectives and dynamics. +- **Rich, Multi-Modal Observations:** The benchmark provides both proprioceptive (joint states) and visual observations. Visual data comes from multiple camera angles, including over-the-shoulder cameras and wrist camera, with options for RGB, depth, and segmentation masks. +- **Infinite Demonstrations:** A key feature of RLBench is its ability to generate an infinite supply of demonstrations for each task. These demonstrations are created using motion planners, making RLBench an ideal platform for research in imitation learning and offline reinforcement learning. +- **Scalability and Customization:** RLBench is designed to be extensible. Researchers can easily create and contribute new tasks, helping the benchmark evolve and stay relevant. + +RLBench includes **eight task sets**, which consist of a collection of multiple tasks (FS=Few-Shot, MT=Multi-Task). + +- **`FS10_v1`** – 10 training tasks, 5 test tasks +- **`FS25_v1`** – 25 training tasks, 5 test tasks +- **`FS50_v1`** – 50 training tasks, 5 test tasks +- **`FS95_v1`** – 95 training tasks, 5 test tasks +- **`MT15_v1`** – 15 training tasks (all tasks of `FS10_v1`, training+test) +- **`MT30_v1`** – 30 training tasks (all tasks of `FS25_v1`, training+test) +- **`MT55_v1`** – 55 training tasks (all tasks of `FS50_v1`, training+test) +- **`MT100_v1`** – 100 training tasks (all tasks of `FS95_v1`, training+test) + +For details about the tasks and task sets, please refer to the [original definition](https://github.com/stepjam/RLBench/blob/master/rlbench/tasks/__init__.py). + +## RLBench in LeRobot + +LeRobot's integration with RLBench allows you to train and evaluate policies on its rich set of tasks. The integration is designed to be seamless, leveraging LeRobot's training and evaluation pipelines. + +### Get started + +RLBench is built around CoppeliaSim v4.1.0 and [PyRep](https://github.com/stepjam/PyRep). + +First, install CoppeliaSim: + +```bash +# set environment variables +export COPPELIASIM_ROOT=${HOME}/CoppeliaSim +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$COPPELIASIM_ROOT +export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT + +wget https://downloads.coppeliarobotics.com/V4_1_0/CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz +mkdir -p $COPPELIASIM_ROOT && tar -xf CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz -C $COPPELIASIM_ROOT --strip-components 1 +rm -rf CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz +``` + +Next, install the necessary dependencies: + +```bash +pip install pycparser # needed while cloning rlbench +pip install -e ".[rlbench]" +``` + +That's it! You can now use RLBench environments within LeRobot. To run headless, check the documentation on the original [RLBench repo](https://github.com/stepjam/RLBench). + +### Evaluating a Policy + +You can evaluate a trained policy on a specific RLBench task or a suite of tasks. + +```bash +export COPPELIASIM_ROOT=${HOME}/CoppeliaSim +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$COPPELIASIM_ROOT +export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT + +lerobot-eval \ + --policy.path="your-policy-id" \ + --env.type=rlbench \ + --env.task=put_rubbish_in_bin \ + --eval.batch_size=1 \ + --eval.n_episodes=10 +``` + +- `--env.task` specifies the RLBench task to evaluate on. You can also use task suites like `FS10_V1` or `MT30_V1`. +- The evaluation script will report the success rate for the given task(s). + +### Training a Policy + +You can train a policy on RLBench tasks using the `lerobot-train` command. You'll need a dataset in the correct format. + +```bash +export COPPELIASIM_ROOT=${HOME}/CoppeliaSim +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$COPPELIASIM_ROOT +export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT + +lerobot-train \ + --policy.type=smolvla \ + --policy.repo_id=${HF_USER}/rlbench-test \ + --dataset.repo_id=lerobot/rlbench_put_rubbish_in_bin \ + --env.type=rlbench \ + --env.task=put_rubbish_in_bin \ + --output_dir=./outputs/ \ + --steps=100000 \ + --batch_size=4 \ + --eval.batch_size=1 \ + --eval.n_episodes=10 \ + --eval_freq=1000 +``` + +> If running on a headless server, ensure that the CoppeliaSim environment is set up to run without a GUI. +> Refer to the [RLBench documentation](https://github.com/stepjam/RLBench). + +### RLBench Datasets + +LeRobot expects datasets to be in a specific format. While there isn't an official `lerobot`-prepared RLBench dataset on the Hugging Face Hub yet, you can create your own by converting demonstrations from the original RLBench format. + +The environment expects the following observation and action keys: + +- **Observations:** + - `observation.state`: Proprioceptive features (usually joint positions + gripper). + - `observation.images.front_rgb`: Front RGB camera view. + - `observation.images.wrist_rgb`: Wrist RGB camera view. + - `observation.images.overhead_rgb`: Overhead RGB camera view. + - `observation.images.left_shoulder_rgb`: Left shoulder RGB camera view. + - `observation.images.right_shoulder_rgb`: Right shoulder RGB camera view. +- **Actions:** + - A continuous control vector for the robot's joints and gripper (e.g. for franka, 8 dimensions: 7 joint positions + 1 gripper state). + +Make sure your dataset's metadata and parquet files use these keys to ensure compatibility with LeRobot's RLBench environment. diff --git a/examples/dataset/rlbench_collect_dataset.py b/examples/dataset/rlbench_collect_dataset.py new file mode 100644 index 0000000000..a517ef8ba4 --- /dev/null +++ b/examples/dataset/rlbench_collect_dataset.py @@ -0,0 +1,328 @@ +import argparse +import os +import shutil + +import numpy as np + +# RLBench +from rlbench import CameraConfig, Environment +from rlbench.action_modes.action_mode import MoveArmThenGripper +from rlbench.action_modes.arm_action_modes import EndEffectorPoseViaIK +from rlbench.action_modes.gripper_action_modes import Discrete +from rlbench.demo import Demo +from rlbench.observation_config import ObservationConfig +from rlbench.utils import name_to_task_class +from scipy.spatial.transform import Rotation +from tqdm import tqdm + +from lerobot.datasets.lerobot_dataset import LeRobotDataset + +# You can define how end-effector actions and rotations are represented. +# The action represents the joint positions (joint1...joint7, gripper_open) +# or end-effector pose (position, rotation, gripper_open), absolute or relative (delta from current pose). +# The rotation can be represented as either Euler angles (3 values) or quaternions (4 values). +EULER_EEF = "euler" # Actions have 7 values: [x, y, z, roll, pitch, yaw, gripper_state] +QUAT_EEF = "quat" # Actions have 8 values: [x, y, z, qx, qy, qz, qw, gripper_state] +JOINTS = "joints" # Actions have 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper_state] + + +def get_target_pose(demo: Demo, index: int): + """Get the target pose (gripper position and open state) for a specific observation in the demo.""" + return np.array( + [*demo._observations[max(0, index)].gripper_pose, demo._observations[max(0, index)].gripper_open] + ) + + +def get_target_joints(demo: Demo, index: int): + """Get the target joint positions for a specific observation in the demo.""" + return np.array( + [*demo._observations[max(0, index)].joint_positions, demo._observations[max(0, index)].gripper_open] + ) + + +def action_conversion( + action: np.ndarray, + to_representation: str = "euler", + is_relative: bool = False, + previous_action: np.ndarray = None, +): + """Convert an action between Euler and quaternion representations. + + Args: + action: np.ndarray of shape (7,) (Euler) or (8,) (quaternion). + Euler format: [x, y, z, roll, pitch, yaw, gripper] + Quaternion format: [x, y, z, qx, qy, qz, qw, gripper] + to_representation: 'euler' or 'quat' target representation. + is_relative: if True, compute the delta between `action` and + `previous_action` and return the delta (position and + rotation). `previous_action` must be provided in this case. + previous_action: previous action (same format as `action`) used when + is_relative is True. + + Returns: + np.ndarray converted action in the target representation. For relative + mode the returned rotation represents the delta rotation (as Euler + angles or as a unit quaternion depending on `to_representation`). + + Notes: + - Quaternion ordering is (qx, qy, qz, qw) to match the rest of the + codebase. Rotation objects from scipy are created/consumed with + this ordering via as_quat(scalar_first=False). + - When producing quaternions we always normalize to guard against + numerical drift. + """ + + if to_representation not in ("euler", "quat"): + raise ValueError("to_representation must be 'euler' or 'quat'") + + a = np.asarray(action, dtype=float) + if a.size not in (7, 8): + raise ValueError("action must be length 7 (Euler) or 8 (quaternion)") + + if is_relative and previous_action is None: + raise ValueError("previous_action must be provided when is_relative is True") + + def _ensure_unit_quat(q): + q = np.asarray(q, dtype=float) + n = np.linalg.norm(q) + if n == 0: + raise ValueError("Zero quaternion encountered") + return q / n + + # Helper: construct Rotation from either euler or quat stored in action array + def _rot_from_action(arr): + arr = np.asarray(arr, dtype=float) + if arr.size == 7: + return Rotation.from_euler("xyz", arr[3:6], degrees=False) + else: + return Rotation.from_quat(arr[3:7]) # (qx, qy, qz, qw) + + # Gripper state (keep as-is, demo code expects absolute gripper state even for deltas) + gripper = a[-1] + + # Relative case: compute deltas + if is_relative: + prev = np.asarray(previous_action, dtype=float) + if prev.size not in (7, 8): + raise ValueError("previous_action must be length 7 or 8") + + delta_pos = a[:3] - prev[:3] + + # If both are Euler, simple subtraction of angles is fine + if a.size == 7 and prev.size == 7: + delta_ang = a[3:6] - prev[3:6] + if to_representation == "euler": + return np.array([*delta_pos, *delta_ang, gripper], dtype=float) + else: + # convert delta Euler to quaternion (and normalize) + q = Rotation.from_euler("xyz", delta_ang, degrees=False).as_quat(scalar_first=False) + q = _ensure_unit_quat(q) + return np.array([*delta_pos, *q, gripper], dtype=float) + + # Otherwise use rotation algebra to compute the delta rotation + r_cur = _rot_from_action(a) + r_prev = _rot_from_action(prev) + r_delta = r_cur * r_prev.inv() + + if to_representation == "euler": + delta_ang = r_delta.as_euler("xyz", degrees=False) + return np.array([*delta_pos, *delta_ang, gripper], dtype=float) + else: + q = r_delta.as_quat(scalar_first=False) + q = _ensure_unit_quat(q) + return np.array([*delta_pos, *q, gripper], dtype=float) + + # Absolute case: just convert representations + if to_representation == "euler": + if a.size == 7: + return a.astype(float) + else: + euler = Rotation.from_quat(a[3:7]).as_euler("xyz", degrees=False) + return np.array([*a[:3], *euler, gripper], dtype=float) + else: # to_representation == 'quat' + if a.size == 8: + q = _ensure_unit_quat(a[3:7]) + return np.array([*a[:3], *q, gripper], dtype=float) + else: + q = Rotation.from_euler("xyz", a[3:6], degrees=False).as_quat(scalar_first=False) + q = _ensure_unit_quat(q) + return np.array([*a[:3], *q, gripper], dtype=float) + + +# ------------------------ +# Main +# ------------------------ + + +def main(args): + task_class = name_to_task_class(args.task) + + # RLBench setup + camera_config = CameraConfig(image_size=(args.image_height, args.image_width)) + obs_config = ObservationConfig( + left_shoulder_camera=camera_config, + right_shoulder_camera=camera_config, + overhead_camera=camera_config, + wrist_camera=camera_config, + front_camera=camera_config, + ) + obs_config.set_all(True) + + action_mode = MoveArmThenGripper( + arm_action_mode=EndEffectorPoseViaIK(absolute_mode=args.absolute_actions), + gripper_action_mode=Discrete(), + ) + env = Environment(action_mode, obs_config=obs_config, headless=True) + env.launch() + task = env.get_task(task_class) + + # Remove the dataset root if already exists + if os.path.exists(args.save_path): + print(f"Dataset root {args.save_path} already exists. Removing it.") + shutil.rmtree(args.save_path) + + camera_names = ["left_shoulder_rgb", "right_shoulder_rgb", "front_rgb", "wrist_rgb", "overhead_rgb"] + + action_feature = {} + if args.action_repr == "euler": + action_feature = { + "shape": (7,), # pos(3) + euler(3) + gripper(1) + "dtype": "float32", + "names": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"], + "description": "End-effector position (x,y,z), orientation (roll,pitch,yaw) and gripper state (0.0 closed, 1.0 open).", + } + elif args.action_repr == "quat": + action_feature = { + "shape": (8,), # pos(3) + quat(4) + gripper(1) + "dtype": "float32", + "names": ["x", "y", "z", "qx", "qy", "qz", "qw", "gripper"], + "description": "End-effector position (x,y,z), orientation (qx,qy,qz,qw) and gripper state (0.0 closed, 1.0 open).", + } + elif args.action_repr == "joints": + action_feature = { + "shape": (8,), # joint_1 to joint_7 + gripper(1) + "dtype": "float32", + "names": ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7", "gripper"], + "description": "Robot joint positions (absolute rotations) and gripper state (0.0 closed, 1.0 open).", + } + + dataset = LeRobotDataset.create( + repo_id=args.repo_id, + fps=args.fps, + root=args.save_path, + robot_type="franka", + features={ + "observation.state": { + "dtype": "float32", + "shape": (7,), # pos(3) + euler(3) + gripper(1) + "names": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"], + "description": "End-effector position (x,y,z), orientation (roll,pitch,yaw) and gripper state (0.0 closed, 1.0 open).", + }, + "observation.state.joints": { + "dtype": "float32", + "shape": (7,), + "names": ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"], + "description": "Robot joint positions (absolute rotations).", + }, + "action": action_feature, + # All camera images + **{ + f"observation.images.{cam}": { + "dtype": "image", + "shape": (args.image_height, args.image_width, 3), + "names": ["height", "width", "channels"], + # "dtype": "video", + # "info": { + # "video.fps": args.fps, + # "video.height": args.image_height, + # "video.width": args.image_width, + # "video.channels": 3, + # "video.is_depth_map": False, + # "has_audio": False, + # }, + } + for cam in camera_names + }, + }, + ) + + # Collect demonstrations and add them to the LeRobot dataset + print(f"Generating {args.num_episodes} demos for task: {args.task}") + for _ in tqdm(range(args.num_episodes), desc="Collecting demos"): + # generate a new demo + demo = task.get_demos(1, live_demos=True)[0] + + for frame_index, observation in enumerate(demo): + action = None + if args.action_repr in ["euler", "quat"]: + action = action_conversion( + get_target_pose(demo, frame_index + 1 if frame_index < len(demo) - 1 else frame_index), + args.action_repr, + not args.absolute_actions, + get_target_pose(demo, frame_index), + ) + elif args.action_repr == "joints": + action = get_target_joints( + demo, frame_index + 1 if frame_index < len(demo) - 1 else frame_index + ) + + # Create the frame data, following the same structure as the features defined above + frame_data = { + "observation.state": action_conversion(get_target_pose(demo, frame_index)).astype(np.float32), + "observation.state.joints": observation.joint_positions.astype(np.float32), + "action": action.astype(np.float32), + "task": task.get_name(), + } + for cam in camera_names: + frame_data[f"observation.images.{cam}"] = getattr(observation, cam) + + # Save the frame + dataset.add_frame(frame_data) + dataset.save_episode() + env.shutdown() + + # dataset.push_to_hub() + print(f"\033[92mDataset saved to {args.save_path} and pushed to HuggingFace Hub: {args.repo_id}\033[0m") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Collect RLBench demonstrations and save to LeRobot dataset format." + ) + parser.add_argument( + "--save_path", + type=str, + default=os.path.join(os.getcwd(), "datasets"), + help="Path to save the LeRobot dataset.", + ) + parser.add_argument( + "--repo_id", + type=str, + required=True, + help="HuggingFace Hub repository ID (e.g., 'username/dataset-name').", + ) + parser.add_argument("--num_episodes", type=int, default=100, help="Number of demonstrations to record.") + parser.add_argument("--task", type=str, default="put_rubbish_in_bin", help="Name of the RLBench task.") + parser.add_argument( + "--action_repr", + type=str, + choices=["euler", "quat", "joints"], + default="euler", + help="Action representation: 'euler' for Euler angles, 'quat' for quaternions, or 'joints' for joint positions.", + ) + parser.add_argument( + "--absolute_actions", + action="store_true", + default=False, + help="Whether to use absolute actions (default: False). Valid only for 'euler' and 'quat' action representations.", + ) + parser.add_argument( + "--fps", type=int, default=30, help="Video frames per second for the dataset (default: 30)." + ) + parser.add_argument("--image_width", type=int, default=256, help="Image width in pixels (default: 256).") + parser.add_argument( + "--image_height", type=int, default=256, help="Image height in pixels (default: 256)." + ) + + args = parser.parse_args() + main(args) diff --git a/pyproject.toml b/pyproject.toml index e58f42a373..1ad2e3faf8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -144,6 +144,7 @@ aloha = ["gym-aloha>=0.1.2,<0.2.0"] pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"] metaworld = ["metaworld==3.0.0"] +rlbench = ["rlbench @ git+https://github.com/stepjam/RLBench.git"] # All all = [ @@ -167,6 +168,7 @@ all = [ "lerobot[phone]", "lerobot[libero]", "lerobot[metaworld]", + "lerobot[rlbench]", ] [project.scripts] diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index dc526114dc..40d9bb29fa 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -21,7 +21,7 @@ from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.robots import RobotConfig from lerobot.teleoperators.config import TeleoperatorConfig -from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STATE_JOINTS @dataclass @@ -319,3 +319,70 @@ def gym_kwargs(self) -> dict: "obs_type": self.obs_type, "render_mode": self.render_mode, } + + +@EnvConfig.register_subclass("rlbench") +@dataclass +class RLBenchEnv(EnvConfig): + task: str = "FS10_V1" # can also choose other task suites or single tasks + fps: int = 30 + episode_length: int = 400 + obs_type: str = "pixels_agent_pos" + render_mode: str = "rgb_array" + camera_name: str = "left_shoulder_rgb,right_shoulder_rgb,front_rgb,wrist_rgb,overhead_rgb" + camera_name_mapping: dict[str, str] | None = None + observation_height: int = 256 + observation_width: int = 256 + task_ids: str | None = None + features: dict[str, PolicyFeature] = field( + default_factory=lambda: { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(8,)), + } + ) + features_map: dict[str, str] = field( + default_factory=lambda: { + ACTION: ACTION, + "agent_pos": OBS_STATE, + "agent_joints": OBS_STATE_JOINTS, + "pixels/front_rgb": f"{OBS_IMAGES}.front_rgb", + "pixels/wrist_rgb": f"{OBS_IMAGES}.wrist_rgb", + "pixels/left_shoulder_rgb": f"{OBS_IMAGES}.left_shoulder_rgb", + "pixels/right_shoulder_rgb": f"{OBS_IMAGES}.right_shoulder_rgb", + "pixels/overhead_rgb": f"{OBS_IMAGES}.overhead_rgb", + } + ) + + def __post_init__(self): + all_cameras = ["front_rgb", "wrist_rgb", "left_shoulder_rgb", "right_shoulder_rgb", "overhead_rgb"] + + if self.obs_type == "pixels": + for cam in all_cameras: + self.features[f"pixels/{cam}"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) + ) + + elif self.obs_type == "pixels_agent_pos": + self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(7,)) + self.features["agent_joints"] = PolicyFeature(type=FeatureType.STATE, shape=(7,)) + for cam in all_cameras: + self.features[f"pixels/{cam}"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) + ) + + elif self.obs_type == "state": + self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(7,)) + self.features["agent_joints"] = PolicyFeature(type=FeatureType.STATE, shape=(7,)) + + else: + raise ValueError(f"Unsupported obs_type: {self.obs_type}") + + @property + def gym_kwargs(self) -> dict: + kwargs = { + "obs_type": self.obs_type, + "render_mode": self.render_mode, + } + + if self.task_ids is not None: + kwargs["task_ids"] = self.task_ids + return kwargs diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index 52c7cbb966..eb45c12835 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -18,7 +18,7 @@ import gymnasium as gym from gymnasium.envs.registration import registry as gym_registry -from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv +from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv, RLBenchEnv def make_env_config(env_type: str, **kwargs) -> EnvConfig: @@ -28,6 +28,8 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig: return PushtEnv(**kwargs) elif env_type == "libero": return LiberoEnv(**kwargs) + elif env_type == "rlbench": + return RLBenchEnv(**kwargs) else: raise ValueError(f"Policy type '{env_type}' is not available.") @@ -85,6 +87,19 @@ def make_env( gym_kwargs=cfg.gym_kwargs, env_cls=env_cls, ) + elif "rlbench" in cfg.type: + from lerobot.envs.rlbench import create_rlbench_envs + + if cfg.task is None: + raise ValueError("RLBench requires a task to be specified") + + return create_rlbench_envs( + task=cfg.task, + n_envs=n_envs, + camera_name=cfg.camera_name, + gym_kwargs=cfg.gym_kwargs, + env_cls=env_cls, + ) if cfg.gym_id not in gym_registry: print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...") diff --git a/src/lerobot/envs/rlbench.py b/src/lerobot/envs/rlbench.py new file mode 100644 index 0000000000..2b33871a17 --- /dev/null +++ b/src/lerobot/envs/rlbench.py @@ -0,0 +1,621 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from collections import defaultdict +from collections.abc import Callable, Iterable, Mapping, Sequence +from functools import partial +from pathlib import Path +from typing import Any + +import gymnasium as gym +import numpy as np +from gymnasium import spaces +from pyrep.const import RenderMode +from pyrep.objects.dummy import Dummy +from pyrep.objects.vision_sensor import VisionSensor +from rlbench.action_modes.action_mode import MoveArmThenGripper +from rlbench.action_modes.arm_action_modes import EndEffectorPoseViaIK +from rlbench.action_modes.gripper_action_modes import Discrete +from rlbench.backend.exceptions import BoundaryError, InvalidActionError, TaskEnvironmentError, WaypointError +from rlbench.backend.task import Task +from rlbench.environment import Environment +from rlbench.observation_config import CameraConfig, ObservationConfig +from rlbench.tasks import FS10_V1, FS25_V1, FS50_V1, FS95_V1, MT15_V1, MT30_V1, MT55_V1, MT100_V1 +from scipy.spatial.transform import Rotation + + +class DeltaEEFPositionActionMode(MoveArmThenGripper): + """Delta end-effector position action mode for arm and gripper. + 8 values: [Δposition Δquaternion gripper]. + + The arm action is first applied, followed by the gripper action. + """ + + def __init__(self): + # Call super + super().__init__( + EndEffectorPoseViaIK(absolute_mode=False), # Arm in delta end-effector position + Discrete(), # Gripper in discrete open/close (<0.5 → close, >=0.5 → open) + ) + + def action_bounds(self): + """Returns the min and max of the action mode. + Range is [-0.3, 0.3] for pose, [-1, 1] for rotation and [0.0, 1.0] for the gripper. + """ + return np.array([-0.3, -0.3, -0.3, -1, -1, -1, -1, 0.0]), np.array([0.3, 0.3, 0.3, 1, 1, 1, 1, 1.0]) + + +# ---- Load configuration data from the external JSON file ---- +CONFIG_PATH = Path(__file__).parent / "rlbench_config.json" +try: + with open(CONFIG_PATH) as f: + data = json.load(f) +except FileNotFoundError as err: + raise FileNotFoundError( + "Could not find 'rlbench_config.json'. " + "Please ensure the configuration file is in the same directory as the script." + ) from err +except json.JSONDecodeError as err: + raise ValueError( + "Failed to decode 'rlbench_config.json'. Please ensure it is a valid JSON file." + ) from err + +# ---- Process the loaded data ---- + +# extract and type-check top-level dicts +TASK_DESCRIPTIONS: dict[str, str] = data.get("TASK_DESCRIPTIONS", {}) +TASK_ID_TO_NAME: dict[str, str] = data.get("TASK_ID_TO_NAME", {}) + +""" +RLBench can support many action and observation types. +Here, we define standard dimensions for end-effector position control with gripper. +""" +ACTION_DIM = 8 # EEF pose+gripper (dim=8, [Δx Δy Δz Δqx Δqy Δqz Δqw gripper]) +OBS_DIM = 7 # EEF pose+gripper (dim=7, [x y z rx ry rz gripper]) + + +def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]: + """Normalize camera_name into a non-empty list of strings.""" + if isinstance(camera_name, str): + cams = [c.strip() for c in camera_name.split(",") if c.strip()] + elif isinstance(camera_name, (list | tuple)): + cams = [str(c).strip() for c in camera_name if str(c).strip()] + else: + raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}") + if not cams: + raise ValueError("camera_name resolved to an empty list.") + return cams + + +def _get_suite(name: str) -> dict[str, list[Task]]: + """Instantiate a RLBench suite by name with clear validation.""" + + suites = { + "FS10_V1": FS10_V1, + "FS25_V1": FS25_V1, + "FS50_V1": FS50_V1, + "FS95_V1": FS95_V1, + "MT15_V1": MT15_V1, + "MT30_V1": MT30_V1, + "MT55_V1": MT55_V1, + "MT100_V1": MT100_V1, + } + + if name not in suites: + raise ValueError(f"Unknown RLBench suite '{name}'. Available: {', '.join(sorted(suites.keys()))}") + suite = suites[name] + + if not suite.get("train", None) and not suite.get("test", None): + raise ValueError(f"Suite '{name}' has no tasks.") + return suite + + +def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> list[int]: + """Validate/normalize task ids. If None → all tasks.""" + if task_ids is None: + return list(range(total_tasks)) + ids = sorted({int(t) for t in task_ids}) + for t in ids: + if t < 0 or t >= total_tasks: + raise ValueError(f"task_id {t} out of range [0, {total_tasks - 1}].") + return ids + + +class RLBenchGymEnv(gym.Env): + """An gym wrapper for RLBench.""" + + metadata: dict[str, Any] = {"render_modes": ["human", "rgb_array"], "render_fps": 4} + + def __init__( + self, + task_class, + observation_mode="state", + render_mode: None | str = None, + action_mode=None, + obs_width: int = 256, + obs_height: int = 256, + ): + self.task_class = task_class + self.observation_mode = observation_mode + + if render_mode is not None and render_mode not in self.metadata["render_modes"]: + raise ValueError(f"render_mode must be one of {self.metadata['render_modes']}, got {render_mode}") + self.render_mode = render_mode + cam_config = CameraConfig(image_size=(obs_width, obs_height)) + obs_config = ObservationConfig( + left_shoulder_camera=cam_config, + right_shoulder_camera=cam_config, + wrist_camera=cam_config, + front_camera=cam_config, + overhead_camera=cam_config, + ) + if observation_mode == "state": + obs_config.set_all_high_dim(False) + obs_config.set_all_low_dim(True) + elif observation_mode == "vision": + obs_config.set_all(True) + else: + raise ValueError(f"Unrecognised observation_mode: {observation_mode}.") + self.obs_config = obs_config + if action_mode is None: + action_mode = EndEffectorPoseViaIK() + self.action_mode = action_mode + + self.rlbench_env = Environment( + action_mode=self.action_mode, + obs_config=self.obs_config, + headless=True, + ) + self.rlbench_env.launch() + self.rlbench_task_env = self.rlbench_env.get_task(self.task_class) + if render_mode is not None: + cam_placeholder = Dummy("cam_cinematic_placeholder") + self.gym_cam = VisionSensor.create([640, 360]) + self.gym_cam.set_pose(cam_placeholder.get_pose()) + if render_mode == "human": + self.gym_cam.set_render_mode(RenderMode.OPENGL3_WINDOWED) + else: + self.gym_cam.set_render_mode(RenderMode.OPENGL3) + _, obs = self.rlbench_task_env.reset() + + gym_obs = self._extract_obs(obs) + self.observation_space = {} + for key, value in gym_obs.items(): + if "rgb" in key: + self.observation_space[key] = spaces.Box( + low=0, high=255, shape=value.shape, dtype=value.dtype + ) + else: + self.observation_space[key] = spaces.Box( + low=-np.inf, high=np.inf, shape=value.shape, dtype=value.dtype + ) + self.observation_space = spaces.Dict(self.observation_space) + + action_low, action_high = action_mode.action_bounds() + self.action_space = spaces.Box( + low=np.float32(action_low), + high=np.float32(action_high), + shape=self.rlbench_env.action_shape, + dtype=np.float32, + ) + + def _extract_obs(self, rlbench_obs): + gym_obs = {} + for state_name in [ + "joint_velocities", + "joint_positions", + "joint_forces", + "gripper_open", + "gripper_pose", + "gripper_joint_positions", + "gripper_touch_forces", + "task_low_dim_state", + ]: + state_data = getattr(rlbench_obs, state_name) + if state_data is not None: + state_data = np.float32(state_data) + if np.isscalar(state_data): + state_data = np.asarray([state_data]) + gym_obs[state_name] = state_data + + if self.observation_mode == "vision": + gym_obs.update( + { + "left_shoulder_rgb": rlbench_obs.left_shoulder_rgb, + "right_shoulder_rgb": rlbench_obs.right_shoulder_rgb, + "wrist_rgb": rlbench_obs.wrist_rgb, + "front_rgb": rlbench_obs.front_rgb, + "overhead_rgb": rlbench_obs.overhead_rgb, + } + ) + return gym_obs + + def render(self): + if self.render_mode == "rgb_array": + frame = self.gym_cam.capture_rgb() + frame = np.clip((frame * 255.0).astype(np.uint8), 0, 255) + return frame + + def reset(self, seed=None, options=None): + super().reset(seed=seed) + # TODO: Remove this and use seed from super() + np.random.seed(seed=seed) + reset_to_demo = None + if options is not None: + # TODO: Write test for this + reset_to_demo = options.get("reset_to_demo", None) + + if reset_to_demo is None: + descriptions, obs = self.rlbench_task_env.reset() + else: + descriptions, obs = self.rlbench_task_env.reset(reset_to_demo=reset_to_demo) + return self._extract_obs(obs), {"text_descriptions": descriptions} + + def step(self, action): + obs, reward, terminated = self.rlbench_task_env.step(action) + return self._extract_obs(obs), reward, terminated, False, {"is_success": reward > 0.5} + + def close(self) -> None: + self.rlbench_env.shutdown() + + +class RLBenchEnv(gym.Env): + metadata = {"render_modes": ["rgb_array"], "render_fps": 30} + + def __init__( + self, + task: Task | None = None, + task_suite: dict[str, list[Task]] | None = None, + camera_name: str + | Sequence[str] = "left_shoulder_rgb,right_shoulder_rgb,front_rgb,wrist_rgb,overhead_rgb", + obs_type: str = "pixels", + render_mode: str = "rgb_array", + observation_width: int = 256, + observation_height: int = 256, + visualization_width: int = 640, + visualization_height: int = 480, + camera_name_mapping: dict[str, str] | None = None, + ): + super().__init__() + self.task = task + self.obs_type = obs_type + self.render_mode = render_mode + self.observation_width = observation_width + self.observation_height = observation_height + self.visualization_width = visualization_width + self.visualization_height = visualization_height + self.camera_name = _parse_camera_names(camera_name) + + # Map raw camera names to "front_rgb", "wrist_rgb", etc. + # The preprocessing step `preprocess_observation` will then prefix these with `.images.*`, + # following the LeRobot convention (e.g., `observation.images.front_rgb`, ...). + # This ensures the policy consistently receives observations in the + # expected format regardless of the original camera naming. + if camera_name_mapping is None: + camera_name_mapping = { + "front_rgb": "front_rgb", + "wrist_rgb": "wrist_rgb", + "left_shoulder_rgb": "left_shoulder_rgb", + "right_shoulder_rgb": "right_shoulder_rgb", + "overhead_rgb": "overhead_rgb", + } + self.camera_name_mapping = camera_name_mapping + + self._env = self._make_envs_task(self.task) + self._max_episode_steps = 500 # TODO: make configurable depending on task suite? + + # Get task description + task_name = self.task.__name__ if self.task is not None else "" + self.task_description = TASK_DESCRIPTIONS.get(task_name, "") + + images = {} + for cam in self.camera_name: + images[self.camera_name_mapping[cam]] = spaces.Box( + low=0, + high=255, + shape=(self.observation_height, self.observation_width, 3), + dtype=np.uint8, + ) + + if self.obs_type == "state": + self.observation_space = spaces.Dict( + { + "agent_pos": spaces.Box( + low=-1000.0, + high=1000.0, + shape=(OBS_DIM,), + dtype=np.float64, + ), + "agent_joints": spaces.Box( + low=-1000.0, + high=1000.0, + shape=(7,), + dtype=np.float64, + ), + } + ) + + elif self.obs_type == "pixels": + self.observation_space = spaces.Dict( + { + "pixels": spaces.Dict(images), + } + ) + elif self.obs_type == "pixels_agent_pos": + self.observation_space = spaces.Dict( + { + "pixels": spaces.Dict(images), + "agent_pos": spaces.Box( + low=-1000.0, + high=1000.0, + shape=(OBS_DIM,), + dtype=np.float64, + ), + "agent_joints": spaces.Box( + low=-1000.0, + high=1000.0, + shape=(7,), + dtype=np.float64, + ), + } + ) + + self.action_space = spaces.Box(low=-5, high=5, shape=(ACTION_DIM,), dtype=np.float32) + + def render(self) -> np.ndarray: + """ + Render the current environment frame. + + Returns: + np.ndarray: The rendered RGB image from the environment. + """ + return self._env.render() + + def _make_envs_task(self, task: Task): + return RLBenchGymEnv( + task, + observation_mode="vision", + render_mode=self.render_mode, + action_mode=DeltaEEFPositionActionMode(), + obs_width=self.observation_width, + obs_height=self.observation_height, + ) + + def _format_raw_obs(self, raw_obs: dict) -> dict[str, Any]: + images = {} + for camera_name in self.camera_name: + image = raw_obs[camera_name] + image = image[::-1, ::-1] # rotate 180 degrees + images[self.camera_name_mapping[camera_name]] = image + + # Gripper pose is 7D (position + quaternion), we convert to 6D (position + euler angles) + eef_position = np.concatenate( + [ + raw_obs["gripper_pose"][:3], + Rotation.from_quat(raw_obs["gripper_pose"][3:]).as_euler("xyz"), + raw_obs["gripper_open"], + ] + ) + robot_joints = np.array(raw_obs["joint_positions"], dtype=np.float32) + + if self.obs_type == "state": + obs = { + "agent_pos": eef_position, + "agent_joints": robot_joints, + } + + if self.obs_type == "pixels": + obs = {"pixels": images.copy()} + elif self.obs_type == "pixels_agent_pos": + obs = { + "pixels": images.copy(), + "agent_pos": eef_position, + "agent_joints": robot_joints, + } + else: + raise NotImplementedError( + f"The observation type '{self.obs_type}' is not supported in RLBench. " + "Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')." + ) + return obs + + def reset( + self, + seed: int | None = None, + **kwargs, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """ + Reset the environment to its initial state. + + Args: + seed (Optional[int]): Random seed for environment initialization. + + Returns: + observation (Dict[str, Any]): The initial formatted observation. + info (Dict[str, Any]): Additional info about the reset state. + """ + super().reset(seed=seed) + + raw_obs, info = self._env.reset(seed=seed) + + observation = self._format_raw_obs(raw_obs) + + info = {"is_success": False} + return observation, info + + def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + """ + Perform one environment step. + + Args: + action (np.ndarray): The action to execute, must be 1-D with shape (action_dim,). + + Returns: + observation (Dict[str, Any]): The formatted observation after the step. + reward (float): The scalar reward for this step. + terminated (bool): Whether the episode terminated successfully. + truncated (bool): Whether the episode was truncated due to a time limit. + info (Dict[str, Any]): Additional environment info. + """ + if action.ndim != 1: + raise ValueError( + f"Expected action to be 1-D (shape (action_dim,)), " + f"but got shape {action.shape} with ndim={action.ndim}" + ) + + # If the action has 7 elements (pose + euler angles + gripper), convert to 8 elements (pose + quaternion + gripper) + if action.shape[0] == 7: + action = np.concatenate( + [ + action[:3], # position + Rotation.from_euler("xyz", action[3:6]).as_quat(), # rotation as quaternion + action[-1:], # Only last position is gripper + ] + ) + + # Force unit quaternion + quat = action[3:7] + quat = quat / np.linalg.norm(quat) + action[3:7] = quat + + # Perform step + try: + raw_obs, reward, done, truncated, info = self._env.step(action) + except (InvalidActionError, BoundaryError, WaypointError, TaskEnvironmentError) as e: + print(f"Error occurred while stepping the environment: {e}") + raw_obs, reward, done, truncated, info = None, 0.0, True, True, {} + + # Determine whether the task was successful + is_success = bool(info.get("is_success", False)) + terminated = done or is_success or truncated + info.update( + { + "task": self.task, + "done": done, + "is_success": is_success, + } + ) + + # Format the raw observation into the expected structure + observation = self._format_raw_obs(raw_obs) + if terminated: + info["final_info"] = { + "task": self.task, + "done": bool(done), + "is_success": bool(is_success), + } + self.reset() + + return observation, reward, terminated, truncated, info + + def close(self): + self._env.close() + + +def _make_env_fns( + *, + suite: dict[str, list[Task]], + task: Task, + n_envs: int, + camera_names: list[str], + gym_kwargs: Mapping[str, Any], +) -> list[Callable[[], RLBenchEnv]]: + """Build n_envs factory callables for a single (suite, task).""" + + def _make_env(**kwargs) -> RLBenchEnv: + local_kwargs = dict(kwargs) + return RLBenchEnv( + task=task, + task_suite=suite, + camera_name=camera_names, + **local_kwargs, + ) + + fns: list[Callable[[], RLBenchEnv]] = [] + for _ in range(n_envs): + fns.append(partial(_make_env, **gym_kwargs)) + return fns + + +# ---- Main API ---------------------------------------------------------------- + + +def create_rlbench_envs( + task: str, + n_envs: int, + gym_kwargs: dict[str, Any] | None = None, + camera_name: str | Sequence[str] = "front_rgb,wrist_rgb", + env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None, +) -> dict[str, dict[int, Any]]: + """ + Create vectorized RLBench environments with a consistent return shape. + + Returns: + dict[task_group][task_id] -> vec_env (env_cls([...]) with exactly n_envs factories) + Notes: + - n_envs is the number of rollouts *per task* (episode_index = 0..n_envs-1). + - `task` can be a single suite or a comma-separated list of suites. + - You may pass `task_names` (list[str]) inside `gym_kwargs` to restrict tasks per suite. + """ + + if env_cls is None or not callable(env_cls): + raise ValueError("env_cls must be a callable that wraps a list of environment factory callables.") + if not isinstance(n_envs, int) or n_envs <= 0: + raise ValueError(f"n_envs must be a positive int; got {n_envs}.") + + gym_kwargs = dict(gym_kwargs or {}) + task_ids_filter = gym_kwargs.pop("task_ids", None) # optional: limit to specific tasks + + camera_names = _parse_camera_names(camera_name) + suite_names = [s.strip() for s in str(task).split(",") if s.strip()] + if not suite_names: + raise ValueError("`task` must contain at least one RLBench suite name.") + + print(f"Creating RLBench envs | task_groups={suite_names} | n_envs(per task)={n_envs}") + + out: dict[str, dict[int, Any]] = defaultdict(dict) + + for suite_name in suite_names: + suite = _get_suite(suite_name) + total = len(suite["train"]) + + # Select task ids to build + if task_ids_filter is not None: + # If string, parse to list of ints (task ids) + if isinstance(task_ids_filter, str): + task_ids_filter = json.loads(task_ids_filter) + print(f"Restricting to task_ids={task_ids_filter}") + selected = _select_task_ids(total, task_ids_filter) + print( + f"Selected {len(selected)} tasks from suite '{suite_name}': {[TASK_ID_TO_NAME[str(id)] for id in selected]}." + ) + + if not selected: + raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).") + + for tid in selected: # FIXME: this breaks for multi-task! + fns = _make_env_fns( + suite=suite, + task=suite["train"][tid], + n_envs=n_envs, + camera_names=camera_names, + gym_kwargs=gym_kwargs, + ) + out[suite_name][tid] = env_cls(fns) + print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}") + + # return plain dicts for predictability + return {group: dict(task_map) for group, task_map in out.items()} diff --git a/src/lerobot/envs/rlbench_config.json b/src/lerobot/envs/rlbench_config.json new file mode 100644 index 0000000000..ee0353bb9f --- /dev/null +++ b/src/lerobot/envs/rlbench_config.json @@ -0,0 +1,219 @@ +{ + "TASK_DESCRIPTIONS": { + "BasketballInHoop": "Put the ball in the hoop", + "BeatTheBuzz": "Beat the buzz", + "BlockPyramid": "Stack blocks in a pyramid", + "ChangeChannel": "Turn the channel", + "ChangeClock": "Change the clock to show time 12.15", + "CloseBox": "Close box", + "CloseDoor": "Close the door", + "CloseDrawer": "Close drawer", + "CloseFridge": "Close fridge", + "CloseGrill": "Close the grill", + "CloseJar": "Close the jar", + "CloseLaptopLid": "Close laptop lid", + "CloseMicrowave": "Close microwave", + "EmptyContainer": "Empty the container in the to container", + "EmptyDishwasher": "Empty the dishwasher", + "GetIceFromFridge": "Get ice from fridge", + "HangFrameOnHanger": "Hang frame on hanger", + "HitBallWithQueue": "Hit ball with queue in to the goal", + "Hockey": "Hit the ball into the net", + "InsertOntoSquarePeg": "Put the ring on the spoke", + "InsertUsbInComputer": "Insert usb in computer", + "LampOff": "Turn off the light", + "LampOn": "Turn on the light", + "LiftNumberedBlock": "Pick up the block with the number", + "LightBulbIn": "Screw in the light bulb", + "LightBulbOut": "Put the bulb in the holder", + "MeatOffGrill": "Take the off the grill", + "MeatOnGrill": "Put the on the grill", + "MoveHanger": "Move hanger onto the other rack", + "OpenBox": "Open box", + "OpenDoor": "Open the door", + "OpenDrawer": "Open drawer", + "OpenFridge": "Open fridge", + "OpenGrill": "Open the grill", + "OpenJar": "Open the jar", + "OpenMicrowave": "Open microwave", + "OpenOven": "Open the oven", + "OpenWashingMachine": "Open washing machine", + "OpenWindow": "Open left window", + "OpenWineBottle": "Open wine bottle", + "PhoneOnBase": "Put the phone on the base", + "PickAndLift": "Pick up the block and lift it up to the target", + "PickAndLiftSmall": "Pick up the and lift it up to the target", + "PickUpCup": "Pick up the cup", + "PlaceCups": "Place 1 cup on the cup holder", + "PlaceHangerOnRack": "Pick up the hanger and place in on the rack", + "PlaceShapeInShapeSorter": "Put the in the shape sorter", + "PlayJenga": "Play jenga", + "PlugChargerInPowerSupply": "Plug charger in power supply", + "PourFromCupToCup": "Pour liquid from the cup to the cup", + "PressSwitch": "Press switch", + "PushButton": "Push the button", + "PushButtons": "Push the buttons", + "PutAllGroceriesInCupboard": "Put all of the groceries in the cupboard", + "PutBooksOnBookshelf": "Put books on bookshelf", + "PutBottleInFridge": "Put bottle in fridge", + "PutGroceriesInCupboard": "Put the in the cupboard", + "PutItemInDrawer": "Put item in drawer", + "PutKnifeInKnifeBlock": "Put the knife in the knife block", + "PutKnifeOnChoppingBoard": "Put the knife on the chopping board", + "PutMoneyInSafe": "Put the money away in the safe on the shelf", + "PutPlateInColoredDishRack": "Put the plate between the pillars of the dish rack", + "PutRubbishInBin": "Put rubbish in bin", + "PutShoesInBox": "Put the shoes in the box", + "PutToiletRollOnStand": "Put toilet roll on stand", + "PutTrayInOven": "Put tray in oven", + "PutUmbrellaInUmbrellaStand": "Put umbrella in umbrella stand", + "ReachAndDrag": "Use the stick to drag the cube onto the target", + "ReachTarget": "Reach the target", + "RemoveCups": "Remove 1 cup from the cup holder and place it on the", + "ScoopWithSpatula": "Scoop up the cube and lift it with the spatula", + "ScrewNail": "Screw the nail in to the block", + "SetTheTable": "Set the table", + "SetupCheckers": "Setup checkers", + "SetupChess": "Setup chess", + "SlideBlockToTarget": "Slide the block to target", + "SlideCabinetOpen": "Slide cabinet open", + "SlideCabinetOpenAndPlaceCups": "Put cup in cabinet", + "SolvePuzzle": "Solve the puzzle", + "StackBlocks": "Stack blocks", + "StackChairs": "Stack the other chairs on top of the chair", + "StackCups": "Stack the other cups on top of the cup", + "StackWine": "Stack wine bottle", + "StraightenRope": "Straighten rope", + "SweepToDustpan": "Sweep dirt to dustpan", + "TakeCupOutFromCabinet": "Take out a cup from the half of the cabinet", + "TakeFrameOffHanger": "Take frame off hanger", + "TakeItemOutOfDrawer": "Take item out of the drawer", + "TakeLidOffSaucepan": "Take lid off the saucepan", + "TakeMoneyOutSafe": "Take the money out of the bottom shelf and place it on", + "TakeOffWeighingScales": "Remove the pepper from the weighing scales and place it", + "TakePlateOffColoredDishRack": "Take plate off the colored rack", + "TakeShoesOutOfBox": "Take shoes out of box", + "TakeToiletRollOffStand": "Take toilet roll off stand", + "TakeTrayOutOfOven": "Take tray out of oven", + "TakeUmbrellaOutOfUmbrellaStand": "Take umbrella out of umbrella stand", + "TakeUsbOutOfComputer": "Take usb out of computer", + "ToiletSeatDown": "Toilet seat down", + "ToiletSeatUp": "Lift toilet seat up", + "TurnOvenOn": "Turn on the oven", + "TurnTap": "Turn tap", + "TvOn": "Turn on the TV", + "UnplugCharger": "Unplug charger", + "WaterPlants": "Water plant", + "WeighingScales": "Weigh the pepper", + "WipeDesk": "Wipe dirt off the desk" + }, + + "TASK_ID_TO_NAME": { + "0": "ReachTarget", + "1": "CloseBox", + "2": "CloseMicrowave", + "3": "PlugChargerInPowerSupply", + "4": "ToiletSeatDown", + "5": "TakeUmbrellaOutOfUmbrellaStand", + "6": "PutUmbrellaInUmbrellaStand", + "7": "SlideCabinetOpen", + "8": "CloseFridge", + "9": "PickAndLift", + "10": "OpenBox", + "11": "OpenMicrowave", + "12": "UnplugCharger", + "13": "ToiletSeatUp", + "14": "OpenFridge", + "15": "TurnTap", + "16": "LightBulbIn", + "17": "BasketballInHoop", + "18": "OpenWindow", + "19": "CloseDoor", + "20": "PushButton", + "21": "PutItemInDrawer", + "22": "OpenDrawer", + "23": "CloseDrawer", + "24": "TurnOvenOn", + "25": "LightBulbOut", + "26": "TvOn", + "27": "OpenOven", + "28": "OpenDoor", + "29": "TakeItemOutOfDrawer", + "30": "BeatTheBuzz", + "31": "BlockPyramid", + "32": "ChangeClock", + "33": "CloseJar", + "34": "CloseLaptopLid", + "35": "EmptyContainer", + "36": "EmptyDishwasher", + "37": "GetIceFromFridge", + "38": "HangFrameOnHanger", + "39": "InsertOntoSquarePeg", + "40": "PutRubbishInBin", + "41": "PutShoesInBox", + "42": "PutToiletRollOnStand", + "43": "PutTrayInOven", + "44": "ReachAndDrag", + "45": "RemoveCups", + "46": "ScoopWithSpatula", + "47": "SetTheTable", + "48": "SetupCheckers", + "49": "SlideBlockToTarget", + "50": "Hockey", + "51": "InsertUsbInComputer", + "52": "PressSwitch", + "53": "PlayJenga", + "54": "MeatOffGrill", + "55": "HitBallWithQueue", + "56": "ScrewNail", + "57": "LampOff", + "58": "LampOn", + "59": "MeatOnGrill", + "60": "MoveHanger", + "61": "OpenJar", + "62": "OpenWineBottle", + "63": "PlaceCups", + "64": "PlaceHangerOnRack", + "65": "PlaceShapeInShapeSorter", + "66": "PutBottleInFridge", + "67": "PutKnifeInKnifeBlock", + "68": "PutMoneyInSafe", + "69": "PutPlateInColoredDishRack", + "70": "SlideCabinetOpenAndPlaceCups", + "71": "StackBlocks", + "72": "StackCups", + "73": "StackWine", + "74": "StraightenRope", + "75": "SweepToDustpan", + "76": "TakeCupOutFromCabinet", + "77": "TakeFrameOffHanger", + "78": "TakeLidOffSaucepan", + "79": "TakeMoneyOutSafe", + "80": "TakeOffWeighingScales", + "81": "TakePlateOffColoredDishRack", + "82": "TakeShoesOutOfBox", + "83": "TakeToiletRollOffStand", + "84": "TakeUsbOutOfComputer", + "85": "WaterPlants", + "86": "WeighingScales", + "87": "WipeDesk", + "88": "ChangeChannel", + "89": "OpenGrill", + "90": "CloseGrill", + "91": "SolvePuzzle", + "92": "PickUpCup", + "93": "PhoneOnBase", + "94": "PourFromCupToCup", + "95": "PutKnifeOnChoppingBoard", + "96": "PutBooksOnBookshelf", + "97": "PushButtons", + "98": "PutGroceriesInCupboard", + "99": "TakeTrayOutOfOven", + "100": "LiftNumberedBlock", + "101": "OpenWashingMachine", + "102": "PickAndLiftSmall", + "103": "PutAllGroceriesInCupboard", + "104": "SetupChess", + "105": "StackChairs" + } +} diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index 5584e0bff0..1c3c54871c 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -26,7 +26,7 @@ from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.envs.configs import EnvConfig -from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STATE_JOINTS from lerobot.utils.utils import get_channel_first_image_shape @@ -75,6 +75,13 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten return_observations[OBS_ENV_STATE] = env_state + if "agent_joints" in observations: + joints_state = torch.from_numpy(observations["agent_joints"]).float() + if joints_state.dim() == 1: + joints_state = joints_state.unsqueeze(0) + + return_observations[OBS_STATE_JOINTS] = joints_state + # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing agent_pos = torch.from_numpy(observations["agent_pos"]).float() if agent_pos.dim() == 1: diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index 6847666eb9..33e08aa0f6 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -21,6 +21,7 @@ OBS_PREFIX = OBS_STR + "." OBS_ENV_STATE = OBS_STR + ".environment_state" OBS_STATE = OBS_STR + ".state" +OBS_STATE_JOINTS = OBS_STATE + ".joints" OBS_IMAGE = OBS_STR + ".image" OBS_IMAGES = OBS_IMAGE + "s" OBS_LANGUAGE = OBS_STR + ".language"