diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 7d819218..4cb35e58 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -80,6 +80,7 @@ jobs: export HF_ENDPOINT=https://hf-mirror.com pip uninstall pymeshlab -y pip install pymeshlab==2023.12.post3 + pip install numpy==1.26.4 pytest tests publish: diff --git a/configs/gym/pour_water/gym_config.json b/configs/gym/pour_water/gym_config.json index 9048f817..04c73b1b 100644 --- a/configs/gym/pour_water/gym_config.json +++ b/configs/gym/pour_water/gym_config.json @@ -1,6 +1,6 @@ { "id": "PourWater-v3", - "max_episodes": 5, + "max_episodes": 10, "env": { "events": { "random_light": { @@ -258,11 +258,38 @@ } }, "dataset": { - "robot_meta": { - "arm_dofs": 12, - "control_freq": 25, - "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"], - "min_len_steps": 5 + "lerobot": { + "func": "LeRobotRecorder", + "mode": "save", + "params": { + "robot_meta": { + "robot_type": "CobotMagic", + "arm_dofs": 12, + "control_freq": 25, + "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"], + "observation": { + "vision": { + "cam_high": ["mask"], + "cam_right_wrist": ["mask"], + "cam_left_wrist": ["mask"] + }, + "states": ["qpos"], + "exteroception": ["cam_high", "cam_right_wrist", "cam_left_wrist"] + }, + "action": "qpos_with_eef_pose", + "min_len_steps": 5 + }, + "instruction": { + "lang": "Pour water from bottle to cup" + }, + "extra": { + "scene_type": "Commercial", + "task_description": "Pour water", + "data_type": "sim" + }, + "use_videos": true, + "export_success_only": false + } } } }, diff --git a/configs/gym/pour_water/gym_config_simple.json b/configs/gym/pour_water/gym_config_simple.json new file mode 100644 index 00000000..f116e0f9 --- /dev/null +++ b/configs/gym/pour_water/gym_config_simple.json @@ -0,0 +1,326 @@ +{ + "id": "PourWater-v3", + "max_episodes": 5, + "env": { + "events": { + "random_light": { + "func": "randomize_light", + "mode": "interval", + "interval_step": 10, + "params": { + "entity_cfg": {"uid": "light_1"}, + "position_range": [[-0.5, -0.5, 2], [0.5, 0.5, 2]], + "color_range": [[0.6, 0.6, 0.6], [1, 1, 1]], + "intensity_range": [50.0, 100.0] + } + }, + "init_bottle_pose": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "bottle"}, + "position_range": [[-0.08, -0.12, 0.0], [0.08, 0.04, 0.0]], + "relative_position": true + } + }, + "init_cup_pose": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "cup"}, + "position_range": [[-0.08, -0.04, 0.0], [0.08, 0.12, 0.0]], + "relative_position": true + } + }, + "prepare_extra_attr": { + "func": "prepare_extra_attr", + "mode": "reset", + "params": { + "attrs": [ + { + "name": "object_lengths", + "mode": "callable", + "entity_uids": "all_objects", + "func_name": "compute_object_length", + "func_kwargs": { + "is_svd_frame": true, + "sample_points": 5000 + } + }, + { + "name": "grasp_pose_object", + "mode": "static", + "entity_cfg": { + "uid": "bottle" + }, + "value": [[ + [0.32243, 0.03245, 0.94604, 0.025], + [0.00706, -0.99947, 0.03188, -0.0 ], + [0.94657, -0.0036 , -0.32249, 0.0 ], + [0.0 , 0.0 , 0.0 , 1.0 ] + ]] + }, + { + "name": "left_arm_base_pose", + "mode": "callable", + "entity_cfg": { + "uid": "CobotMagic" + }, + "func_name": "get_link_pose", + "func_kwargs": { + "link_name": "left_arm_base", + "to_matrix": true + } + }, + { + "name": "right_arm_base_pose", + "mode": "callable", + "entity_cfg": { + "uid": "CobotMagic" + }, + "func_name": "get_link_pose", + "func_kwargs": { + "link_name": "right_arm_base", + "to_matrix": true + } + } + ] + } + }, + "register_info_to_env": { + "func": "register_info_to_env", + "mode": "reset", + "params": { + "registry": [ + { + "entity_cfg": { + "uid": "bottle" + }, + "pose_register_params": { + "compute_relative": false, + "compute_pose_object_to_arena": true, + "to_matrix": true + } + }, + { + "entity_cfg": { + "uid": "cup" + }, + "pose_register_params": { + "compute_relative": false, + "compute_pose_object_to_arena": true, + "to_matrix": true + } + }, + { + "entity_cfg": { + "uid": "CobotMagic", + "control_parts": ["left_arm"] + }, + "attrs": ["left_arm_base_pose"], + "pose_register_params": { + "compute_relative": "cup", + "compute_pose_object_to_arena": false, + "to_matrix": true + }, + "prefix": false + }, + { + "entity_cfg": { + "uid": "CobotMagic", + "control_parts": ["right_arm"] + }, + "attrs": ["right_arm_base_pose"], + "pose_register_params": { + "compute_relative": "bottle", + "compute_pose_object_to_arena": false, + "to_matrix": true + }, + "prefix": false + } + ], + "registration": "affordance_datas", + "sim_update": true + } + }, + "random_material": { + "func": "randomize_visual_material", + "mode": "interval", + "interval_step": 10, + "params": { + "entity_cfg": {"uid": "table"}, + "random_texture_prob": 0.0, + "texture_path": "CocoBackground/coco", + "base_color_range": [[0.2, 0.2, 0.2], [1.0, 1.0, 1.0]] + } + }, + "random_cup_material": { + "func": "randomize_visual_material", + "mode": "interval", + "interval_step": 10, + "params": { + "entity_cfg": {"uid": "cup"}, + "random_texture_prob": 0.0, + "texture_path": "CocoBackground/coco", + "base_color_range": [[0.2, 0.2, 0.2], [1.0, 1.0, 1.0]] + } + }, + "random_bottle_material": { + "func": "randomize_visual_material", + "mode": "interval", + "interval_step": 10, + "params": { + "entity_cfg": {"uid": "bottle"}, + "random_texture_prob": 0.0, + "texture_path": "CocoBackground/coco", + "base_color_range": [[0.2, 0.2, 0.2], [1.0, 1.0, 1.0]] + } + }, + "random_robot_init_eef_pose": { + "func": "randomize_robot_eef_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "CobotMagic", "control_parts": ["left_arm", "right_arm"]}, + "position_range": [[-0.01, -0.01, -0.01], [0.01, 0.01, 0]] + } + } + }, + "observations": { + "norm_robot_eef_joint": { + "func": "normalize_robot_joint_data", + "mode": "modify", + "name": "robot/qpos", + "params": { + "joint_ids": [12, 13, 14, 15] + } + } + }, + "dataset": { + "lerobot": { + "func": "LeRobotRecorder", + "mode": "save", + "params": { + "robot_meta": { + "robot_type": "CobotMagic", + "arm_dofs": 12, + "control_freq": 25, + "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"], + "observation": { + "vision": { + "cam_high": ["mask"] + }, + "states": ["qpos"], + "exteroception": ["cam_high", "cam_right_wrist", "cam_left_wrist"] + }, + "action": "qpos_with_eef_pose", + "min_len_steps": 5 + }, + "instruction": { + "lang": "Pour water from bottle to cup" + }, + "extra": { + "scene_type": "Commercial", + "task_description": "Pour water", + "data_type": "sim" + }, + "use_videos": true, + "export_success_only": false + } + } + } + }, + "robot": { + "uid": "CobotMagic", + "robot_type": "CobotMagic", + "init_pos": [0.0, 0.0, 0.7775], + "init_qpos": [-0.3,0.3,1.0,1.0,-1.2,-1.2,0.0,0.0,0.6,0.6,0.0,0.0,0.05,0.05,0.05,0.05] + }, + "sensor": [ + { + "sensor_type": "Camera", + "uid": "cam_high", + "width": 960, + "height": 540, + "intrinsics": [488.1665344238281, 488.1665344238281, 480, 270], + "extrinsics": { + "eye": [0.35368482807598, 0.014695524383058989, 1.4517046071614774], + "target": [0.8586357573287919, 0, 0.5232553674540066], + "up": [0.9306678549330372, -0.0005600064212467153, 0.3658647703553347] + } + } + ], + "light": { + "direct": [ + { + "uid": "light_1", + "light_type": "point", + "color": [1.0, 1.0, 1.0], + "intensity": 50.0, + "init_pos": [2, 0, 2], + "radius": 10.0 + } + ] + }, + "background": [ + { + "uid": "table", + "shape": { + "shape_type": "Mesh", + "fpath": "CircleTableSimple/circle_table_simple.ply", + "compute_uv": true + }, + "attrs" : { + "mass": 10.0, + "static_friction": 0.95, + "dynamic_friction": 0.9, + "restitution": 0.01 + }, + "body_scale": [1, 1, 1], + "body_type": "kinematic", + "init_pos": [0.725, 0.0, 0.825], + "init_rot": [0, 90, 0] + } + ], + "rigid_object": [ + { + "uid":"cup", + "shape": { + "shape_type": "Mesh", + "fpath": "PaperCup/paper_cup.ply", + "compute_uv": true + }, + "attrs" : { + "mass": 0.01, + "contact_offset": 0.003, + "rest_offset": 0.001, + "restitution": 0.01, + "max_depenetration_velocity": 1e1, + "min_position_iters": 32, + "min_velocity_iters":8 + }, + "init_pos": [0.75, 0.1, 0.9], + "body_scale":[0.75, 0.75, 1.0], + "max_convex_hull_num": 8 + }, + { + "uid":"bottle", + "shape": { + "shape_type": "Mesh", + "fpath": "ScannedBottle/kashijia_processed.ply", + "compute_uv": true + }, + "attrs" : { + "mass": 0.01, + "contact_offset": 0.003, + "rest_offset": 0.001, + "restitution": 0.01, + "max_depenetration_velocity": 1e1, + "min_position_iters": 32, + "min_velocity_iters":8 + }, + "init_pos": [0.75, -0.1, 0.932], + "body_scale":[1, 1, 1], + "max_convex_hull_num": 8 + } + ] +} \ No newline at end of file diff --git a/configs/gym/special/simple_task_ur10.json b/configs/gym/special/simple_task_ur10.json new file mode 100644 index 00000000..ee84c5ff --- /dev/null +++ b/configs/gym/special/simple_task_ur10.json @@ -0,0 +1,114 @@ +{ + "id": "SimpleTask-v1", + "max_episodes": 24, + "env": { + "events": { + "random_light": { + "func": "randomize_light", + "mode": "interval", + "interval_step": 20, + "params": { + "entity_cfg": {"uid": "light_1"}, + "position_range": [[-0.5, -0.5, 2], [0.5, 0.5, 2]], + "color_range": [[0.6, 0.6, 0.6], [1, 1, 1]], + "intensity_range": [50.0, 100.0] + } + }, + "random_material": { + "func": "randomize_visual_material", + "mode": "interval", + "interval_step": 50, + "params": { + "entity_cfg": {"uid": "table"}, + "random_texture_prob": 0.0, + "texture_path": "CocoBackground/coco", + "base_color_range": [[0.2, 0.2, 0.2], [1.0, 1.0, 1.0]] + } + } + }, + "dataset": { + "lerobot": { + "func": "LeRobotRecorder", + "mode": "save", + "params": { + "robot_meta": { + "robot_type": "UR10", + "arm_dofs": 6, + "control_freq": 3, + "observation": { + "vision": { + "cam_high": [] + }, + "states": ["qpos"] + }, + "action": "qpos", + "min_len_steps": 5 + }, + "instruction": { + "lang": "Acting with Oscillatory motion" + }, + "extra": { + "scene_type": "commercial", + "task_description": "Oscillatory motion", + "data_type": "sim" + }, + "use_videos": false, + "export_success_only": false + } + } + } + }, + "robot": { + "uid": "UR10", + "fpath": "UniversalRobots/UR10/UR10.urdf", + "init_pos": [0.0, 0.0, 0.7775], + "init_qpos": [1.57079, -1.57079, 1.57079, -1.57079, -1.57079, -3.14159] + }, + "sensor": [ + { + "sensor_type": "Camera", + "uid": "cam_high", + "width": 640, + "height": 480, + "intrinsics": [488.1665344238281, 488.1665344238281, 320.0, 240.0], + "extrinsics": { + "eye": [1, 0, 3], + "target": [0, 0, 1] + } + } + ], + "light": { + "direct": [ + { + "uid": "light_1", + "light_type": "point", + "color": [1.0, 1.0, 1.0], + "intensity": 50.0, + "init_pos": [2, 0, 2], + "radius": 10.0 + } + ] + }, + "background": [ + { + "uid": "table", + "shape": { + "shape_type": "Mesh", + "fpath": "CircleTableSimple/circle_table_simple.ply", + "compute_uv": true + }, + "attrs" : { + "mass": 10.0, + "static_friction": 0.95, + "dynamic_friction": 0.9, + "restitution": 0.01 + }, + "body_scale": [1, 1, 1], + "body_type": "kinematic", + "init_pos": [0.8, 0.0, 0.825], + "init_rot": [0, 90, 0] + } + ], + "rigid_object": [ + ] +} \ No newline at end of file diff --git a/docs/source/quick_start/install.md b/docs/source/quick_start/install.md index fd00abef..7bb9bf5f 100644 --- a/docs/source/quick_start/install.md +++ b/docs/source/quick_start/install.md @@ -53,8 +53,14 @@ Install the project in development mode: ```bash pip install -e . --extra-index-url http://pyp.open3dv.site:2345/simple/ --trusted-host pyp.open3dv.site + +# Or install with the lerobot extras: +pip install -e .[lerobot] --extra-index-url http://pyp.open3dv.site:2345/simple/ --trusted-host pyp.open3dv.site ``` +> [!NOTE] +> * [LeRobot](https://huggingface.co/docs/lerobot/installation) is an optional module for EmbodiChain that provides data saving and loading functionalities for robot learning tasks. Installing with the `lerobot` extras will include this module and its dependencies. + ### Verify Installation To verify that EmbodiChain is installed correctly, run a simple demo script to create a simulation scene: diff --git a/embodichain/data/constants.py b/embodichain/data/constants.py index 17264a80..8409d89a 100644 --- a/embodichain/data/constants.py +++ b/embodichain/data/constants.py @@ -21,3 +21,4 @@ "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/" ) EMBODICHAIN_DEFAULT_DATA_ROOT = str(Path.home() / ".cache" / "embodichain_data") +EMBODICHAIN_DEFAULT_DATASET_ROOT = str(Path.home() / ".cache" / "embodichain_datasets") diff --git a/embodichain/data/enum.py b/embodichain/data/enum.py index 3716b840..cf758192 100644 --- a/embodichain/data/enum.py +++ b/embodichain/data/enum.py @@ -15,6 +15,8 @@ # ---------------------------------------------------------------------------- from enum import Enum, IntEnum +import torch +import numpy as np class SemanticMask(IntEnum): @@ -59,3 +61,16 @@ class Hints(Enum): EndEffector.DEXTROUSHAND.value, ) ARM = (ControlParts.LEFT_ARM.value, ControlParts.RIGHT_ARM.value) + + +class JointType(Enum): + QPOS = "qpos" + + +class EefType(Enum): + POSE = "eef_pose" + + +class ActionMode(Enum): + ABSOLUTE = "" + RELATIVE = "delta_" # This indicates the action is relative change with respect to last state. diff --git a/embodichain/lab/gym/envs/__init__.py b/embodichain/lab/gym/envs/__init__.py index 286d7824..88257690 100644 --- a/embodichain/lab/gym/envs/__init__.py +++ b/embodichain/lab/gym/envs/__init__.py @@ -29,3 +29,5 @@ # Reinforcement learning environments from embodichain.lab.gym.envs.tasks.rl.push_cube import PushCubeEnv + +from embodichain.lab.gym.envs.tasks.special.simple_task import SimpleTaskEnv diff --git a/embodichain/lab/gym/envs/action_bank/configurable_action.py b/embodichain/lab/gym/envs/action_bank/configurable_action.py index 62144cfc..3dd649e9 100644 --- a/embodichain/lab/gym/envs/action_bank/configurable_action.py +++ b/embodichain/lab/gym/envs/action_bank/configurable_action.py @@ -616,9 +616,16 @@ def initialize_with_current_qpos( # TODO: Hard to get current qpos for multi-agent env current_qpos = env.robot.get_qpos() joint_ids = env.robot.get_joint_ids(name=get_control_part(env, executor)) - if current_qpos.ndim == 2 and current_qpos.shape[0] == 1: - current_qpos = current_qpos[0] - current_qpos = current_qpos[joint_ids].cpu() + + # Handle multi-environment case + if current_qpos.ndim == 2: + # current_qpos shape: [num_envs, num_joints] + # Take first environment and then select joints + current_qpos = current_qpos[0, joint_ids].cpu() + else: + # Single environment case + # current_qpos shape: [num_joints] + current_qpos = current_qpos[joint_ids].cpu() executor_qpos_dim = action_list[executor].shape[0] diff --git a/embodichain/lab/gym/envs/base_env.py b/embodichain/lab/gym/envs/base_env.py index 1b2e5b5a..15f5d3be 100644 --- a/embodichain/lab/gym/envs/base_env.py +++ b/embodichain/lab/gym/envs/base_env.py @@ -103,13 +103,13 @@ def __init__( self.cfg = cfg # the number of envs to be simulated in parallel. - self.num_envs = self.cfg.num_envs + self._num_envs = self.cfg.num_envs if self.cfg.sim_cfg is None: self.sim_cfg = SimulationManagerCfg(headless=True) else: self.sim_cfg = self.cfg.sim_cfg - self.sim_cfg.num_envs = self.num_envs + self.sim_cfg.num_envs = self._num_envs if self.cfg.seed is not None: self.cfg.seed = set_seed(self.cfg.seed) @@ -129,7 +129,7 @@ def __init__( self.sim.open_window() self._elapsed_steps = torch.zeros( - self.num_envs, dtype=torch.int32, device=self.sim_cfg.sim_device + self._num_envs, dtype=torch.int32, device=self.sim_cfg.sim_device ) self._init_sim_state(**kwargs) @@ -138,7 +138,7 @@ def __init__( logger.log_info("[INFO]: Initialized environment:") logger.log_info(f"\tEnvironment device : {self.sim.device}") - logger.log_info(f"\tNumber of environments: {self.num_envs}") + logger.log_info(f"\tNumber of environments: {self._num_envs}") logger.log_info(f"\tEnvironment seed : {self.cfg.seed}") logger.log_info(f"\tPhysics dt : {self.sim_cfg.physics_dt}") logger.log_info( @@ -146,7 +146,12 @@ def __init__( ) @property - def device(self) -> torch.Tensor: + def num_envs(self) -> int: + """Return the number of environments simulated in parallel.""" + return self._num_envs + + @property + def device(self) -> torch.device: """Return the device used by the environment.""" return self.sim.device @@ -380,7 +385,7 @@ def get_info(self, **kwargs) -> Dict[str, Any]: info.update(self.evaluate(**kwargs)) return info - def check_truncated(self, obs: EnvObs, info: Dict[str, Any]) -> bool: + def check_truncated(self, obs: EnvObs, info: Dict[str, Any]) -> torch.Tensor: """Check if the episode is truncated. Args: @@ -388,7 +393,7 @@ def check_truncated(self, obs: EnvObs, info: Dict[str, Any]) -> bool: info: The info dictionary. Returns: - True if the episode is truncated, False otherwise. + A boolean tensor indicating truncation for each environment in the batch. """ return torch.zeros(self.num_envs, dtype=torch.bool, device=self.device) diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index cd5ac648..8ce31346 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -20,7 +20,7 @@ import gymnasium as gym from dataclasses import MISSING -from typing import Dict, Union, Sequence, Tuple, Any, List +from typing import Dict, Union, Sequence, Tuple, Any, List, Optional from embodichain.lab.sim.cfg import ( RobotCfg, @@ -42,6 +42,7 @@ from embodichain.lab.gym.envs.managers import ( EventManager, ObservationManager, + DatasetManager, ) from embodichain.lab.gym.utils.registration import register_env from embodichain.utils import configclass, logger @@ -90,9 +91,10 @@ class EnvLightCfg: Please refer to the :class:`embodichain.lab.gym.managers.ObservationManager` class for more details. """ - # TODO: This would be changed to a more generic data pipeline configuration. - dataset: Union[Dict[str, Any], None] = None - """Data pipeline configuration. Defaults to None. + dataset: Union[object, None] = None + """Dataset settings. Defaults to None, in which case no dataset collection is performed. + + Please refer to the :class:`embodichain.lab.gym.managers.DatasetManager` class for more details. """ extensions: Union[Dict[str, Any], None] = None @@ -145,6 +147,10 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): self.affordance_datas = {} self.action_bank = None + # TODO: Change to array like data structure to handle different demo action list length for across different arena. + self.action_length: int = 0 # Set by create_demo_action_list + self._action_step_counter: int = 0 # Track steps within current action sequence + extensions = getattr(cfg, "extensions", {}) or {} for name, value in extensions.items(): @@ -170,14 +176,8 @@ def _init_sim_state(self, **kwargs): if self.cfg.observations: self.observation_manager = ObservationManager(self.cfg.observations, self) - # TODO: A workaround for handling dataset saving, which need history data of obs-action pairs. - # We may improve this by implementing a data manager to handle data saving and online streaming. - if self.cfg.dataset is not None: - self.metadata["dataset"] = self.cfg.dataset - self.episode_obs_list = [] - self.episode_action_list = [] - - self.curr_episode = 0 + if self.cfg.dataset: + self.dataset_manager = DatasetManager(self.cfg.dataset, self) def _apply_functor_filter(self) -> None: """Apply functor filters to the environment components based on configuration. @@ -261,25 +261,68 @@ def reset( self, seed: int | None = None, options: dict | None = None ) -> Tuple[EnvObs, Dict]: obs, info = super().reset(seed=seed, options=options) - - if hasattr(self, "episode_obs_list"): - self.episode_obs_list = [obs] - self.episode_action_list = [] - + self._action_step_counter = 0 # Reset action step counter return obs, info def step( self, action: EnvAction, **kwargs ) -> Tuple[EnvObs, torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]: - # TODO: Maybe add action preprocessing manager and its functors. - obs, reward, done, truncated, info = super().step(action, **kwargs) - - if hasattr(self, "episode_action_list"): + """Step the environment with the given action. + + Extends BaseEnv.step() to integrate with DatasetManager for automatic + data collection and saving. The key is to: + 1. Record obs-action pairs as they happen + 2. Detect episode completion + 3. Auto-save episodes BEFORE reset + 4. Then perform the actual reset + """ + self._elapsed_steps += 1 + self._action_step_counter += 1 # Increment action sequence counter + + action = self._step_action(action=action) + self.sim.update(self.sim_cfg.physics_dt, self.cfg.sim_steps_per_control) + self._update_sim_state(**kwargs) + + obs = self.get_obs(**kwargs) + info = self.get_info(**kwargs) + rewards = self.get_reward(obs=obs, action=action, info=info) + + # Check termination conditions + terminateds = torch.logical_or( + info.get( + "success", + torch.zeros(self.num_envs, dtype=torch.bool, device=self.device), + ), + info.get( + "fail", torch.zeros(self.num_envs, dtype=torch.bool, device=self.device) + ), + ) + truncateds = self.check_truncated(obs=obs, info=info) + if self.cfg.ignore_terminations: + terminateds[:] = False + + # Detect which environments need reset + dones = torch.logical_or(terminateds, truncateds) + reset_env_ids = dones.nonzero(as_tuple=False).squeeze(-1) + + # Call dataset manager with mode="save": it will record and auto-save if dones=True + if self.cfg.dataset: + if "save" in self.dataset_manager.available_modes: + self.dataset_manager.apply( + mode="save", + env_ids=None, + obs=obs, + action=action, + dones=dones, + terminateds=terminateds, + info=info, + ) - self.episode_obs_list.append(obs) - self.episode_action_list.append(action) + # Now perform reset for completed environments + if len(reset_env_ids) > 0: + obs, _ = self.reset(options={"reset_ids": reset_env_ids}) - return obs, reward, done, truncated, info + return obs, rewards, terminateds, truncateds, info def _extend_obs(self, obs: EnvObs, **kwargs) -> EnvObs: if self.observation_manager: @@ -450,6 +493,14 @@ def create_demo_action_list(self, *args, **kwargs) -> Sequence[EnvAction] | None This function should be implemented in subclasses to generate a sequence of actions that demonstrate a specific task or behavior within the environment. + Important: + Subclasses MUST set `self.action_length` to the length of the returned action list. + This is used by the environment to automatically detect episode truncation. + Example: + action_list = [...] # Generate actions + self.action_length = len(action_list) + return action_list + Returns: Sequence[EnvAction] | None: A list of actions if a demonstration is available, otherwise None. """ @@ -457,29 +508,40 @@ def create_demo_action_list(self, *args, **kwargs) -> Sequence[EnvAction] | None "The method 'create_demo_action_list' must be implemented in subclasses." ) - def to_dataset(self, id: str, save_path: str = None) -> str | None: - """Convert the recorded episode data to a dataset format. + def is_task_success(self, **kwargs) -> torch.Tensor: + """ + Determine if the task is successfully completed. This is mainly used in the data generation process + of the imitation learning. Args: - id (str): Unique identifier for the dataset. - save_path (str, optional): Path to save the dataset. If None, use config or default. + **kwargs: Additional arguments for task-specific success criteria. Returns: - str | None: The path to the saved dataset, or None if failed. + torch.Tensor: A boolean tensor indicating success for each environment in the batch. """ - raise NotImplementedError( - "The method 'to_dataset' will be implemented in the near future." - ) - def is_task_success(self, **kwargs) -> torch.Tensor: - """Determine if the task is successfully completed. This is mainly used in the data generation process - of the imitation learning. + return torch.ones(self.num_envs, dtype=torch.bool, device=self.device) + + def check_truncated(self, obs: EnvObs, info: Dict[str, Any]) -> torch.Tensor: + """Check if the episode is truncated. Args: - **kwargs: Additional arguments for task-specific success criteria. + obs: The observation from the environment. + info: The info dictionary. Returns: - torch.Tensor: A boolean tensor indicating success for each environment in the batch. + A boolean tensor indicating truncation for each environment in the batch. """ + # Check if action sequence has reached its end + if self.action_length > 0 and self._action_step_counter >= self.action_length: + return torch.ones(self.num_envs, dtype=torch.bool, device=self.device) - return torch.ones(self.num_envs, dtype=torch.bool, device=self.device) + return super().check_truncated(obs, info) + + def close(self) -> None: + """Close the environment and release resources.""" + # Finalize dataset if present + if self.cfg.dataset: + self.dataset_manager.finalize() + + self.sim.destroy() diff --git a/embodichain/lab/gym/envs/managers/__init__.py b/embodichain/lab/gym/envs/managers/__init__.py index 946165a8..e38f4f22 100644 --- a/embodichain/lab/gym/envs/managers/__init__.py +++ b/embodichain/lab/gym/envs/managers/__init__.py @@ -14,7 +14,15 @@ # limitations under the License. # ---------------------------------------------------------------------------- -from .cfg import FunctorCfg, SceneEntityCfg, EventCfg, ObservationCfg +from .cfg import ( + FunctorCfg, + SceneEntityCfg, + EventCfg, + ObservationCfg, + DatasetFunctorCfg, +) from .manager_base import Functor, ManagerBase from .event_manager import EventManager from .observation_manager import ObservationManager +from .dataset_manager import DatasetManager +from .datasets import * diff --git a/embodichain/lab/gym/envs/managers/cfg.py b/embodichain/lab/gym/envs/managers/cfg.py index 3f5c8da6..07888c9f 100644 --- a/embodichain/lab/gym/envs/managers/cfg.py +++ b/embodichain/lab/gym/envs/managers/cfg.py @@ -309,3 +309,15 @@ def _resolve_body_names(self, scene: SimulationManager): if isinstance(self.body_ids, int): self.body_ids = [self.body_ids] self.body_names = [entity.body_names[i] for i in self.body_ids] + + +@configclass +class DatasetFunctorCfg(FunctorCfg): + """Configuration for dataset collection functors. + + Dataset functors are called with mode="save" which handles both: + - Recording observation-action pairs on every step + - Auto-saving episodes when dones=True + """ + + mode: Literal["save"] = "save" diff --git a/embodichain/lab/gym/envs/managers/dataset_manager.py b/embodichain/lab/gym/envs/managers/dataset_manager.py new file mode 100644 index 00000000..0a8e9d49 --- /dev/null +++ b/embodichain/lab/gym/envs/managers/dataset_manager.py @@ -0,0 +1,321 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Dataset manager for orchestrating dataset collection functors.""" + +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from collections.abc import Sequence + +import torch +from prettytable import PrettyTable + +from embodichain.utils import logger +from embodichain.lab.sim.types import EnvObs, EnvAction +from .manager_base import ManagerBase +from .cfg import DatasetFunctorCfg + +if TYPE_CHECKING: + from embodichain.lab.gym.envs import EmbodiedEnv + + +class DatasetManager(ManagerBase): + """Manager for orchestrating dataset collection and saving using functors. + + The dataset manager supports multiple dataset formats through a functor system: + - LeRobot format (via LeRobotRecorder) + - HDF5 format (via HDF5Recorder) + - Zarr format (via ZarrRecorder) + - Custom formats (via user-defined functors) + + Each functor's step() method is called once per environment step and handles: + - Recording observation-action pairs + - Detecting episode completion (dones=True) + - Auto-saving completed episodes + + Example configuration: + >>> from embodichain.lab.gym.envs.managers.cfg import DatasetFunctorCfg + >>> from embodichain.lab.gym.envs.managers.datasets import LeRobotRecorder + >>> + >>> @configclass + >>> class MyEnvCfg: + >>> dataset: dict = { + >>> "lerobot": DatasetFunctorCfg( + >>> func=LeRobotRecorder, + >>> params={ + >>> "robot_meta": {...}, + >>> "instruction": {"lang": "pick and place"}, + >>> "extra": {"scene_type": "kitchen"}, + >>> "save_path": "/data/datasets", + >>> "export_success_only": True, + >>> } + >>> ) + >>> } + """ + + _env: EmbodiedEnv + """The environment instance.""" + + def __init__(self, cfg: object, env: EmbodiedEnv): + """Initialize the dataset manager. + + Args: + cfg: Configuration object containing dataset functor configurations. + env: The environment instance. + """ + # Store functors by mode (similar to EventManager) + self._mode_functor_names: dict[str, list[str]] = {} + self._mode_functor_cfgs: dict[str, list[DatasetFunctorCfg]] = {} + self._mode_class_functor_cfgs: dict[str, list[DatasetFunctorCfg]] = {} + + # Call base class to parse functors + super().__init__(cfg, env) + + ## TODO: fix configurable_action.py to avoid getting env.metadata['dataset'] + # Extract robot_meta from first functor and add to env.metadata for backward compatibility + # This allows legacy code (like action_bank) to access robot_meta via env.metadata["dataset"]["robot_meta"] + for mode_cfgs in self._mode_functor_cfgs.values(): + for functor_cfg in mode_cfgs: + if "robot_meta" in functor_cfg.params: + if not hasattr(env, "metadata"): + env.metadata = {} + if "dataset" not in env.metadata: + env.metadata["dataset"] = {} + env.metadata["dataset"]["robot_meta"] = functor_cfg.params[ + "robot_meta" + ] + logger.log_info( + "Added robot_meta to env.metadata for backward compatibility" + ) + break + else: + continue + break + + logger.log_info( + f"DatasetManager initialized with {sum(len(v) for v in self._mode_functor_names.values())} functors" + ) + + def __str__(self) -> str: + """Returns: A string representation for dataset manager.""" + msg = f" contains {len(self._functor_names)} active functors.\n" + + table = PrettyTable() + table.title = "Active Dataset Functors" + table.field_names = ["Index", "Name", "Type"] + table.align["Name"] = "l" + + for index, name in enumerate(self._functor_names): + functor_cfg = self._functor_cfgs[index] + functor_type = ( + functor_cfg.func.__class__.__name__ + if hasattr(functor_cfg.func, "__class__") + else str(functor_cfg.func) + ) + table.add_row([index, name, functor_type]) + + msg += table.get_string() + msg += "\n" + + return msg + + """ + Properties. + """ + + @property + def active_functors(self) -> dict[str, list[str]]: + """Name of active dataset functors by mode. + + The keys are the modes and the values are the names of the dataset functors. + """ + return self._mode_functor_names + + @property + def available_modes(self) -> list[str]: + """List of available modes for the dataset manager.""" + return list(self._mode_functor_names.keys()) + + """ + Operations. + """ + + def reset( + self, env_ids: Union[Sequence[int], torch.Tensor, None] = None + ) -> dict[str, float]: + """Reset all dataset functors. + + Args: + env_ids: The environment ids. Defaults to None. + + Returns: + Empty dict (no logging info). + """ + # Call reset on all class functors across all modes + for mode_cfgs in self._mode_class_functor_cfgs.values(): + for functor_cfg in mode_cfgs: + functor_cfg.func.reset(env_ids=env_ids) + + return {} + + def apply( + self, + mode: str, + env_ids: Union[Sequence[int], torch.Tensor, None] = None, + obs: Optional[EnvObs] = None, + action: Optional[EnvAction] = None, + dones: Optional[torch.Tensor] = None, + terminateds: Optional[torch.Tensor] = None, + info: Optional[Dict[str, Any]] = None, + ) -> None: + """Apply dataset functors for the specified mode. + + This method follows the same pattern as EventManager.apply() for consistency. + Currently only supports mode="save" which handles both recording and auto-saving. + + Args: + mode: The mode to apply (currently only "save" is supported). + env_ids: The indices of the environments to apply the functor to. + Defaults to None, in which case the functor is applied to all environments. + obs: Observation from the environment (batched for all envs). + action: Action applied to the environment (batched for all envs). + dones: Boolean tensor indicating which envs completed episodes. + terminateds: Boolean tensor indicating termination (success/fail). + info: Info dict containing success/fail information. + """ + # check if mode is valid + if mode not in self._mode_functor_names: + logger.log_warning( + f"Dataset mode '{mode}' is not defined. Skipping dataset operation." + ) + return + + # iterate over all the dataset functors for this mode + for functor_cfg in self._mode_functor_cfgs[mode]: + functor_cfg.func( + self._env, + env_ids, + obs, + action, + dones, + terminateds, + info, + **functor_cfg.params, + ) + + def finalize(self) -> Optional[str]: + """Finalize all dataset functors. + + Called when the environment is closed. Saves any remaining episodes + and finalizes all datasets. + + Returns: + Path to the first saved dataset, or None if failed. + """ + dataset_paths = [] + + # Call finalize on all class functors across all modes + for mode_cfgs in self._mode_class_functor_cfgs.values(): + for functor_cfg in mode_cfgs: + if hasattr(functor_cfg.func, "finalize"): + try: + path = functor_cfg.func.finalize() + if path: + dataset_paths.append(path) + except Exception as e: + logger.log_error(f"Failed to finalize functor: {e}") + + if dataset_paths: + logger.log_info(f"Finalized {len(dataset_paths)} datasets") + return dataset_paths[0] + + return None + + """ + Operations - Functor settings. + """ + + def get_functor_cfg(self, functor_name: str) -> DatasetFunctorCfg: + """Gets the configuration for the specified functor. + + Args: + functor_name: The name of the dataset functor. + + Returns: + The configuration of the dataset functor. + + Raises: + ValueError: If the functor name is not found. + """ + for mode, functors in self._mode_functor_names.items(): + if functor_name in functors: + return self._mode_functor_cfgs[mode][functors.index(functor_name)] + logger.log_error(f"Dataset functor '{functor_name}' not found.") + + """ + Helper functions. + """ + + def _prepare_functors(self): + """Prepare dataset functors from configuration. + + This method parses the configuration and initializes all dataset functors, + organizing them by mode (similar to EventManager). + """ + # Check if config is dict already + if isinstance(self.cfg, dict): + cfg_items = self.cfg.items() + else: + cfg_items = self.cfg.__dict__.items() + + # Iterate over all the functors + for functor_name, functor_cfg in cfg_items: + # Check for non config + if functor_cfg is None: + continue + + # Convert dict to DatasetFunctorCfg if needed (for JSON configs) + if isinstance(functor_cfg, dict): + functor_cfg = DatasetFunctorCfg(**functor_cfg) + + # Check for valid config type + if not isinstance(functor_cfg, DatasetFunctorCfg): + raise TypeError( + f"Configuration for '{functor_name}' is not of type DatasetFunctorCfg." + f" Received: '{type(functor_cfg)}'." + ) + + # Resolve common parameters + # min_argc=7 to skip: env, env_ids, obs, action, dones, terminateds, info + # These are runtime positional arguments, not config parameters + self._resolve_common_functor_cfg(functor_name, functor_cfg, min_argc=7) + + # Check if mode is a new mode + if functor_cfg.mode not in self._mode_functor_names: + # add new mode + self._mode_functor_names[functor_cfg.mode] = [] + self._mode_functor_cfgs[functor_cfg.mode] = [] + self._mode_class_functor_cfgs[functor_cfg.mode] = [] + + # Add functor name and parameters + self._mode_functor_names[functor_cfg.mode].append(functor_name) + self._mode_functor_cfgs[functor_cfg.mode].append(functor_cfg) + + # Check if the functor is a class + if inspect.isclass(functor_cfg.func): + self._mode_class_functor_cfgs[functor_cfg.mode].append(functor_cfg) diff --git a/embodichain/lab/gym/envs/managers/datasets.py b/embodichain/lab/gym/envs/managers/datasets.py new file mode 100644 index 00000000..09c0c1e2 --- /dev/null +++ b/embodichain/lab/gym/envs/managers/datasets.py @@ -0,0 +1,432 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Dataset functors for collecting and saving episode data.""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +import numpy as np +import torch + +from embodichain.utils import logger +from embodichain.data.constants import EMBODICHAIN_DEFAULT_DATASET_ROOT +from embodichain.lab.sim.types import EnvObs, EnvAction +from embodichain.lab.gym.utils.misc import is_stereocam +from embodichain.utils.utility import get_right_name +from embodichain.data.enum import JointType +from .manager_base import Functor +from .cfg import DatasetFunctorCfg + +if TYPE_CHECKING: + from embodichain.lab.gym.envs import EmbodiedEnv + +try: + from lerobot.datasets.lerobot_dataset import LeRobotDataset, HF_LEROBOT_HOME + + LEROBOT_AVAILABLE = True + + __all__ = ["LeRobotRecorder"] +except ImportError: + LEROBOT_AVAILABLE = False + + __all__ = [] + + +class LeRobotRecorder(Functor): + """Functor for recording episodes in LeRobot format. + + This functor handles: + - Recording observation-action pairs during episodes + - Converting data to LeRobot format + - Saving episodes when they complete + """ + + def __init__(self, cfg: DatasetFunctorCfg, env: EmbodiedEnv): + """Initialize the LeRobot dataset recorder. + + Args: + cfg: Functor configuration containing params: + - save_path: Root directory for saving datasets + - robot_meta: Robot metadata for dataset + - instruction: Optional task instruction + - extra: Optional extra metadata + - use_videos: Whether to save videos + - image_writer_threads: Number of threads for image writing + - image_writer_processes: Number of processes for image writing + - export_success_only: Whether to export only successful episodes + env: The environment instance + """ + super().__init__(cfg, env) + + # Extract parameters from cfg.params + params = cfg.params + + # Required parameters + self.lerobot_data_root = params.get( + "save_path", EMBODICHAIN_DEFAULT_DATASET_ROOT + ) + self.robot_meta = params.get("robot_meta", {}) + + # Optional parameters + self.instruction = params.get("instruction", None) + self.extra = params.get("extra", {}) + self.use_videos = params.get("use_videos", False) + self.export_success_only = params.get("export_success_only", False) + + # Episode data buffers + self.episode_obs_list: List[Dict] = [] + self.episode_action_list: List[Any] = [] + + # LeRobot dataset instance + self.dataset: Optional[LeRobotDataset] = None + self.dataset_full_path: Optional[Path] = None + + # Tracking + self.total_time: float = 0.0 + self.curr_episode: int = 0 + + # Initialize dataset + self._initialize_dataset() + + logger.log_info(f"LeRobotRecorder initialized at: {self.dataset_path}") + + @property + def dataset_path(self) -> str: + """Path to the dataset directory.""" + return ( + str(self.dataset_full_path) if self.dataset_full_path else "Not initialized" + ) + + def reset(self, env_ids: Optional[torch.Tensor] = None) -> None: + """Reset the recorder buffers. + + Args: + env_ids: Environment IDs to reset (currently clears all data). + """ + self._reset_buffer() + + def __call__( + self, + env: EmbodiedEnv, + env_ids: Union[torch.Tensor, None], + obs: EnvObs, + action: EnvAction, + dones: torch.Tensor, + terminateds: torch.Tensor, + info: Dict[str, Any], + save_path: Optional[str] = None, + id: Optional[str] = None, + robot_meta: Optional[Dict] = None, + instruction: Optional[str] = None, + extra: Optional[Dict] = None, + use_videos: bool = False, + export_success_only: bool = False, + ) -> None: + """Main entry point for the recorder functor. + + This method is called by DatasetManager.apply(mode="save") with runtime arguments + as positional parameters and configuration parameters from cfg.params. + + Args: + env: The environment instance. + env_ids: Environment IDs (for consistency with EventManager pattern). + obs: Observation from the environment. + action: Action applied to the environment. + dones: Boolean tensor indicating which envs completed episodes. + terminateds: Termination flags (success/fail). + info: Info dict containing success/fail information. + save_path: Root directory (already set in __init__). + id: Dataset identifier (already set in __init__). + robot_meta: Robot metadata (already set in __init__). + instruction: Task instruction (already set in __init__). + extra: Extra metadata (already set in __init__). + use_videos: Whether to save videos (already set in __init__). + export_success_only: Whether to export only successful episodes (already set in __init__). + """ + # Always record the step + self._record_step(obs, action) + + # Check if any episodes are done and save them + done_env_ids = dones.nonzero(as_tuple=False).squeeze(-1) + if len(done_env_ids) > 0: + # Save completed episodes + self._save_episodes(done_env_ids, terminateds, info) + + def _record_step(self, obs: EnvObs, action: EnvAction) -> None: + """Record a single step.""" + self.episode_obs_list.append(obs) + self.episode_action_list.append(action) + + def _save_episodes( + self, + env_ids: torch.Tensor, + terminateds: Optional[torch.Tensor] = None, + info: Optional[Dict[str, Any]] = None, + ) -> None: + """Save completed episodes.""" + if len(self.episode_obs_list) == 0: + logger.log_warning("No episode data to save") + return + + obs_list = self.episode_obs_list + action_list = self.episode_action_list + + # Align obs and action + if len(obs_list) > len(action_list): + obs_list = obs_list[:-1] + + task = self.instruction.get("lang", "unknown_task") + + # Update metadata + extra_info = self.extra.copy() if self.extra else {} + fps = self.dataset.meta.info.get("fps", 30) + current_episode_time = (len(obs_list) * len(env_ids)) / fps if fps > 0 else 0 + + episode_extra_info = extra_info.copy() + self.total_time += current_episode_time + episode_extra_info["total_time"] = self.total_time + self._update_dataset_info({"extra": episode_extra_info}) + + # Process each environment + for env_id in env_ids.cpu().tolist(): + is_success = False + if info is not None and "success" in info: + success_tensor = info["success"] + if isinstance(success_tensor, torch.Tensor): + is_success = success_tensor[env_id].item() + else: + is_success = success_tensor + elif terminateds is not None: + is_success = terminateds[env_id].item() + + logger.log_info(f"Episode {env_id} success: {is_success}") + if self.export_success_only and not is_success: + logger.log_info(f"Skipping failed episode for env {env_id}") + continue + + try: + for obs, action in zip(obs_list, action_list): + frame = self._convert_frame_to_lerobot(obs, action, task, env_id) + self.dataset.add_frame(frame) + + self.dataset.save_episode() + logger.log_info( + f"Auto-saved {'successful' if is_success else 'failed'} " + f"episode {self.curr_episode} for env {env_id} with {len(obs_list)} frames" + ) + self.curr_episode += 1 + except Exception as e: + logger.log_error(f"Failed to save episode {env_id}: {e}") + + self._reset_buffer() + + def finalize(self) -> Optional[str]: + """Finalize the dataset.""" + if len(self.episode_obs_list) > 0: + active_env_ids = torch.arange(self.num_envs, device=self.device) + self._save_episodes(active_env_ids) + + try: + if self.dataset is not None: + self.dataset.finalize() + logger.log_info(f"Dataset finalized at: {self.dataset_path}") + return self.dataset_path + except Exception as e: + logger.log_error(f"Failed to finalize dataset: {e}") + + return None + + def _reset_buffer(self) -> None: + """Reset episode buffers.""" + self.episode_obs_list.clear() + self.episode_action_list.clear() + logger.log_info("Reset buffers (cleared all batched data)") + + def _initialize_dataset(self) -> None: + """Initialize the LeRobot dataset.""" + robot_type = self.robot_meta.get("robot_type", "robot") + scene_type = self.extra.get("scene_type", "scene") + task_description = self.extra.get("task_description", "task") + + robot_type = str(robot_type).lower().replace(" ", "_") + task_description = str(task_description).lower().replace(" ", "_") + + # Use lerobot_data_root from __init__ + lerobot_data_root = Path(self.lerobot_data_root) + + # Generate dataset folder name with auto-incrementing suffix + base_name = f"{robot_type}_{scene_type}_{task_description}" + + # Find the next available sequence number by checking existing folders + existing_dirs = list(lerobot_data_root.glob(f"{base_name}_*")) + if not existing_dirs: + dataset_id = 0 + else: + # Extract sequence numbers from existing directories + max_id = -1 + for dir_path in existing_dirs: + suffix = dir_path.name[len(base_name) + 1 :] # +1 for underscore + if suffix.isdigit(): + max_id = max(max_id, int(suffix)) + dataset_id = max_id + 1 + + # Format dataset name with zero-padding (3 digits: 000, 001, 002, ...) + dataset_name = f"{base_name}_{dataset_id:03d}" + + # LeRobot's root parameter is the COMPLETE dataset path (not parent directory) + self.dataset_full_path = lerobot_data_root / dataset_name + + fps = self.robot_meta.get("control_freq", 30) + features = self._build_features() + + logger.log_info("------------------------------------------") + logger.log_info(f"Building dataset: {dataset_name}") + logger.log_info(f"Parent directory: {lerobot_data_root}") + logger.log_info(f"Full path: {self.dataset_full_path}") + + self.dataset = LeRobotDataset.create( + repo_id=dataset_name, + fps=fps, + root=str(self.dataset_full_path), + robot_type=robot_type, + features=features, + use_videos=self.use_videos, + ) + logger.log_info(f"Created LeRobot dataset at: {self.dataset_full_path}") + + def _build_features(self) -> Dict: + """Build LeRobot features dict.""" + features = {} + extra_vision_config = self.robot_meta.get("observation", {}).get("vision", {}) + + for camera_name in extra_vision_config.keys(): + sensor = self._env.get_sensor(camera_name) + is_stereo = is_stereocam(sensor) + img_shape = (sensor.cfg.height, sensor.cfg.width, 3) + + features[camera_name] = { + "dtype": "video" if self.use_videos else "image", + "shape": img_shape, + "names": ["height", "width", "channel"], + } + + if is_stereo: + features[get_right_name(camera_name)] = { + "dtype": "video" if self.use_videos else "image", + "shape": img_shape, + "names": ["height", "width", "channel"], + } + + qpos = self._env.robot.get_qpos() + state_dim = qpos.shape[1] + + if state_dim > 0: + features["observation.state"] = { + "dtype": "float32", + "shape": (state_dim,), + "names": ["state"], + } + + action_dim = self.robot_meta.get("arm_dofs", 7) + features["action"] = { + "dtype": "float32", + "shape": (action_dim,), + "names": ["action"], + } + + return features + + def _convert_frame_to_lerobot( + self, obs: Dict[str, Any], action: Any, task: str, env_id: int + ) -> Dict: + """Convert a single frame to LeRobot format.""" + frame = {"task": task} + extra_vision_config = self.robot_meta.get("observation", {}).get("vision", {}) + arm_dofs = self.robot_meta.get("arm_dofs", 7) + + # Add images + for camera_name in extra_vision_config.keys(): + if camera_name in obs.get("sensor", {}): + sensor = self._env.get_sensor(camera_name) + is_stereo = is_stereocam(sensor) + + color_data = obs["sensor"][camera_name]["color"] + if isinstance(color_data, torch.Tensor): + color_img = color_data[env_id][:, :, :3].cpu().numpy() + else: + color_img = np.array(color_data)[env_id][:, :, :3] + + if color_img.dtype in [np.float32, np.float64]: + color_img = (color_img * 255).astype(np.uint8) + + frame[camera_name] = color_img + + if is_stereo: + color_right_data = obs["sensor"][camera_name]["color_right"] + if isinstance(color_right_data, torch.Tensor): + color_right_img = ( + color_right_data[env_id][:, :, :3].cpu().numpy() + ) + else: + color_right_img = np.array(color_right_data)[env_id][:, :, :3] + + if color_right_img.dtype in [np.float32, np.float64]: + color_right_img = (color_right_img * 255).astype(np.uint8) + + frame[get_right_name(camera_name)] = color_right_img + + # Add state + qpos = obs["robot"][JointType.QPOS.value] + if isinstance(qpos, torch.Tensor): + state_data = qpos[env_id].cpu().numpy().astype(np.float32) + else: + state_data = np.array(qpos)[env_id].astype(np.float32) + + frame["observation.state"] = state_data + + # Add action + if isinstance(action, torch.Tensor): + action_data = action[env_id, :arm_dofs].cpu().numpy() + elif isinstance(action, np.ndarray): + action_data = action[env_id, :arm_dofs] + elif isinstance(action, dict): + action_data = action.get("action", action.get("arm_action", action)) + if isinstance(action_data, torch.Tensor): + action_data = action_data[env_id, :arm_dofs].cpu().numpy() + elif isinstance(action_data, np.ndarray): + action_data = action_data[env_id, :arm_dofs] + else: + action_data = np.array(action)[env_id, :arm_dofs] + + frame["action"] = action_data + + return frame + + def _update_dataset_info(self, updates: dict) -> bool: + """Update dataset metadata.""" + if self.dataset is None: + logger.log_error("LeRobotDataset not initialized.") + return False + + try: + self.dataset.meta.info.update(updates) + return True + except Exception as e: + logger.log_error(f"Failed to update dataset info: {e}") + return False diff --git a/embodichain/lab/gym/envs/managers/object/geometry.py b/embodichain/lab/gym/envs/managers/object/geometry.py index db26b18f..1201dfb0 100644 --- a/embodichain/lab/gym/envs/managers/object/geometry.py +++ b/embodichain/lab/gym/envs/managers/object/geometry.py @@ -47,7 +47,7 @@ def get_pcd_svd_frame(pc: torch.Tensor) -> torch.Tensor: pc_centered = pc - pc_center u, s, vt = torch.linalg.svd(pc_centered) rotation = vt.T - pc_pose = torch.eye(4, dtype=torch.float32) + pc_pose = torch.eye(4, dtype=torch.float32, device=pc.device) pc_pose[:3, :3] = rotation pc_pose[:3, 3] = pc_center return pc_pose @@ -90,7 +90,7 @@ def apply_svd_transfer_pcd( standard_verts = [] for object_verts in verts: pc_svd_frame = get_pcd_svd_frame(object_verts) - inv_svd_frame = inv_transform(pc_svd_frame) + inv_svd_frame = torch.linalg.inv(pc_svd_frame) standard_object_verts = ( object_verts @ inv_svd_frame[:3, :3].T + inv_svd_frame[:3, 3] ) @@ -175,7 +175,7 @@ def compute_object_length( ) pcs = rigid_object.get_vertices(env_ids) body_scale = rigid_object.get_body_scale(env_ids) - scaled_pcs = pcs * body_scale + scaled_pcs = pcs * body_scale.unsqueeze(1) if is_svd_frame: scaled_pcs = apply_svd_transfer_pcd(scaled_pcs, sample_points) diff --git a/embodichain/lab/gym/envs/tasks/special/simple_task.py b/embodichain/lab/gym/envs/tasks/special/simple_task.py new file mode 100644 index 00000000..5c9c879d --- /dev/null +++ b/embodichain/lab/gym/envs/tasks/special/simple_task.py @@ -0,0 +1,88 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch + +from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg +from embodichain.lab.gym.utils.registration import register_env +from embodichain.utils import logger + +__all__ = ["SimpleTaskEnv"] + + +@register_env("SimpleTask-v1", max_episode_steps=600) +class SimpleTaskEnv(EmbodiedEnv): + """A demo environment with sinusoidal trajectory + + Args: + EmbodiedEnv (_type_): _description_ + """ + + def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): + super().__init__(cfg, **kwargs) + + def create_demo_action_list(self, *args, **kwargs): + """ + Create a demonstration action list for the current task. + + This demo creates a simple sinusoidal trajectory for the robot joints. + + Returns: + list: A list of demo actions generated by the task. + """ + action_list = [] + num_steps = 100 + + # Get initial pose + init_pose = self.robot.get_qpos() # shape: (num_envs, num_joints) + + # Create a sinusoidal trajectory + for i in range(num_steps): + # Calculate phase for sinusoidal motion + t = i / num_steps # 0 to 1 + phase = torch.full( + (init_pose.shape[0],), t * 2 * torch.pi, device=self.device + ) # repeat for num_envs + + # Create sinusoidal offsets for each joint + # Joint 0: horizontal movement + # Joint 1: vertical movement + # Other joints: smaller oscillations + offset = torch.zeros_like( + init_pose, dtype=torch.float32, device=self.device + ) + offset[:, 0] = torch.sin(phase) * 0.3 # ±0.3 rad + offset[:, 1] = torch.cos(phase) * 0.2 # ±0.2 rad + offset[:, 2] = torch.sin(phase * 2) * 0.1 # ±0.1 rad, double frequency + + # Add small random variation to make it more natural + noise = (torch.rand_like(init_pose, device=self.device) - 0.5) * 0.02 + + # Compute action + action = init_pose + offset + noise + + # Clamp to joint limits if available + if hasattr(self.robot.body_data, "qpos_limits"): + qpos_limits = self.robot.body_data.qpos_limits[0] # (num_joints, 2) + action = torch.clamp(action, qpos_limits[:, 0], qpos_limits[:, 1]) + + action_list.append(action) + + logger.log_info( + f"Generated {len(action_list)} demo actions with sinusoidal trajectory" + ) + self.action_length = len(action_list) + return action_list diff --git a/embodichain/lab/gym/envs/tasks/tableware/pour_water/pour_water.py b/embodichain/lab/gym/envs/tasks/tableware/pour_water/pour_water.py index 433b76de..4d575b93 100644 --- a/embodichain/lab/gym/envs/tasks/tableware/pour_water/pour_water.py +++ b/embodichain/lab/gym/envs/tasks/tableware/pour_water/pour_water.py @@ -57,6 +57,7 @@ def create_demo_action_list(self, *args, **kwargs): logger.log_info( f"Demo action list created with {len(action_list)} steps.", color="green" ) + self.action_length = len(action_list) return action_list def create_expert_demo_action_list(self, **kwargs): diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index 0f92abb7..ebe06d6e 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -364,6 +364,7 @@ def config_to_cfg(config: dict) -> "EmbodiedEnvCfg": SceneEntityCfg, EventCfg, ObservationCfg, + DatasetFunctorCfg, ) from embodichain.utils import configclass from embodichain.data import get_data_path @@ -453,7 +454,32 @@ class ComponentCfg: env_cfg.sim_steps_per_control = config["env"].get("sim_steps_per_control", 4) # load dataset config - env_cfg.dataset = config["env"].get("dataset", None) + env_cfg.dataset = ComponentCfg() + if "dataset" in config["env"]: + # Define modules to search for dataset functions + dataset_modules = [ + "embodichain.lab.gym.envs.managers.datasets", + ] + + for dataset_name, dataset_params in config["env"]["dataset"].items(): + dataset_params_modified = deepcopy(dataset_params) + + # Find the function from multiple modules using the utility function + dataset_func = find_function_from_modules( + dataset_params["func"], + dataset_modules, + raise_if_not_found=True, + ) + + from embodichain.lab.gym.envs.managers import DatasetFunctorCfg + + dataset = DatasetFunctorCfg( + func=dataset_func, + mode=dataset_params_modified["mode"], + params=dataset_params_modified["params"], + ) + + setattr(env_cfg.dataset, dataset_name, dataset) # TODO: support more env events, eg, grasp pose generation, mesh preprocessing, etc. diff --git a/embodichain/lab/gym/utils/misc.py b/embodichain/lab/gym/utils/misc.py index b669e6cf..b75b70af 100644 --- a/embodichain/lab/gym/utils/misc.py +++ b/embodichain/lab/gym/utils/misc.py @@ -1367,3 +1367,18 @@ def is_eef_hand(robot, control_parts) -> bool: if "gripper" in data_key and is_eef_hand(robot, control_parts) is False: return "right_eef" return None + + +def is_stereocam(sensor) -> bool: + """ + Check if a sensor is a StereoCamera (binocular camera). + + Args: + sensor: The sensor instance to check. + + Returns: + bool: True if the sensor is a StereoCamera, False otherwise. + """ + from embodichain.lab.sim.sensors import StereoCamera + + return isinstance(sensor, StereoCamera) diff --git a/embodichain/lab/gym/utils/registration.py b/embodichain/lab/gym/utils/registration.py index b2fe5b62..9c52e103 100644 --- a/embodichain/lab/gym/utils/registration.py +++ b/embodichain/lab/gym/utils/registration.py @@ -18,6 +18,8 @@ import json import sys +import torch + from copy import deepcopy from functools import partial from typing import TYPE_CHECKING, Dict, Type @@ -107,6 +109,14 @@ def __init__(self, env: gym.Env, max_episode_steps: int): def base_env(self) -> BaseEnv: return self.env.unwrapped + @property + def device(self) -> torch.device: + return self.base_env.device + + @property + def num_envs(self) -> int: + return self.base_env.num_envs + def step(self, action): observation, reward, terminated, truncated, info = self.env.step(action) truncated = truncated | (self.base_env.elapsed_steps >= self._max_episode_steps) diff --git a/embodichain/lab/scripts/run_env.py b/embodichain/lab/scripts/run_env.py index 1ad5318f..2268be0f 100644 --- a/embodichain/lab/scripts/run_env.py +++ b/embodichain/lab/scripts/run_env.py @@ -34,7 +34,7 @@ def generate_and_execute_action_list(env, idx, debug_mode): - action_list = env.create_demo_action_list(action_sentence=idx) + action_list = env.get_wrapper_attr("create_demo_action_list")(action_sentence=idx) if action_list is None or len(action_list) == 0: log_warning("Action is invalid. Skip to next generation.") @@ -44,10 +44,9 @@ def generate_and_execute_action_list(env, idx, debug_mode): action_list, desc=f"Executing action list #{idx}", unit="step" ): # Step the environment with the current action + # The environment will automatically detect truncation based on action_length obs, reward, terminated, truncated, info = env.step(action) - # TODO: May be add some functions for debug_mode - # TODO: We may assume in export demonstration rollout, there is no truncation from the env. # but truncation is useful to improve the generation efficiency. @@ -83,24 +82,33 @@ def generate_function( """ valid = True + _, _ = env.reset() while True: - _, _ = env.reset() - + ret = [] for trajectory_idx in range(num_traj): valid = generate_and_execute_action_list(env, trajectory_idx, debug_mode) if not valid: + _, _ = env.reset() break - if not debug_mode and env.is_task_success().item(): - pass - - # TODO: Add data saving and online data streaming logic here. - + # Check task success for all environments + if not debug_mode: + success = env.get_wrapper_attr("is_task_success")() + # For multiple environments, check if all succeeded + all_success = ( + success.all().item() if success.numel() > 1 else success.item() + ) + if all_success: + pass + # TODO: Add data saving and online data streaming logic here. + else: + log_warning(f"Task fail, Skip to next generation.") + valid = False + break else: - log_warning(f"Task fail, Skip to next generation.") - valid = False - break + # In debug mode, skip success check + pass if valid: break @@ -188,8 +196,8 @@ def main(args, env, gym_config): args = parser.parse_args() - if args.num_envs != 1: - log_error(f"Currently only support num_envs=1, but got {args.num_envs}.") + # if args.num_envs != 1: + # log_error(f"Currently only support num_envs=1, but got {args.num_envs}.") gym_config = load_json(args.gym_config) cfg: EmbodiedEnvCfg = config_to_cfg(gym_config) diff --git a/embodichain/lab/sim/utility/workspace_analyzer/__init__.py b/embodichain/lab/sim/utility/workspace_analyzer/__init__.py index 57074086..a2b3e80a 100644 --- a/embodichain/lab/sim/utility/workspace_analyzer/__init__.py +++ b/embodichain/lab/sim/utility/workspace_analyzer/__init__.py @@ -53,9 +53,6 @@ from embodichain.lab.sim.utility.workspace_analyzer import visualizers from embodichain.lab.sim.utility.workspace_analyzer import metrics from embodichain.lab.sim.utility.workspace_analyzer import constraints -from embodichain.lab.sim.utility.workspace_analyzer.workspace_sampler import ( - sample_circular_plane_reachability, -) __all__ = [ "WorkspaceAnalyzer", @@ -67,5 +64,4 @@ "visualizers", "metrics", "constraints", - "sample_circular_plane_reachability", ] diff --git a/embodichain/lab/sim/utility/workspace_analyzer/constraints/__init__.py b/embodichain/lab/sim/utility/workspace_analyzer/constraints/__init__.py index b6df4e52..9c02f032 100644 --- a/embodichain/lab/sim/utility/workspace_analyzer/constraints/__init__.py +++ b/embodichain/lab/sim/utility/workspace_analyzer/constraints/__init__.py @@ -16,11 +16,9 @@ from .base_constraint import BaseConstraintChecker, IConstraintChecker from .workspace_constraint import WorkspaceConstraintChecker -from .circular_constraint import CircularConstraintChecker __all__ = [ "BaseConstraintChecker", "IConstraintChecker", "WorkspaceConstraintChecker", - "CircularConstraintChecker", ] diff --git a/pyproject.toml b/pyproject.toml index ad6d6d2e..f1381090 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dynamic = ["version"] dependencies = [ "dexsim_engine==0.3.8", "setuptools>=78.1.1", - "gymnasium==0.29.1", + "gymnasium>=0.29.1", "casadi==3.7.1", "pin==2.7.0", "toppra==0.6.3", @@ -38,7 +38,7 @@ dependencies = [ "pytorch_kinematics==0.7.6", "polars==1.31.0", "PyYAML>=6.0", - "accelerate==1.2.1", + "accelerate>=1.10.0", "wandb==0.20.1", "tensorboard", "transformers>=4.53.0", @@ -51,6 +51,11 @@ dependencies = [ "h5py", ] +[project.optional-dependencies] +lerobot = [ + "lerobot==0.4.2" +] + [tool.setuptools.dynamic] version = { file = ["VERSION"] } diff --git a/scripts/tutorials/gym/random_reach.py b/scripts/tutorials/gym/random_reach.py index a6e4ed0a..d9912237 100644 --- a/scripts/tutorials/gym/random_reach.py +++ b/scripts/tutorials/gym/random_reach.py @@ -144,7 +144,7 @@ def _extend_obs(self, obs: EnvObs, **kwargs) -> EnvObs: action = env.action_space.sample() action = torch.as_tensor(action, dtype=torch.float32, device=env.device) - init_pose = env.robot_init_qpos + init_pose = env.unwrapped.robot_init_qpos init_pose = ( torch.as_tensor(init_pose, dtype=torch.float32, device=env.device) .unsqueeze_(0) diff --git a/scripts/tutorials/sim/create_sensor.py b/scripts/tutorials/sim/create_sensor.py index 52db7b87..1ae8f5b6 100644 --- a/scripts/tutorials/sim/create_sensor.py +++ b/scripts/tutorials/sim/create_sensor.py @@ -247,7 +247,7 @@ def get_sensor_image(camera: Camera, headless=False, step_count=0): normals = data["normal"].cpu().numpy()[0] # (H, W, 3) # Normalize for visualization - depth_vis = (depth - depth.min()) / (depth.ptp() + 1e-8) + depth_vis = (depth - depth.min()) / (np.ptp(depth) + 1e-8) depth_vis = (depth_vis * 255).astype("uint8") mask_vis = mask_to_color_map(mask, user_ids=np.unique(mask)) normals_vis = ((normals + 1) / 2 * 255).astype("uint8") diff --git a/tests/gym/envs/test_base_env.py b/tests/gym/envs/test_base_env.py index e9d1f3ab..bed26a71 100644 --- a/tests/gym/envs/test_base_env.py +++ b/tests/gym/envs/test_base_env.py @@ -136,7 +136,7 @@ def test_env_rollout(self): action, dtype=torch.float32, device=self.env.device ) - init_pose = self.env.robot_init_qpos + init_pose = self.env.get_wrapper_attr("robot_init_qpos") init_pose = ( torch.as_tensor( init_pose, dtype=torch.float32, device=self.env.device diff --git a/tests/gym/envs/test_embodied_env.py b/tests/gym/envs/test_embodied_env.py index 29ea0d85..574fd60c 100644 --- a/tests/gym/envs/test_embodied_env.py +++ b/tests/gym/envs/test_embodied_env.py @@ -135,20 +135,22 @@ def test_env_rollout(self): for i in range(2): action = self.env.action_space.sample() action = torch.as_tensor( - action, dtype=torch.float32, device=self.env.device + action, + dtype=torch.float32, + device=self.env.get_wrapper_attr("device"), ) obs, reward, done, truncated, info = self.env.step(action) assert reward.shape == ( - self.env.num_envs, - ), f"Expected reward shape ({self.env.num_envs},), got {reward.shape}" + self.env.get_wrapper_attr("num_envs"), + ), f"Expected reward shape ({self.env.get_wrapper_attr('num_envs')},), got {reward.shape}" assert done.shape == ( - self.env.num_envs, - ), f"Expected done shape ({self.env.num_envs},), got {done.shape}" + self.env.get_wrapper_attr("num_envs"), + ), f"Expected done shape ({self.env.get_wrapper_attr('num_envs')},), got {done.shape}" assert truncated.shape == ( - self.env.num_envs, - ), f"Expected truncated shape ({self.env.num_envs},), got {truncated.shape}" + self.env.get_wrapper_attr("num_envs"), + ), f"Expected truncated shape ({self.env.get_wrapper_attr('num_envs')},), got {truncated.shape}" assert obs.get("robot") is not None, "Expected 'robot' info in the info dict" def teardown_method(self):