diff --git a/src/lerobot/__init__.py b/src/lerobot/__init__.py index 9d3ed1893e..1b3fc75128 100644 --- a/src/lerobot/__init__.py +++ b/src/lerobot/__init__.py @@ -183,6 +183,7 @@ available_cameras = [ "opencv", "intelrealsense", + "zed", ] # lists all available motors from `lerobot/motors` diff --git a/src/lerobot/cameras/__init__.py b/src/lerobot/cameras/__init__.py index 1488cd89ea..ffca9580bb 100644 --- a/src/lerobot/cameras/__init__.py +++ b/src/lerobot/cameras/__init__.py @@ -15,3 +15,22 @@ from .camera import Camera from .configs import CameraConfig, ColorMode, Cv2Rotation from .utils import make_cameras_from_configs + +from .camera import Camera +from .configs import CameraConfig, ColorMode, Cv2Rotation +from .utils import make_cameras_from_configs + +from .opencv.configuration_opencv import OpenCVCameraConfig +from .realsense.configuration_realsense import RealSenseCameraConfig +from .zed.configuration_zed import ZedCameraConfig + +__all__ = [ + "Camera", + "CameraConfig", + "ColorMode", + "Cv2Rotation", + "make_cameras_from_configs", + "OpenCVCameraConfig", + "RealSenseCameraConfig", + "ZedCameraConfig", +] \ No newline at end of file diff --git a/src/lerobot/cameras/camera.py b/src/lerobot/cameras/camera.py index e435c7309a..9fd2162590 100644 --- a/src/lerobot/cameras/camera.py +++ b/src/lerobot/cameras/camera.py @@ -89,28 +89,34 @@ def connect(self, warmup: bool = True) -> None: pass @abc.abstractmethod - def read(self, color_mode: ColorMode | None = None) -> np.ndarray: + def read(self, color_mode: ColorMode | None = None) -> dict[str, np.ndarray]: """Capture and return a single frame from the camera. Args: - color_mode: Desired color mode for the output frame. If None, - uses the camera's default color mode. + color_mode: Desired color mode for the output frame. If None, uses the camera's default color mode. Returns: - np.ndarray: Captured frame as a numpy array. + dict[str, np.ndarray]: Dictionary with modality keys and image data. + The keys are automatically determined based on array shape: + - 'image': For 2D arrays (H, W) - grayscale images + - 'image': For 3D arrays (H, W, 1) - grayscale images + - 'image': For 3D arrays (H, W, 3) - RGB/BGR images + - 'image': For 3D arrays (H, W, 4) - RGBA images + - 'depth': For depth maps (H, W) + Additional modality-specific keys may be provided by specific cameras. """ pass @abc.abstractmethod - def async_read(self, timeout_ms: float = ...) -> np.ndarray: + def async_read(self, timeout_ms: float = ...) -> dict[str, np.ndarray]: """Asynchronously capture and return a single frame from the camera. Args: - timeout_ms: Maximum time to wait for a frame in milliseconds. - Defaults to implementation-specific timeout. + timeout_ms: Maximum time to wait for a frame in milliseconds. Defaults to implementation-specific timeout. Returns: - np.ndarray: Captured frame as a numpy array. + dict[str, np.ndarray]: Dictionary mapping modality keys to image data. + See the read() method for detailed documentation on modality keys and image formats. """ pass diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py index 50e55f0c22..acf968357b 100644 --- a/src/lerobot/cameras/opencv/camera_opencv.py +++ b/src/lerobot/cameras/opencv/camera_opencv.py @@ -34,7 +34,7 @@ from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..camera import Camera -from ..utils import get_cv2_backend, get_cv2_rotation +from ..utils import get_cv2_backend, get_cv2_rotation, get_image_modality_key from .configuration_opencv import ColorMode, OpenCVCameraConfig # NOTE(Steven): The maximum opencv device index depends on your operating system. For instance, @@ -289,7 +289,7 @@ def find_cameras() -> list[dict[str, Any]]: return found_cameras_info - def read(self, color_mode: ColorMode | None = None) -> np.ndarray: + def read(self, color_mode: ColorMode | None = None) -> dict[str, np.ndarray]: """ Reads a single frame synchronously from the camera. @@ -301,11 +301,6 @@ def read(self, color_mode: ColorMode | None = None) -> np.ndarray: color mode (`self.color_mode`) for this read operation (e.g., request RGB even if default is BGR). - Returns: - np.ndarray: The captured frame as a NumPy array in the format - (height, width, channels), using the specified or default - color mode and applying any configured rotation. - Raises: DeviceNotConnectedError: If the camera is not connected. RuntimeError: If reading the frame from the camera fails or if the @@ -327,7 +322,8 @@ def read(self, color_mode: ColorMode | None = None) -> np.ndarray: read_duration_ms = (time.perf_counter() - start_time) * 1e3 logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") - return processed_frame + image_modality_key = get_image_modality_key(image=processed_frame) + return {image_modality_key: processed_frame} def _postprocess_image(self, image: np.ndarray, color_mode: ColorMode | None = None) -> np.ndarray: """ @@ -419,7 +415,7 @@ def _stop_read_thread(self) -> None: self.thread = None self.stop_event = None - def async_read(self, timeout_ms: float = 200) -> np.ndarray: + def async_read(self, timeout_ms: float = 200) -> dict[str, np.ndarray]: """ Reads the latest available frame asynchronously. @@ -432,7 +428,7 @@ def async_read(self, timeout_ms: float = 200) -> np.ndarray: to become available. Defaults to 200ms (0.2 seconds). Returns: - np.ndarray: The latest captured frame as a NumPy array in the format + dict[str, np.ndarray]: A map of the latest captured frames as a NumPy array in the format (height, width, channels), processed according to configuration. Raises: diff --git a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py index c96789f969..46249ba557 100644 --- a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py +++ b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py @@ -23,6 +23,8 @@ from threading import Event, Lock, Thread from typing import Any +from ..utils import get_image_modality_key + # Fix MSMF hardware transform compatibility for Windows before importing cv2 if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ: os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0" @@ -131,7 +133,7 @@ def find_cameras(ip_address: str = "localhost", port: int = 50065) -> list[dict[ camera_manager.disconnect() return initialized_cameras - def read(self, color_mode: ColorMode | None = None) -> np.ndarray: + def read(self, color_mode: ColorMode | None = None) -> dict[str, np.ndarray]: """ Reads a single frame synchronously from the camera. @@ -143,7 +145,7 @@ def read(self, color_mode: ColorMode | None = None) -> np.ndarray: request RGB even if default is BGR). Returns: - np.ndarray: The captured frame as a NumPy array in the format + dict[str, np.ndarray]: The captured frame as a NumPy array in the format (height, width, channels), using the specified or default color mode and applying any configured rotation. """ @@ -177,7 +179,8 @@ def read(self, color_mode: ColorMode | None = None) -> np.ndarray: read_duration_ms = (time.perf_counter() - start_time) * 1e3 logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") - return frame + image_modality_key = get_image_modality_key(image=frame) + return {image_modality_key: frame} def _read_loop(self): """ @@ -226,7 +229,7 @@ def _stop_read_thread(self) -> None: self.thread = None self.stop_event = None - def async_read(self, timeout_ms: float = 200) -> np.ndarray: + def async_read(self, timeout_ms: float = 200) -> dict[str, np.ndarray]: """ Reads the latest available frame asynchronously. @@ -239,7 +242,7 @@ def async_read(self, timeout_ms: float = 200) -> np.ndarray: to become available. Defaults to 200ms (0.2 seconds). Returns: - np.ndarray: The latest captured frame as a NumPy array in the format + dict[str, np.ndarray]: A map of the latest captured frames as a NumPy array in the format (height, width, channels), processed according to configuration. Raises: @@ -267,7 +270,8 @@ def async_read(self, timeout_ms: float = 200) -> np.ndarray: if frame is None: raise RuntimeError(f"Internal error: Event set but no frame available for {self}.") - return frame + image_modality_key = get_image_modality_key(image=frame) + return {image_modality_key: frame} def disconnect(self): """ diff --git a/src/lerobot/cameras/realsense/camera_realsense.py b/src/lerobot/cameras/realsense/camera_realsense.py index cc816e5525..86e490553d 100644 --- a/src/lerobot/cameras/realsense/camera_realsense.py +++ b/src/lerobot/cameras/realsense/camera_realsense.py @@ -33,7 +33,7 @@ from ..camera import Camera from ..configs import ColorMode -from ..utils import get_cv2_rotation +from ..utils import get_cv2_rotation, get_image_modality_key from .configuration_realsense import RealSenseCameraConfig logger = logging.getLogger(__name__) @@ -351,7 +351,7 @@ def read_depth(self, timeout_ms: int = 200) -> np.ndarray: return depth_map_processed - def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> np.ndarray: + def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> dict[str, np.ndarray]: """ Reads a single frame (color) synchronously from the camera. @@ -362,7 +362,7 @@ def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> np timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms. Returns: - np.ndarray: The captured color frame as a NumPy array + dict[str, np.ndarray]: A map of the captured color frame as a NumPy array (height, width, channels), processed according to `color_mode` and rotation. Raises: @@ -389,7 +389,8 @@ def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> np read_duration_ms = (time.perf_counter() - start_time) * 1e3 logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") - return color_image_processed + image_modality_key = get_image_modality_key(image=frame) + return {image_modality_key: frame} def _postprocess_image( self, image: np.ndarray, color_mode: ColorMode | None = None, depth_frame: bool = False @@ -486,7 +487,7 @@ def _stop_read_thread(self): self.stop_event = None # NOTE(Steven): Missing implementation for depth for now - def async_read(self, timeout_ms: float = 200) -> np.ndarray: + def async_read(self, timeout_ms: float = 200) -> dict[str, np.ndarray]: """ Reads the latest available frame data (color) asynchronously. @@ -499,8 +500,8 @@ def async_read(self, timeout_ms: float = 200) -> np.ndarray: to become available. Defaults to 200ms (0.2 seconds). Returns: - np.ndarray: - The latest captured frame data (color image), processed according to configuration. + dict[str, np.ndarray]: + A map of the latest captured frame data (color image), processed according to configuration. Raises: DeviceNotConnectedError: If the camera is not connected. @@ -527,7 +528,8 @@ def async_read(self, timeout_ms: float = 200) -> np.ndarray: if frame is None: raise RuntimeError(f"Internal error: Event set but no frame available for {self}.") - return frame + image_modality_key = get_image_modality_key(image=frame) + return {image_modality_key: frame} def disconnect(self): """ diff --git a/src/lerobot/cameras/utils.py b/src/lerobot/cameras/utils.py index aa6ff98b48..65dbe3daa2 100644 --- a/src/lerobot/cameras/utils.py +++ b/src/lerobot/cameras/utils.py @@ -17,6 +17,8 @@ import platform from typing import cast +import numpy as np + from lerobot.utils.import_utils import make_device_from_device_class from .camera import Camera @@ -43,6 +45,11 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s cameras[key] = Reachy2Camera(cfg) + elif cfg.type == "zed": + from .zed import ZedCamera + + cameras[key] = ZedCamera(cfg) + else: try: cameras[key] = cast(Camera, make_device_from_device_class(cfg)) @@ -74,3 +81,47 @@ def get_cv2_backend() -> int: # return cv2.CAP_AVFOUNDATION else: # Linux and others return cv2.CAP_ANY + + +def get_image_modality_key(image: np.ndarray, is_depth: bool = False) -> str: + """ + Determine the modality key based on image array shape. + + Args: + image: Image array from the camera. Should be a numpy array. + is_depth: If True, explicitly treat as depth modality regardless of shape. + + Returns: + str: Modality key indicating the image type: + - 'depth': Depth maps + - 'gray': Grayscale images (H, W) or (H, W, 1) + - 'rgb': RGB images (H, W, 3) + - 'rgba': RGBA images (H, W, 4) + - 'unknown': For unsupported array shapes + + Raises: + ValueError: If image is not a numpy array + + Example: + >>> get_image_modality_key(np.zeros((480, 640))) # 'gray' + >>> get_image_modality_key(np.zeros((480, 640, 3))) # 'rgb' + >>> get_image_modality_key(np.zeros((480, 640)), is_depth=True) # 'depth' + """ + if not isinstance(image, np.ndarray): + raise ValueError(f"Expected numpy array, got {type(image)}") + + if is_depth: + return "depth" + + if len(image.shape) == 2: + return "gray" + elif len(image.shape) == 3: + channels = image.shape[2] + if channels == 1: + return "gray" + elif channels == 3: + return "rgb" + elif channels == 4: + return "rgba" + + return "unknown" diff --git a/src/lerobot/cameras/zed/__init__.py b/src/lerobot/cameras/zed/__init__.py new file mode 100644 index 0000000000..5a5b95c64b --- /dev/null +++ b/src/lerobot/cameras/zed/__init__.py @@ -0,0 +1,2 @@ +from .camera_zed import ZedCamera +from .configuration_zed import ZedCameraConfig diff --git a/src/lerobot/cameras/zed/camera_zed.py b/src/lerobot/cameras/zed/camera_zed.py new file mode 100644 index 0000000000..dc23eaae68 --- /dev/null +++ b/src/lerobot/cameras/zed/camera_zed.py @@ -0,0 +1,549 @@ +""" +Provides the ZedCamera class for capturing frames from ZED stereo cameras. +""" + +import logging +import time +from threading import Event, Lock, Thread +from typing import Any + +import cv2 +import numpy as np + +import pyzed.sl as sl + +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +from ..camera import Camera +from ..configs import ColorMode +from ..utils import get_cv2_rotation, get_image_modality_key +from .configuration_zed import ZedCameraConfig + +logger = logging.getLogger(__name__) + + +class ZedCamera(Camera): + """ + Manages interactions with ZED stereo cameras for frame and depth recording. + + This class provides an interface for ZED cameras, leveraging the `pyzed.sl` library. + It uses the camera's unique serial number for identification. ZED cameras support + high-quality depth sensing and various resolutions. + + Use the provided utility script to find available camera indices and default profiles: + ```bash + lerobot-find-cameras zed + ``` + + A `ZedCamera` instance requires a configuration object specifying the + camera's serial number or a unique device name. + + Example: + ```python + from lerobot.cameras.zed import ZedCamera, ZedCameraConfig + from lerobot.cameras import ColorMode, Cv2Rotation + + # Basic usage with serial number + config = ZedCameraConfig(serial_number_or_name="0123456789") # Replace with actual SN + camera = ZedCamera(config) + camera.connect() + + # Read 1 frame synchronously + color_image = camera.read() + print(color_image.shape) + + # Read 1 depth frame + depth_map = camera.read_depth() + + # When done, properly disconnect the camera. + camera.disconnect() + + # Example with custom settings + custom_config = ZedCameraConfig( + serial_number_or_name="0123456789", + fps=30, + width=1280, + height=720, + color_mode=ColorMode.BGR, + rotation=Cv2Rotation.ROTATE_90, + use_depth=True, + depth_mode="NEURAL" + ) + depth_camera = ZedCamera(custom_config) + depth_camera.connect() + ``` + """ + + def __init__(self, config: ZedCameraConfig): + """ + Initializes the ZedCamera instance. + + Args: + config: The configuration settings for the camera. + """ + + super().__init__(config) + + self.config = config + + if config.serial_number_or_name.isdigit(): + self.serial_number = config.serial_number_or_name + else: + self.serial_number = self._find_serial_number_from_name(config.serial_number_or_name) + + self.fps = config.fps + self.color_mode = config.color_mode + self.use_depth = config.use_depth + self.warmup_s = config.warmup_s + self.depth_mode = config.depth_mode + + self.zed_camera: sl.Camera | None = None + self.runtime_params: sl.RuntimeParameters | None = None + self.mat_resolution: sl.Resolution | None = None + + self.thread: Thread | None = None + self.stop_event: Event | None = None + self.frame_lock: Lock = Lock() + self.latest_frame: dict[str, np.ndarray] | None = None + self.new_frame_event: Event = Event() + + self.rotation: int | None = get_cv2_rotation(config.rotation) + + if self.height and self.width: + self.capture_width, self.capture_height = self.width, self.height + if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]: + self.capture_width, self.capture_height = self.height, self.width + + # ZED specific attributes + self.image_mat = sl.Mat() + self.depth_mat = sl.Mat() + self.point_cloud_mat = sl.Mat() + + def __str__(self) -> str: + return f"{self.__class__.__name__}({self.serial_number})" + + @property + def is_connected(self) -> bool: + """Checks if the ZED camera is opened and streaming.""" + return self.zed_camera is not None and self.zed_camera.is_opened() + + def connect(self, warmup: bool = True): + """ + Connects to the ZED camera specified in the configuration. + + Initializes the ZED camera, configures the required parameters, + starts the camera, and validates the actual stream settings. + + Raises: + DeviceAlreadyConnectedError: If the camera is already connected. + ValueError: If the configuration is invalid (e.g., missing serial/name, name not unique). + ConnectionError: If the camera is found but fails to open or no ZED devices are detected. + RuntimeError: If the camera starts but fails to apply requested settings. + """ + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} is already connected.") + + # Create ZED camera object + self.zed_camera = sl.Camera() + + # Set initialization parameters + init_params = sl.InitParameters() + init_params.camera_resolution = sl.RESOLUTION.HD720 # Default, can be overridden + init_params.camera_fps = self.fps or 30 + init_params.depth_mode = self._get_zed_depth_mode() + init_params.coordinate_units = sl.UNIT.METER + init_params.coordinate_system = sl.COORDINATE_SYSTEM.RIGHT_HANDED_Y_UP + init_params.set_from_serial_number(int(self.serial_number)) + + # Set depth minimum and maximum range in meters:cite[4] + init_params.depth_minimum_distance = 0.3 + init_params.depth_maximum_distance = 20 + + # Open the camera + err = self.zed_camera.open(init_params) + if err != sl.ERROR_CODE.SUCCESS: + self.zed_camera = None + raise ConnectionError( + f"Failed to open {self}. Error code: {err}. " + f"Run `lerobot-find-cameras zed` to find available cameras." + ) + + # Configure runtime parameters + self.runtime_params = sl.RuntimeParameters(enable_depth=True) + + # Set mat resolution based on configuration + self._configure_mat_resolution() + + if warmup: + # ZED cameras need longer warmup time:cite[4] + time.sleep(self.warmup_s) + start_time = time.time() + while time.time() - start_time < self.warmup_s: + self.read() + time.sleep(0.1) + + logger.info(f"{self} connected.") + + def _get_zed_depth_mode(self) -> sl.DEPTH_MODE: + """Converts depth mode string to ZED depth mode enum.""" + if not self.use_depth: + return sl.DEPTH_MODE.NONE + + mode_map = { + "QUALITY": sl.DEPTH_MODE.QUALITY, + "ULTRA": sl.DEPTH_MODE.ULTRA, + "NEURAL": sl.DEPTH_MODE.NEURAL + } + return mode_map.get(self.depth_mode, sl.DEPTH_MODE.NEURAL) + + def _configure_mat_resolution(self) -> None: + """Configures the matrix resolution based on camera settings.""" + if self.width and self.height: + # Use custom resolution + self.mat_resolution = sl.Resolution(self.capture_width, self.capture_height) + else: + # Use camera's default resolution + camera_info = self.zed_camera.get_camera_information() + self.mat_resolution = camera_info.camera_configuration.resolution + self.width = self.mat_resolution.width + self.height = self.mat_resolution.height + self.capture_width = self.width + self.capture_height = self.height + + # Update fps if not set + if self.fps is None: + self.fps = camera_info.camera_configuration.fps + + @staticmethod + def find_cameras() -> list[dict[str, Any]]: + """ + Detects available ZED cameras connected to the system. + """ + found_cameras_info = [] + + try: + # Get list of connected devices + device_list = sl.Camera.get_device_list() + + for device in device_list: + # Create camera information dictionary with available attributes + camera_info = { + "name": str(device.camera_name), + "type": "ZED", + "id": str(device.serial_number), + "model": str(device.camera_model), + "state": str(device.camera_state), + } + + # Get resolution and FPS through camera initialization + try: + zed = sl.Camera() + init_params = sl.InitParameters() + init_params.set_from_serial_number(device.serial_number) + + # Open camera briefly to read configuration + if zed.open(init_params) == sl.ERROR_CODE.SUCCESS: + camera_config = ( + zed.get_camera_information().camera_configuration + ) + camera_info["resolution"] = ( + f"{camera_config.resolution.width}x{camera_config.resolution.height}" + ) + camera_info["fps"] = camera_config.fps + camera_info["firmware"] = camera_config.firmware_version + zed.close() + except Exception as e: + logger.warning( + f"Could not read full configuration for ZED {device.serial_number}: {e}" + ) + # Set default values if camera initialization fails + camera_info["resolution"] = "1920x1080" # Default resolution + camera_info["fps"] = 30 # Default FPS + + found_cameras_info.append(camera_info) + + except Exception as e: + logger.error(f"Error enumerating ZED devices: {e}") + + return found_cameras_info + + def _find_serial_number_from_name(self, name: str) -> str: + """Finds the serial number for a given unique camera name.""" + camera_infos = self.find_cameras() + found_devices = [cam for cam in camera_infos if str(cam["name"]) == name] + + if not found_devices: + available_names = [cam["name"] for cam in camera_infos] + raise ValueError( + f"No ZED camera found with name '{name}'. Available camera names: {available_names}" + ) + + if len(found_devices) > 1: + serial_numbers = [dev["id"] for dev in found_devices] + raise ValueError( + f"Multiple ZED cameras found with name '{name}'. " + f"Please use a unique serial number instead. Found SNs: {serial_numbers}" + ) + + serial_number = str(found_devices[0]["id"]) + return serial_number + + def read_depth(self, timeout_ms: int = 200) -> np.ndarray: + """ + Reads a single frame (depth) synchronously from the camera. + + This is a blocking call. It waits for a depth frame from the ZED camera. + + Args: + timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms. + + Returns: + np.ndarray: The depth map as a NumPy array (height, width) + of type `np.uint16` (raw depth values in millimeters) with rotation applied. + + Raises: + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If reading frames from the camera fails. + """ + + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + if not self.use_depth: + raise RuntimeError( + f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}." + ) + + start_time = time.perf_counter() + + # Grab a frame with timeout + if not self.zed_camera.grab(self.runtime_params): + raise RuntimeError(f"{self} read_depth failed to grab frame.") + + # Retrieve depth map + self.zed_camera.retrieve_measure(py_mat=self.depth_mat, measure=sl.MEASURE.DEPTH, resolution=self.mat_resolution) + depth_map = self.depth_mat.get_data() + + depth_map_processed = self._postprocess_image(depth_map, depth_frame=True) + + read_duration_ms = (time.perf_counter() - start_time) * 1e3 + logger.debug(f"{self} read_depth took: {read_duration_ms:.1f}ms") + + return depth_map_processed + + def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> dict[str, np.ndarray]: + """ + Reads a single frame (color) synchronously from the camera. + + This is a blocking call. It waits for a color frame from the ZED camera. + + Args: + timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms. + + Returns: + dict[str, np.ndarray]: A map of the captured color frame as a NumPy array + (height, width, channels), processed according to `color_mode` and rotation. + + Raises: + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If reading frames from the camera fails. + """ + + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + start_time = time.perf_counter() + + # Grab a frame with timeout handling + if not self.zed_camera.grab(self.runtime_params): + raise RuntimeError(f"{self} read failed to grab frame.") + + # Retrieve left image (RGB by default in ZED SDK) + self.zed_camera.retrieve_image(py_mat=self.image_mat,view= self.config.camera_view, resolution=self.mat_resolution) + color_image_raw = self.image_mat.get_data() + color_image_processed = self._postprocess_image(color_image_raw, color_mode) + + depth_key: str | None = None + depth_map_processed: np.ndarray | None = None + if self.use_depth: + # Retrieve depth map + self.zed_camera.retrieve_measure(py_mat=self.depth_mat, measure=sl.MEASURE.DEPTH, resolution=self.mat_resolution) + depth_map = self.depth_mat.get_data() + depth_map_processed = self._postprocess_image(depth_map, depth_frame=True) + depth_key = get_image_modality_key(image=depth_map_processed, is_depth=True) + + read_duration_ms = (time.perf_counter() - start_time) * 1e3 + logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") + + rgb_key = get_image_modality_key(image=color_image_processed) + images = {rgb_key: color_image_processed} + if depth_key is not None and depth_map_processed is not None: + images[depth_key] = depth_map_processed + return images + + def _postprocess_image( + self, image: np.ndarray, color_mode: ColorMode | None = None, depth_frame: bool = False + ) -> np.ndarray: + """ + Applies color conversion, dimension validation, and rotation to a raw frame. + + Args: + image (np.ndarray): The raw image frame from ZED camera. + color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None, + uses the instance's default `self.color_mode`. + depth_frame (bool): Whether this is a depth frame. + + Returns: + np.ndarray: The processed image frame according to configuration. + + Raises: + ValueError: If the requested `color_mode` is invalid. + RuntimeError: If the raw frame dimensions do not match expectations. + """ + + if color_mode and color_mode not in (ColorMode.RGB, ColorMode.BGR): + raise ValueError( + f"Invalid requested color mode '{color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}." + ) + + if depth_frame: + h, w = image.shape + else: + h, w, c = image.shape + # If the image has 4 channels, convert it to 3 channels (e.g., BGRA to BGR) + if c == 4: + image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR) + # Update the channel count after conversion + c = 3 + + if c != 3: + raise RuntimeError(f"{self} frame channels={c} do not match expected 3 channels (RGB/BGR).") + + if h != self.capture_height or w != self.capture_width: + raise RuntimeError( + f"{self} frame width={w} or height={h} do not match configured width={self.capture_width} or height={self.capture_height}." + ) + + processed_image = image + + # ZED returns images in BGR format by default, convert if needed + if not depth_frame and self.color_mode == ColorMode.RGB: + processed_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]: + processed_image = cv2.rotate(processed_image, self.rotation) + + return processed_image + + def _read_loop(self): + """ + Internal loop run by the background thread for asynchronous reading. + + On each iteration: + 1. Reads a color frame + 2. Stores result in latest_frame (thread-safe) + 3. Sets new_frame_event to notify listeners + + Stops on DeviceNotConnectedError, logs other errors and continues. + """ + while not self.stop_event.is_set(): + try: + color_image = self.read(timeout_ms=500) + + with self.frame_lock: + self.latest_frame = color_image + self.new_frame_event.set() + + except DeviceNotConnectedError: + break + except Exception as e: + logger.warning(f"Error reading frame in background thread for {self}: {e}") + + def _start_read_thread(self) -> None: + """Starts or restarts the background read thread if it's not running.""" + if self.thread is not None and self.thread.is_alive(): + self.thread.join(timeout=0.1) + if self.stop_event is not None: + self.stop_event.set() + + self.stop_event = Event() + self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop") + self.thread.daemon = True + self.thread.start() + + def _stop_read_thread(self): + """Signals the background read thread to stop and waits for it to join.""" + if self.stop_event is not None: + self.stop_event.set() + + if self.thread is not None and self.thread.is_alive(): + self.thread.join(timeout=2.0) + + self.thread = None + self.stop_event = None + + def async_read(self, timeout_ms: float = 200) -> dict[str, np.ndarray]: + """ + Reads the latest available frame data (color) asynchronously. + + This method retrieves the most recent color frame captured by the background + read thread. + + Args: + timeout_ms (float): Maximum time in milliseconds to wait for a frame + to become available. Defaults to 200ms. + + Returns: + dict[str, np.ndarray]: A map of the latest captured frame data (color and depth image). + + Raises: + DeviceNotConnectedError: If the camera is not connected. + TimeoutError: If no frame data becomes available within the specified timeout. + RuntimeError: If the background thread died unexpectedly. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.thread is None or not self.thread.is_alive(): + self._start_read_thread() + + if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0): + thread_alive = self.thread is not None and self.thread.is_alive() + raise TimeoutError( + f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. " + f"Read thread alive: {thread_alive}." + ) + + with self.frame_lock: + frame = self.latest_frame + self.new_frame_event.clear() + + if frame is None: + raise RuntimeError(f"Internal error: Event set but no frame available for {self}.") + + return frame + + def disconnect(self): + """ + Disconnects from the camera and cleans up resources. + + Stops the background read thread and closes the ZED camera. + + Raises: + DeviceNotConnectedError: If the camera is already disconnected. + """ + + if not self.is_connected and self.thread is None: + raise DeviceNotConnectedError( + f"Attempted to disconnect {self}, but it appears already disconnected." + ) + + if self.thread is not None: + self._stop_read_thread() + + if self.zed_camera is not None: + self.zed_camera.close() + self.zed_camera = None + self.runtime_params = None + self.mat_resolution = None + + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/cameras/zed/configuration_zed.py b/src/lerobot/cameras/zed/configuration_zed.py new file mode 100644 index 0000000000..5342a29889 --- /dev/null +++ b/src/lerobot/cameras/zed/configuration_zed.py @@ -0,0 +1,78 @@ +import pyzed.sl as sl + +from dataclasses import dataclass + +from ..configs import CameraConfig, ColorMode, Cv2Rotation + + +@CameraConfig.register_subclass("zed") +@dataclass +class ZedCameraConfig(CameraConfig): + """Configuration class for ZED cameras. + + This class provides specialized configuration options for ZED cameras, + including support for depth sensing and device identification via serial number or name. + + Example configurations for ZED 2i: + ```python + # Basic configurations + ZedCameraConfig("0123456789", 30, 1280, 720) # 1280x720 @ 30FPS + ZedCameraConfig("0123456789", 15, 2208, 1242) # 2208x1242 @ 15FPS + + # Advanced configurations + ZedCameraConfig("0123456789", 30, 1280, 720, use_depth=True) # With depth sensing + ZedCameraConfig("0123456789", 30, 1280, 720, rotation=Cv2Rotation.ROTATE_90) # With 90° rotation + ``` + + Attributes: + fps: Requested frames per second for the color stream. + width: Requested frame width in pixels for the color stream. + height: Requested frame height in pixels for the color stream. + serial_number_or_name: Unique serial number or human-readable name to identify the camera. + color_mode: Color mode for image output (RGB or BGR). Defaults to RGB. + use_depth: Whether to enable depth stream. Defaults to False. + rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation. + warmup_s: Time reading frames before returning from connect (in seconds) + depth_mode: Depth sensing mode for ZED camera. Options: 'QUALITY', 'ULTRA', 'NEURAL' + + Note: + - Either name or serial_number must be specified. + - Depth stream configuration (if enabled) will use the same FPS as the color stream. + - The actual resolution and FPS may be adjusted by the camera to the nearest supported mode. + - For `fps`, `width` and `height`, either all of them need to be set, or none of them. + """ + + serial_number_or_name: str = "" # Default to the unique ZED camera + color_mode: ColorMode = ColorMode.RGB + use_depth: bool = True + rotation: Cv2Rotation = Cv2Rotation.ROTATE_180 + warmup_s: int = 3 # ZED cameras need longer warmup time + depth_mode: str = "QUALITY" + camera_view: sl.VIEW = sl.VIEW.LEFT + + def __post_init__(self): + if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): + raise ValueError( + f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided." + ) + + if self.rotation not in ( + Cv2Rotation.NO_ROTATION, + Cv2Rotation.ROTATE_90, + Cv2Rotation.ROTATE_180, + Cv2Rotation.ROTATE_270, + ): + raise ValueError( + f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided." + ) + + if self.depth_mode not in ("QUALITY", "ULTRA", "NEURAL"): + raise ValueError( + f"`depth_mode` is expected to be 'QUALITY', 'ULTRA', or 'NEURAL', but {self.depth_mode} is provided." + ) + + values = (self.fps, self.width, self.height) + if any(v is not None for v in values) and any(v is None for v in values): + raise ValueError( + "For `fps`, `width` and `height`, either all of them need to be set, or none of them." + ) diff --git a/src/lerobot/configs/demo_configs/lerobot_example_config.json b/src/lerobot/configs/demo_configs/lerobot_example_config.json new file mode 100644 index 0000000000..b25815c6fd --- /dev/null +++ b/src/lerobot/configs/demo_configs/lerobot_example_config.json @@ -0,0 +1,133 @@ +{ + "type": "gym_manipulator", + "robot": { + "type": "so100_follower_end_effector", + "port": "/dev/tty.usbmodem58760431631", + "urdf_path": "path/to/your/robot.urdf", + "target_frame_name": "gripper_frame_link", + "cameras": { + "front": { + "type": "opencv", + "index_or_path": 0, + "height": 720, + "width": 1280, + "fps": 30 + }, + "wrist": { + "type": "opencv", + "index_or_path": 1, + "height": 720, + "width": 1280, + "fps": 30 + } + }, + "end_effector_bounds": { + "min": [ + -1.0, + -1.0, + -1.0 + ], + "max": [ + 1.0, + 1.0, + 1.0 + ] + }, + "end_effector_step_sizes": { + "x": 0.025, + "y": 0.025, + "z": 0.025 + } + }, + "teleop": { + "type": "gamepad", + "use_gripper": true + }, + "wrapper": { + "display_cameras": false, + "add_joint_velocity_to_observation": true, + "add_current_to_observation": true, + "add_ee_pose_to_observation": true, + "crop_params_dict": { + "observation.images.front": [ + 270, + 170, + 90, + 190 + ], + "observation.images.wrist": [ + 0, + 0, + 480, + 640 + ] + }, + "resize_size": [ + 128, + 128 + ], + "control_time_s": 20.0, + "use_gripper": true, + "gripper_quantization_threshold": null, + "gripper_penalty": -0.02, + "gripper_penalty_in_reward": false, + "fixed_reset_joint_positions": [ + 0.0, + 0.0, + 0.0, + 90.0, + 0.0, + 5.0 + ], + "reset_time_s": 2.5, + "control_mode": "gamepad" + }, + "name": "real_robot", + "mode": null, + "repo_id": null, + "dataset_root": null, + "task": "", + "num_episodes": 2, + "episode": 0, + "pretrained_policy_name_or_path": null, + "device": "cpu", + "push_to_hub": true, + "fps": 10, + "features": { + "observation.images.front": { + "type": "VISUAL", + "shape": [ + 3, + 128, + 128 + ] + }, + "observation.images.wrist": { + "type": "VISUAL", + "shape": [ + 3, + 128, + 128 + ] + }, + "observation.state": { + "type": "STATE", + "shape": [ + 15 + ] + }, + "action": { + "type": "ACTION", + "shape": [ + 3 + ] + } + }, + "features_map": { + "observation.images.front": "observation.images.side", + "observation.images.wrist": "observation.images.wrist", + "observation.state": "observation.state", + "action": "action" + }, + "reward_classifier_pretrained_path": null +} \ No newline at end of file diff --git a/src/lerobot/configs/demo_configs/serl_config.json b/src/lerobot/configs/demo_configs/serl_config.json new file mode 100644 index 0000000000..e772e784a2 --- /dev/null +++ b/src/lerobot/configs/demo_configs/serl_config.json @@ -0,0 +1,99 @@ +{ + "env": { + "name": "real_robot", + "fps": 30, + "robot": { + "type": "so101_follower", + "port": "/dev/ttyACM1", + "id": "follower_arm", + "use_degrees": true, + "cameras": { + "wrist": { + "type": "opencv", + "index_or_path": 0, + "width": 640, + "height": 480, + "fps": 30 + }, + "head": { + "type": "zed", + "rotation": "ROTATE_180", + "width": 1280, + "height": 720, + "fps": 30 + } + }, + "urdf_path": "./src/lerobot/robots/so100_follower/SO101/so101_new_calib.urdf", + "target_frame_name": "gripper_frame_link" + }, + "teleop": { + "type": "so101_leader", + "port": "/dev/ttyACM0", + "id": "leader_arm", + "use_degrees": true + }, + "processor": { + "control_mode": "leader", + "observation": { + "display_cameras": true, + "add_joint_velocity_to_observation": true, + "add_current_to_observation": true + }, + "image_preprocessing": { + "crop_params_dict": { + "observation.images.wrist": [30, 100, 400, 380], + "observation.images.head_rgb": [100, 100, 580, 580], + "observation.images.head_depth": [100, 100, 580, 580] + }, + "resize_size": [128, 128] + }, + "gripper": { + "use_gripper": true, + "gripper_penalty": 0 + }, + "reset": { + "fixed_reset_joint_positions": [ + 3.4840621948242188, + -99.40602111816406, + 19.78221321105957, + 99.30374145507812, + -53.28620147705078, + 1.26262629032135 + ], + "reset_time_s": 15, + "control_time_s": 60.0, + "terminate_on_success": true + }, + "inverse_kinematics": { + "urdf_path": "./src/lerobot/robots/so101_follower/sim/so101_new_calib.urdf", + "target_frame_name": "gripper_frame_link", + "end_effector_bounds": { + "min": [0.0704, -0.1517, 0.0053], + "max": [0.4292, 0.3503, 0.4122] + }, + "end_effector_step_sizes": { + "x": 0.02, + "y": 0.02, + "z": 0.02 + } + }, + "reward_classifier": { + "pretrained_path": null, + "success_threshold": 0.5, + "success_reward": 1.0 + }, + "max_gripper_pos": 100 + } + }, + "dataset": { + "repo_id": "weiwenliu/lerobot-serl-test", + "root": null, + "task": "Test SERL in cardbox perturbation", + "num_episodes_to_record": 2, + "replay_episode": 0, + "push_to_hub": true + }, + "mode": "record", + "device": "cuda", + "resume": false +} \ No newline at end of file diff --git a/src/lerobot/datasets/image_writer.py b/src/lerobot/datasets/image_writer.py index ee10df6e19..27df9139da 100644 --- a/src/lerobot/datasets/image_writer.py +++ b/src/lerobot/datasets/image_writer.py @@ -38,34 +38,98 @@ def wrapper(*args, **kwargs): return wrapper -def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image: - # TODO(aliberts): handle 1 channel and 4 for depth images - if image_array.ndim != 3: - raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.") - - if image_array.shape[0] == 3: - # Transpose from pytorch convention (C, H, W) to (H, W, C) - image_array = image_array.transpose(1, 2, 0) - - elif image_array.shape[-1] != 3: - raise NotImplementedError( - f"The image has {image_array.shape[-1]} channels, but 3 is required for now." - ) - - if image_array.dtype != np.uint8: - if range_check: - max_ = image_array.max().item() - min_ = image_array.min().item() - if max_ > 1.0 or min_ < 0.0: - raise ValueError( - "The image data type is float, which requires values in the range [0.0, 1.0]. " - f"However, the provided range is [{min_}, {max_}]. Please adjust the range or " - "provide a uint8 image with values in the range [0, 255]." - ) +def image_array_to_pil_image( + image_array: np.ndarray, range_check: bool = True +) -> PIL.Image.Image: + """ + Convert numpy array to PIL Image with support for both 2D and 3D arrays. - image_array = (image_array * 255).astype(np.uint8) + Args: + image_array: Input image array, can be 2D (H, W) or 3D (C, H, W) or (H, W, C) + range_check: Whether to validate value ranges for float arrays - return PIL.Image.fromarray(image_array) + Returns: + PIL.Image.Image: Converted PIL image + """ + # Handle 2D arrays (depth maps, grayscale images) + if image_array.ndim == 2: + if image_array.dtype in [np.float32, np.float64]: + if range_check: + max_ = image_array.max().item() + min_ = image_array.min().item() + if max_ > 1.0 or min_ < 0.0: + # For depth maps, normalize to 0-255 range + image_array = (image_array - min_) / (max_ - min_) * 255 + image_array = image_array.astype(np.uint8) + else: + image_array = (image_array * 255).astype(np.uint8) + else: + image_array = (image_array * 255).astype(np.uint8) + elif image_array.dtype == np.uint16: + # Depth maps in uint16 - keep as is or scale down to uint8 + # Option 1: Keep as uint16 (requires mode "I;16") + return PIL.Image.fromarray(image_array, mode="I;16") + else: + # Other 2D arrays (uint8, etc.) + image_array = image_array.astype(np.uint8) + + return PIL.Image.fromarray(image_array, mode="L") + + # Handle 3D arrays to handle both (C, H, W) and (H, W, C) + elif image_array.ndim == 3: + # Determine the channel dimension + channels_dim = None + spatial_dims = [] + + for i, dim in enumerate(image_array.shape): + if dim in [1, 3, 4]: # Possible channel counts + if channels_dim is None: + channels_dim = i + else: + spatial_dims.append(dim) + else: + spatial_dims.append(dim) + + # If we found a channel dimension and it's not the last dimension, transpose + if channels_dim is not None and channels_dim != 2: + # Transpose from (C, H, W) to (H, W, C) + if channels_dim == 0: + image_array = image_array.transpose(1, 2, 0) + elif channels_dim == 1: + image_array = image_array.transpose(0, 2, 1) + + # Now image_array should be in (H, W, C) format + channels = image_array.shape[2] + + # Handle different channel counts + if channels == 1: + # Single channel image (depth, grayscale) + image_array = image_array[:, :, 0] # Convert to 2D + return image_array_to_pil_image(image_array, range_check) + elif channels == 3: + # RGB image + if image_array.dtype != np.uint8: + if range_check: + max_ = image_array.max().item() + min_ = image_array.min().item() + if max_ > 1.0 or min_ < 0.0: + raise ValueError( + "The image data type is float, which requires values in the range [0.0, 1.0]. " + f"However, the provided range is [{min_}, {max_}]. Please adjust the range or " + "provide a uint8 image with values in the range [0, 255]." + ) + image_array = (image_array * 255).astype(np.uint8) + return PIL.Image.fromarray(image_array) + else: + raise NotImplementedError( + f"The image has {channels} channels, but 1 or 3 is required. " + f"Shape: {image_array.shape}" + ) + + else: + raise ValueError( + f"Unsupported array dimensions: {image_array.ndim}. Expected 2D or 3D array." + ) def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1): diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index b661b21b03..bb5048990c 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -15,6 +15,7 @@ # limitations under the License. import contextlib import gc +import json import logging import shutil import tempfile @@ -184,7 +185,21 @@ def image_keys(self) -> list[str]: @property def video_keys(self) -> list[str]: """Keys to access visual modalities stored as videos.""" - return [key for key, ft in self.features.items() if ft["dtype"] == "video"] + video_keys_list = [] + + for key, ft in self.features.items(): + # Skip features that don't have the expected dictionary structure + if not isinstance(ft, dict): + continue + + # Skip features without dtype field + if "dtype" not in ft: + continue + + if ft["dtype"] == "video": + video_keys_list.append(key) + + return video_keys_list @property def camera_keys(self) -> list[str]: diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index a2f2850141..da5c6ddeb9 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -653,6 +653,40 @@ def hw_to_dataset_features( return features +def _find_image_keys(feature_key: str, values: dict[str, Any]) -> list[str]: + """Find all matching image keys in values for a given feature key. + + Args: + feature_key: Camera name from dataset feature (e.g., "head") + values: Dictionary containing image data + + Returns: + list[str]: List of matching keys from values dictionary + + Raises: + KeyError: If no matching image key found + """ + # Direct camera name match (backward compatibility) + if feature_key in values: + return [feature_key] + + # Look for valid camera_modality format keys + valid_modalities = ["rgb", "gray", "depth", "rgba", "ir"] + valid_keys = [ + k + for k in values.keys() + if any(k == f"{feature_key}_{mod}" for mod in valid_modalities) + ] + + if not valid_keys: + available_keys = [k for k in values.keys() if "_" in k] + raise KeyError( + f"No valid image key found for '{feature_key}'. Available: {available_keys}" + ) + + return valid_keys + + def build_dataset_frame( ds_features: dict[str, dict], values: dict[str, Any], prefix: str ) -> dict[str, np.ndarray]: @@ -677,7 +711,13 @@ def build_dataset_frame( elif ft["dtype"] == "float32" and len(ft["shape"]) == 1: frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32) elif ft["dtype"] in ["image", "video"]: - frame[key] = values[key.removeprefix(f"{prefix}.images.")] + feature_key = key.removeprefix(f"{prefix}.images.") + if feature_key in values: + frame[key] = values[feature_key] + else: + matched_keys = _find_image_keys(feature_key, values) + for matched_key in matched_keys: + frame[f"{prefix}.images.{matched_key}"] = values[matched_key] return frame @@ -1083,7 +1123,7 @@ def validate_feature_image_or_video( Args: name (str): The name of the feature. - expected_shape (list[str]): The expected shape (C, H, W). + expected_shape (list[str]): The expected shape (C, H, W), (H, W). value: The image data to validate. Returns: @@ -1093,9 +1133,23 @@ def validate_feature_image_or_video( error_message = "" if isinstance(value, np.ndarray): actual_shape = value.shape - c, h, w = expected_shape - if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)): - error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n" + # Depth or gray image. + if len(expected_shape) == 2: + if len(actual_shape) != 2: + error_message += f"Feature '{name}': expected 2D array, got {len(actual_shape)}D shape {actual_shape}\n" + elif actual_shape != tuple(expected_shape): + error_message += f"Feature '{name}': shape {actual_shape} != expected {tuple(expected_shape)}\n" + + # RGB image. + elif len(expected_shape) == 3: + c, h, w = expected_shape + if len(actual_shape) != 3: + error_message += f"Feature '{name}': expected 3D array, got {len(actual_shape)}D shape {actual_shape}\n" + elif actual_shape not in [(c, h, w), (h, w, c)]: + error_message += f"Feature '{name}': shape {actual_shape} not in expected {[(c, h, w), (h, w, c)]}\n" + + else: + error_message += f"Feature '{name}': invalid expected shape {expected_shape}\n" elif isinstance(value, PILImage.Image): pass else: diff --git a/src/lerobot/model/kinematics.py b/src/lerobot/model/kinematics.py index f059b97907..c8661ae43a 100644 --- a/src/lerobot/model/kinematics.py +++ b/src/lerobot/model/kinematics.py @@ -32,6 +32,11 @@ def __init__( target_frame_name: Name of the end-effector frame in the URDF joint_names: List of joint names to use for the kinematics solver """ + print(f"DEBUG: Kinematics solver initialized:") + print(f" URDF path: {urdf_path}") + print(f" Target frame: {target_frame_name}") + print(f" Motor names: {joint_names}") + try: import placo except ImportError as e: diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index be11ac1af4..3eef2cd740 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -28,7 +28,12 @@ RobotObservation, TransitionKey, ) -from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep +from .delta_action_processor import ( + MapDeltaActionToRobotActionStep, + MapTensorToDeltaActionDictStep, + Map7DDeltaActionToRobotActionStep, + MapTensorTo7DDeltaActionDictStep, +) from .device_processor import DeviceProcessorStep from .factory import ( make_default_processors, @@ -46,9 +51,11 @@ GripperPenaltyProcessorStep, ImageCropResizeProcessorStep, InterventionActionProcessorStep, + LeaderArmInterventionProcessorStep, RewardClassifierProcessorStep, TimeLimitProcessorStep, ) +from .joint_action_processor import DirectJointControlStep, JointBoundsAndSafetyStep from .joint_observations_processor import JointVelocityProcessorStep, MotorCurrentProcessorStep from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep, hotswap_stats from .observation_processor import VanillaObservationProcessorStep @@ -71,6 +78,7 @@ TruncatedProcessorStep, ) from .policy_robot_bridge import ( + DirectJointToPolicyActionProcessorStep, PolicyActionToRobotActionProcessorStep, RobotActionToPolicyActionProcessorStep, ) @@ -85,6 +93,8 @@ "batch_to_transition", "create_transition", "DeviceProcessorStep", + "DirectJointControlStep", + "DirectJointToPolicyActionProcessorStep", "DoneProcessorStep", "EnvAction", "EnvTransition", @@ -95,12 +105,16 @@ "InfoProcessorStep", "InterventionActionProcessorStep", "JointVelocityProcessorStep", + "JointBoundsAndSafetyStep", + "LeaderArmInterventionProcessorStep", "make_default_processors", "make_default_teleop_action_processor", "make_default_robot_action_processor", "make_default_robot_observation_processor", + "Map7DDeltaActionToRobotActionStep", "MapDeltaActionToRobotActionStep", "MapTensorToDeltaActionDictStep", + "MapTensorTo7DDeltaActionDictStep", "MotorCurrentProcessorStep", "NormalizerProcessorStep", "Numpy2TorchActionProcessorStep", diff --git a/src/lerobot/processor/delta_action_processor.py b/src/lerobot/processor/delta_action_processor.py index a8395637ca..a7182308e4 100644 --- a/src/lerobot/processor/delta_action_processor.py +++ b/src/lerobot/processor/delta_action_processor.py @@ -141,3 +141,125 @@ def transform_features( ) return features + + +@ProcessorStepRegistry.register("map_tensor_to_7d_delta_action_dict") +@dataclass +class MapTensorTo7DDeltaActionDictStep(ActionProcessorStep): + """ + Maps a flat 7D action tensor to a structured delta action dictionary. + Supports 7-dimensional actions: [x, y, z, rx, ry, rz, gripper] + """ + + use_gripper: bool = True + + def action(self, action: PolicyAction) -> RobotAction: + if not isinstance(action, PolicyAction): + raise ValueError("Only PolicyAction is supported for this processor") + + if action.dim() > 1: + action = action.squeeze(0) + + delta_action = { + "delta_x": action[0].item(), + "delta_y": action[1].item(), + "delta_z": action[2].item(), + "delta_rx": action[3].item(), + "delta_ry": action[4].item(), + "delta_rz": action[5].item(), + } + if self.use_gripper: + delta_action["gripper"] = action[6].item() + return delta_action + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + for axis in ["x", "y", "z"]: + features[PipelineFeatureType.ACTION][f"delta_{axis}"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + + for axis in ["rx", "ry", "rz"]: + features[PipelineFeatureType.ACTION][f"delta_{axis}"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + + if self.use_gripper: + features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + return features + + +@ProcessorStepRegistry.register("map_7d_delta_action_to_robot_action") +@dataclass +class Map7DDeltaActionToRobotActionStep(RobotActionProcessorStep): + """ + Maps 7D delta actions to robot target actions for inverse kinematics. + Supports both position and rotation deltas. + """ + + # Scale factors for delta movements + position_scale: float = 1.0 + rotation_scale: float = 1.0 + noise_threshold: float = 1e-3 + + def action(self, action: RobotAction) -> RobotAction: + delta_x = action.pop("delta_x") + delta_y = action.pop("delta_y") + delta_z = action.pop("delta_z") + delta_rx = action.pop("delta_rx") + delta_ry = action.pop("delta_ry") + delta_rz = action.pop("delta_rz") + gripper = action.pop("gripper") + + position_magnitude = (delta_x**2 + delta_y**2 + delta_z**2) ** 0.5 + rotation_magnitude = (delta_rx**2 + delta_ry**2 + delta_rz**2) ** 0.5 + enabled = (position_magnitude > self.noise_threshold) or ( + rotation_magnitude > self.noise_threshold + ) + + scaled_delta_x = delta_x * self.position_scale + scaled_delta_y = delta_y * self.position_scale + scaled_delta_z = delta_z * self.position_scale + + scaled_delta_rx = delta_rx * self.rotation_scale + scaled_delta_ry = delta_ry * self.rotation_scale + scaled_delta_rz = delta_rz * self.rotation_scale + + action = { + "enabled": enabled, + "target_x": scaled_delta_x, + "target_y": scaled_delta_y, + "target_z": scaled_delta_z, + "target_wx": scaled_delta_rx, + "target_wy": scaled_delta_ry, + "target_wz": scaled_delta_rz, + "gripper_vel": float(gripper), + } + + return action + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + for axis in ["x", "y", "z", "rx", "ry", "rz", "gripper"]: + features[PipelineFeatureType.ACTION].pop(f"delta_{axis}", None) + + for feat in [ + "enabled", + "target_x", + "target_y", + "target_z", + "target_wx", + "target_wy", + "target_wz", + "gripper_vel", + ]: + features[PipelineFeatureType.ACTION][f"{feat}"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + + return features + diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index f0dbac9c3c..6875742e03 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -23,10 +23,12 @@ import numpy as np import torch import torchvision.transforms.functional as F # noqa: N812 +from lerobot.model.kinematics import RobotKinematics from lerobot.configs.types import PipelineFeatureType, PolicyFeature from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.teleoperators.utils import TeleopEvents +from lerobot.utils.rotation import Rotation from .core import EnvTransition, PolicyAction, TransitionKey from .pipeline import ( @@ -141,6 +143,8 @@ class AddTeleopEventsAsInfoStep(InfoProcessorStep): """ teleop_device: TeleopWithEvents + _debug_frame_count: int = 0 + _last_space_state: bool = False def __post_init__(self): """Validates that the provided teleoperator supports events after initialization.""" @@ -156,10 +160,54 @@ def info(self, info: dict) -> dict: Returns: A new dictionary including the teleoperator events. """ - new_info = dict(info) + self._debug_frame_count += 1 teleop_events = self.teleop_device.get_teleop_events() + + if self._debug_frame_count % 30 == 30: # Disable now + print( + f"\n=== DEEP DEBUG TELEOP EVENTS (Frame {self._debug_frame_count}) ===" + ) + print(f"1. Raw teleop_events: {teleop_events}") + print("2. Teleop device details:") + print(f" Type: {type(self.teleop_device)}") + print(f" Module: {self.teleop_device.__class__.__module__}") + + print("3. Space key detection (multiple methods):") + + if hasattr(self.teleop_device, "is_space_pressed"): + space_pressed = self.teleop_device.is_space_pressed() + print(f" is_space_pressed(): {space_pressed}") + else: + print(" is_space_pressed(): Method not available") + + if hasattr(self.teleop_device, "is_intervention_triggered"): + intervention_triggered = self.teleop_device.is_intervention_triggered() + print(f" is_intervention_triggered(): {intervention_triggered}") + else: + print(" is_intervention_triggered(): Method not available") + + if hasattr(self.teleop_device, "get_state"): + state = self.teleop_device.get_state() + print(f" get_state(): {state}") + else: + print(" get_state(): Method not available") + + if hasattr(self.teleop_device, "buttons"): + print(f" buttons: {self.teleop_device.buttons}") + else: + print(" buttons: Attribute not available") + + if hasattr(self.teleop_device, "key_events"): + print(f" key_events: {self.teleop_device.key_events}") + else: + print(" key_events: Attribute not available") + + print("=== END DEEP DEBUG ===\n") + + new_info = dict(info) new_info.update(teleop_events) + return new_info def transform_features( @@ -408,6 +456,7 @@ class InterventionActionProcessorStep(ProcessorStep): use_gripper: bool = False terminate_on_success: bool = True + _debug_frame_count: int = 0 def __call__(self, transition: EnvTransition) -> EnvTransition: """ @@ -420,6 +469,7 @@ def __call__(self, transition: EnvTransition) -> EnvTransition: The modified transition, potentially with an overridden action, updated reward, and termination status. """ + self._debug_frame_count += 1 action = transition.get(TransitionKey.ACTION) if not isinstance(action, PolicyAction): raise ValueError(f"Action should be a PolicyAction type got {type(action)}") @@ -433,6 +483,22 @@ def __call__(self, transition: EnvTransition) -> EnvTransition: success = info.get(TeleopEvents.SUCCESS, False) rerecord_episode = info.get(TeleopEvents.RERECORD_EPISODE, False) + if self._debug_frame_count % 30 == 1: # fps=30 + print( + f"\n=== DEBUG INTERVENTION PROCESSOR (Frame {self._debug_frame_count}) ===" + ) + print( + f"Input action type: {type(action)}, shape: {getattr(action, 'shape', 'No shape')}" + ) + print(f"Info keys: {list(info.keys())}") + print(f"Complementary data keys: {list(complementary_data.keys())}") + print(f"Teleop action type: {type(teleop_action)}, value: {teleop_action}") + print("Intervention signals:") + print(f" IS_INTERVENTION: {is_intervention}") + print(f" TERMINATE_EPISODE: {terminate_episode}") + print(f" SUCCESS: {success}") + print(f" RERECORD_EPISODE: {rerecord_episode}") + new_transition = transition.copy() # Override action if intervention is active @@ -446,20 +512,46 @@ def __call__(self, transition: EnvTransition) -> EnvTransition: ] if self.use_gripper: action_list.append(teleop_action.get(GRIPPER_KEY, 1.0)) + + if self._debug_frame_count % 30 == 1: + print(f"Converting teleop dict to list: {action_list}") + elif isinstance(teleop_action, np.ndarray): action_list = teleop_action.tolist() + if self._debug_frame_count % 30 == 1: + print(f"Converting teleop numpy array to list: {action_list}") else: action_list = teleop_action + if self._debug_frame_count % 30 == 1: + print(f"Using teleop action as-is: {action_list}") + + teleop_action_tensor = torch.tensor( + action_list, dtype=action.dtype, device=action.device + ) + + if self._debug_frame_count % 30 == 1: + print("ACTION OVERRIDE:") + print(f" Original policy action: {action}") + print(f" New teleop action: {teleop_action_tensor}") + print(f" Action device: {teleop_action_tensor.device}") + print(f" Action dtype: {teleop_action_tensor.dtype}") - teleop_action_tensor = torch.tensor(action_list, dtype=action.dtype, device=action.device) new_transition[TransitionKey.ACTION] = teleop_action_tensor # Handle episode termination - new_transition[TransitionKey.DONE] = bool(terminate_episode) or ( - self.terminate_on_success and success - ) + original_done = transition.get(TransitionKey.DONE, False) + new_done = bool(terminate_episode) or (self.terminate_on_success and success) + new_transition[TransitionKey.DONE] = new_done new_transition[TransitionKey.REWARD] = float(success) + if self._debug_frame_count % 30 == 1 and (original_done != new_done or success): + print("TERMINATION STATUS:") + print(f" Original done: {original_done}") + print( + f" New done: {new_done} (terminate_episode: {terminate_episode}, success: {success})" + ) + print(f" Reward set to: {float(success)}") + # Update info with intervention metadata info = new_transition.get(TransitionKey.INFO, {}) info[TeleopEvents.IS_INTERVENTION] = is_intervention @@ -472,6 +564,17 @@ def __call__(self, transition: EnvTransition) -> EnvTransition: complementary_data[TELEOP_ACTION_KEY] = new_transition.get(TransitionKey.ACTION) new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data + # 最终状态检查 + if self._debug_frame_count % 30 == 1: + final_action = new_transition.get(TransitionKey.ACTION) + print("FINAL TRANSITION STATE:") + print(f" Final action: {final_action}") + print(f" Final action type: {type(final_action)}") + print(f" Final done: {new_transition.get(TransitionKey.DONE)}") + print(f" Final reward: {new_transition.get(TransitionKey.REWARD)}") + print(f" Info[IS_INTERVENTION]: {info.get(TeleopEvents.IS_INTERVENTION)}") + print("=== END DEBUG ===\n") + return new_transition def get_config(self) -> dict[str, Any]: @@ -492,6 +595,524 @@ def transform_features( return features +@ProcessorStepRegistry.register("leader_arm_intervention") +class LeaderArmInterventionProcessorStep(ProcessorStep): + """ + Leader arm intervention with direct joint position control. + User moves leader arm and follower directly copies joint positions. + Still computes and saves delta action for dataset. + """ + + use_gripper: bool = False + terminate_on_success: bool = True + + def __init__( + self, + use_gripper: bool = False, + terminate_on_success: bool = True, + sync_tolerances: dict | None = None, + kinematics_solver: RobotKinematics | None = None, + motor_names: list[str] | None = None, + ): + self.use_gripper = use_gripper + self.terminate_on_success = terminate_on_success + self.kinematics_solver = kinematics_solver + self.motor_names = motor_names or [] + + # Position tolerance in degrees + self.sync_tolerances = sync_tolerances or { + "shoulder_pan.pos": 5.0, + "shoulder_lift.pos": 10.0, + "elbow_flex.pos": 10.0, + "wrist_flex.pos": 10.0, + "wrist_roll.pos": 10.0, + "gripper.pos": 10.0, + } + + # Movement detection thresholds + self.position_threshold = 0.002 + self.orientation_threshold = 0.01 + self.gripper_threshold = 1.0 + + self._debug_frame_count = 0 + self._last_intervention_state = False + self._is_position_synced = False + self._follower_reference_positions = None + self._leader_base_positions = None + self._leader_base_ee_pose = None + self._sync_start_time = None + self._last_leader_ee_pose = None + self._stable_frames_count = 0 + self._stable_frames_required = 10 + + def __call__(self, transition: EnvTransition) -> EnvTransition: + self._debug_frame_count += 1 + + action = transition.get(TransitionKey.ACTION) + if not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type got {type(action)}") + + info = transition.get(TransitionKey.INFO, {}) + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + + teleop_action = complementary_data.get(TELEOP_ACTION_KEY, {}) + is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False) + terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False) + success = info.get(TeleopEvents.SUCCESS, False) + rerecord_episode = info.get(TeleopEvents.RERECORD_EPISODE, False) + + new_transition = transition.copy() + + # Detect intervention state transitions + intervention_started = is_intervention and not self._last_intervention_state + intervention_ended = not is_intervention and self._last_intervention_state + + if intervention_started: + print( + "LEADER ARM INTERVENTION: Intervention started - initializing manual position synchronization" + ) + self._initialize_manual_sync(transition) + + if intervention_ended: + print("LEADER ARM INTERVENTION: Intervention ended") + self._reset_sync_state() + + self._last_intervention_state = is_intervention + + # Process teleoperation during intervention + if is_intervention and teleop_action is not None: + if isinstance(teleop_action, dict) and any( + ".pos" in key for key in teleop_action.keys() + ): + if not self._is_position_synced: + sync_complete = self._check_manual_sync_complete( + teleop_action, transition + ) + if sync_complete: + print( + "LEADER ARM INTERVENTION: Position synchronization COMPLETE!" + ) + print("LEADER ARM INTERVENTION: Ready for teleoperation") + self._is_position_synced = True + self._leader_base_positions = teleop_action.copy() + if self.kinematics_solver is not None: + self._leader_base_ee_pose = ( + self._compute_ee_pose_from_joints(teleop_action) + ) + self._last_leader_ee_pose = self._leader_base_ee_pose + + # Store leader joint positions for direct control + leader_joint_positions = self._extract_leader_joint_positions( + teleop_action + ) + complementary_data["leader_joint_positions"] = ( + leader_joint_positions + ) + + # Compute and store delta action for dataset + delta_action = self._compute_delta_action(teleop_action) + complementary_data["teleop_action"] = torch.tensor( + delta_action, dtype=action.dtype, device=action.device + ) + + # Send current positions to maintain pose during sync transition + observation = transition.get(TransitionKey.OBSERVATION, {}) + current_joint_positions = [] + for motor_name in self.motor_names: + joint_key = f"{motor_name}.pos" + if joint_key in observation: + current_joint_positions.append(observation[joint_key]) + else: + current_joint_positions.append(0.0) + + current_positions_tensor = torch.tensor( + current_joint_positions, + dtype=action.dtype, + device=action.device, + ) + new_transition[TransitionKey.ACTION] = current_positions_tensor + + else: + # Still syncing: maintain current position but compute delta for recording + observation = transition.get(TransitionKey.OBSERVATION, {}) + + # Compute delta action for dataset (even though we're not moving) + delta_action = self._compute_delta_action(teleop_action) + complementary_data["teleop_action"] = torch.tensor( + delta_action, dtype=action.dtype, device=action.device + ) + + # Send current positions to maintain pose + current_joint_positions = [] + for motor_name in self.motor_names: + joint_key = f"{motor_name}.pos" + if joint_key in observation: + current_joint_positions.append(observation[joint_key]) + else: + current_joint_positions.append(0.0) + + current_positions_tensor = torch.tensor( + current_joint_positions, + dtype=action.dtype, + device=action.device, + ) + new_transition[TransitionKey.ACTION] = current_positions_tensor + else: + # Position sync complete - compute EE delta action first + delta_action = self._compute_delta_action(teleop_action) + + # Store delta action for dataset (always) + complementary_data["teleop_action"] = torch.tensor( + delta_action, dtype=action.dtype, device=action.device + ) + + # Check EE delta magnitude to decide if we should move follower + delta_magnitude = np.linalg.norm(delta_action[:3]) + rot_magnitude = np.linalg.norm(delta_action[3:6]) + + if delta_magnitude > 0.005 or rot_magnitude > 0.01: + # Movement detected: store leader joint positions for direct control + leader_joint_positions = self._extract_leader_joint_positions(teleop_action) + complementary_data["leader_joint_positions"] = leader_joint_positions + print(f"Movement detected: pos={delta_magnitude:.4f}, rot={rot_magnitude:.4f}") + else: + # No movement: clear leader joint positions + complementary_data["leader_joint_positions"] = None + + # Send current joint positions to maintain pose (not EE delta action) + observation = transition.get(TransitionKey.OBSERVATION, {}) + current_joint_positions = [] + for motor_name in self.motor_names: + joint_key = f"{motor_name}.pos" + if joint_key in observation: + current_joint_positions.append(observation[joint_key]) + else: + current_joint_positions.append(0.0) + + current_positions_tensor = torch.tensor( + current_joint_positions, dtype=action.dtype, device=action.device + ) + new_transition[TransitionKey.ACTION] = current_positions_tensor + + # Update base for delta calculation + self._leader_base_positions = teleop_action.copy() + if self.kinematics_solver is not None: + self._leader_base_ee_pose = self._compute_ee_pose_from_joints( + teleop_action + ) + else: + # No intervention: maintain current position + observation = transition.get(TransitionKey.OBSERVATION, {}) + + # Get current joint positions from observation + current_joint_positions = [] + for motor_name in self.motor_names: + joint_key = f"{motor_name}.pos" + if joint_key in observation: + current_joint_positions.append(observation[joint_key]) + else: + current_joint_positions.append(0.0) + + # Send current positions to maintain pose + current_positions_tensor = torch.tensor( + current_joint_positions, dtype=action.dtype, device=action.device + ) + new_transition[TransitionKey.ACTION] = current_positions_tensor + + print(f"DEBUG: Maintaining current position: {[f'{x:.1f}' for x in current_joint_positions]}") + + # Handle episode termination + new_transition[TransitionKey.DONE] = bool(terminate_episode) or ( + self.terminate_on_success and success + ) + new_transition[TransitionKey.REWARD] = float(success) + + # Update info dictionary + info = new_transition.get(TransitionKey.INFO, {}) + info[TeleopEvents.IS_INTERVENTION] = is_intervention + info[TeleopEvents.RERECORD_EPISODE] = rerecord_episode + info[TeleopEvents.SUCCESS] = success + new_transition[TransitionKey.INFO] = info + + # Update complementary data + complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + complementary_data[TELEOP_ACTION_KEY] = new_transition.get(TransitionKey.ACTION) + new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data + + return new_transition + + def _extract_leader_joint_positions(self, leader_positions: dict) -> list: + """Extract leader joint positions in the correct order for motor_names""" + joint_positions = [] + for motor_name in self.motor_names: + joint_key = f"{motor_name}.pos" + if joint_key in leader_positions: + joint_positions.append(leader_positions[joint_key]) + else: + joint_positions.append(0.0) + + return joint_positions + + def _compute_delta_action(self, current_leader_positions: dict) -> list: + """Compute delta action for dataset recording""" + if self._leader_base_ee_pose is None or self.kinematics_solver is None: + return [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + ( + [0.0] if self.use_gripper else [0.0] + ) + + leader_current_pose = self._compute_ee_pose_from_joints( + current_leader_positions + ) + base_pose_inv = np.linalg.inv(self._leader_base_ee_pose) + delta_pose = leader_current_pose @ base_pose_inv + + delta_action = self._pose_to_7d_action(delta_pose) + + if self.use_gripper: + current_gripper = current_leader_positions["gripper.pos"] + base_gripper = self._leader_base_positions["gripper.pos"] + gripper_delta = (current_gripper - base_gripper) * 0.1 + delta_action.append(gripper_delta) + else: + delta_action.append(0.0) + + return delta_action + + def _compute_ee_pose_from_joints(self, joint_positions: dict) -> np.ndarray: + if self.kinematics_solver is None: + return np.eye(4) + + joint_array = [] + for motor_name in self.motor_names: + joint_key = motor_name + ".pos" + if joint_key in joint_positions: + joint_array.append(joint_positions[joint_key]) + else: + joint_array.append(0.0) + + ee_pose = self.kinematics_solver.forward_kinematics(np.array(joint_array)) + return ee_pose + + def _pose_to_7d_action(self, pose: np.ndarray) -> list: + """Convert 4x4 transform to 7D action [x, y, z, rx, ry, rz]""" + position = pose[:3, 3] + rotation_matrix = pose[:3, :3] + + try: + from scipy.spatial.transform import Rotation + + rot = Rotation.from_matrix(rotation_matrix) + euler_angles = rot.as_euler("zyx", degrees=False) + rx, ry, rz = euler_angles[2], euler_angles[1], euler_angles[0] + except ImportError: + rx = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2]) + ry = np.arctan2( + -rotation_matrix[2, 0], + np.sqrt(rotation_matrix[2, 1] ** 2 + rotation_matrix[2, 2] ** 2), + ) + rz = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0]) + + return [position[0], position[1], position[2], rx, ry, rz] + + def _initialize_manual_sync(self, transition: EnvTransition): + """Initialize manual position synchronization.""" + self._follower_reference_positions = self._get_follower_joint_positions( + transition + ) + + if self._follower_reference_positions: + print("\n" + "=" * 80) + print("LEADER ARM INTERVENTION: MANUAL POSITION SYNCHRONIZATION REQUIRED") + print("=" * 80) + print("Follower arm current positions:") + for joint, pos in self._follower_reference_positions.items(): + tolerance = self.sync_tolerances.get(joint, 10.0) + print(f" {joint}: {pos:7.2f}° (Tolerance: ±{tolerance}°)") + + print("\nINSTRUCTIONS:") + print("1. Manually move LEADER arm to match FOLLOWER positions above") + print("2. Keep leader arm STEADY when green checkmarks appear") + print("3. System will auto-detect when synchronization is complete") + print("4. Real-time joint differences will be displayed below") + print("=" * 80 + "\n") + + def _reset_sync_state(self): + """Reset synchronization state when intervention ends.""" + self._is_position_synced = False + self._follower_reference_positions = None + self._leader_base_positions = None + self._leader_base_ee_pose = None + self._last_leader_ee_pose = None + self._stable_frames_count = 0 + self._sync_start_time = None + + def _get_follower_joint_positions(self, transition: EnvTransition) -> dict: + """Extract follower robot joint positions from observation.""" + observation = transition.get(TransitionKey.OBSERVATION, {}) + follower_positions = {} + for joint_name in [ + "shoulder_pan.pos", + "shoulder_lift.pos", + "elbow_flex.pos", + "wrist_flex.pos", + "wrist_roll.pos", + "gripper.pos", + ]: + if joint_name in observation: + follower_positions[joint_name] = observation[joint_name] + return follower_positions + + def _check_manual_sync_complete( + self, current_leader_positions: dict, transition: EnvTransition + ) -> bool: + """ + Check synchronization with enhanced movement detection using EE pose. + Display joint position differences for user guidance. + """ + if self._follower_reference_positions is None: + return True + + # Check synchronization timeout + if self._sync_start_time and (time.time() - self._sync_start_time) > 120: + print("LEADER ARM INTERVENTION: Synchronization timeout") + return True + + # Calculate current leader EE pose + current_leader_ee_pose = self._compute_ee_pose_from_joints( + current_leader_positions + ) + + # Check if leader arm is stable (not moving significantly) + is_stable = self._check_leader_stability( + current_leader_ee_pose, current_leader_positions + ) + + # Check joint-based synchronization and display differences + joints_synced = self._check_joint_based_sync_with_display(current_leader_positions) + + # Sync complete only when joints are synced AND leader is stable + sync_complete = joints_synced and is_stable + + if self._debug_frame_count % 10 == 1: + print( + f"SYNC STATUS: joints_synced={joints_synced}, stable={is_stable}, stable_frames={self._stable_frames_count}/{self._stable_frames_required}" + ) + + return sync_complete + + def _check_joint_based_sync_with_display(self, current_leader_positions: dict) -> bool: + """Check joint-based synchronization with tolerance and display differences.""" + if self._follower_reference_positions is None: + return True + + all_within_tolerance = True + + # Display header every 30 frames + if self._debug_frame_count % 60 == 1: + print("\n" + "="*80) + print("JOINT SYNCHRONIZATION STATUS") + print("="*80) + print(f"{'Joint':<20} {'Follower':<10} {'Leader':<10} {'Diff':<10} {'Tolerance':<10} {'Status':<10}") + print("-"*80) + + for joint_name, ref_pos in self._follower_reference_positions.items(): + if joint_name in current_leader_positions: + current_pos = current_leader_positions[joint_name] + error = abs(current_pos - ref_pos) + tolerance = self.sync_tolerances.get(joint_name, 10.0) + within_tolerance = error <= tolerance + + if not within_tolerance: + all_within_tolerance = False + + # Display status every frame for user feedback + status = "✓ OK" if within_tolerance else "✗ NEED MOVE" + print(f"{joint_name:<20} {ref_pos:>9.1f}° {current_pos:>9.1f}° {error:>9.1f}° {tolerance:>9.1f}° {status:>10}") + + # Display summary every 30 frames + if self._debug_frame_count % 60 == 1: + print("="*80) + if all_within_tolerance: + print("ALL JOINTS WITHIN TOLERANCE - Keep arm steady to complete sync") + else: + print("ADJUST LEADER ARM to match follower positions above") + print("="*80) + + return all_within_tolerance + + def _check_leader_stability( + self, current_ee_pose: np.ndarray, current_leader_positions: dict + ) -> bool: + """Check if leader arm is stable.""" + if self._last_leader_ee_pose is None: + self._last_leader_ee_pose = current_ee_pose + return False + + current_pos = current_ee_pose[:3, 3] + last_pos = self._last_leader_ee_pose[:3, 3] + pos_change = np.linalg.norm(current_pos - last_pos) + + current_rot = current_ee_pose[:3, :3] + last_rot = self._last_leader_ee_pose[:3, :3] + rot_change = np.linalg.norm(current_rot - last_rot) + + current_gripper = current_leader_positions.get("gripper.pos", 0.0) + last_gripper = ( + self._leader_base_positions.get("gripper.pos", 0.0) + if self._leader_base_positions + else current_gripper + ) + gripper_change = abs(current_gripper - last_gripper) + + is_stable = ( + pos_change < self.position_threshold + and rot_change < self.orientation_threshold + and gripper_change < self.gripper_threshold + ) + + if is_stable: + self._stable_frames_count += 1 + else: + self._stable_frames_count = 0 + + self._last_leader_ee_pose = current_ee_pose + + return self._stable_frames_count >= self._stable_frames_required + + def _check_joint_based_sync(self, current_leader_positions: dict) -> bool: + """Check joint-based synchronization with tolerance.""" + if self._follower_reference_positions is None: + return True + + all_within_tolerance = True + for joint_name, ref_pos in self._follower_reference_positions.items(): + if joint_name in current_leader_positions: + current_pos = current_leader_positions[joint_name] + error = abs(current_pos - ref_pos) + tolerance = self.sync_tolerances.get(joint_name, 10.0) + if error > tolerance: + all_within_tolerance = False + + return all_within_tolerance + + def get_config(self) -> dict[str, Any]: + return { + "use_gripper": self.use_gripper, + "terminate_on_success": self.terminate_on_success, + "sync_tolerances": self.sync_tolerances, + "position_threshold": self.position_threshold, + "orientation_threshold": self.orientation_threshold, + "gripper_threshold": self.gripper_threshold, + "stable_frames_required": self._stable_frames_required, + "has_kinematics_solver": self.kinematics_solver is not None, + } + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + @dataclass @ProcessorStepRegistry.register("reward_classifier_processor") class RewardClassifierProcessorStep(ProcessorStep): diff --git a/src/lerobot/processor/joint_action_processor.py b/src/lerobot/processor/joint_action_processor.py new file mode 100644 index 0000000000..2f4339ba50 --- /dev/null +++ b/src/lerobot/processor/joint_action_processor.py @@ -0,0 +1,104 @@ +import numpy as np +import torch + +from dataclasses import dataclass, field + +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature + +from .core import PolicyAction, RobotAction, TransitionKey, EnvTransition +from .pipeline import ( + ActionProcessorStep, + ProcessorStepRegistry, + RobotActionProcessorStep, + ProcessorStep, +) + + +@ProcessorStepRegistry.register("direct_joint_control") +@dataclass +class DirectJointControlStep(ProcessorStep): + """Process direct joint control commands from leader arm.""" + + motor_names: list[str] = field(default_factory=list) + use_gripper: bool = False + + def __call__(self, transition: EnvTransition) -> EnvTransition: + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + leader_joint_positions = complementary_data.get("leader_joint_positions") + + if leader_joint_positions is not None: + # Create robot action from leader joint positions + robot_action = {} + + # Handle arm joints + for i, motor_name in enumerate(self.motor_names): + if i < len(leader_joint_positions): + if isinstance(leader_joint_positions, torch.Tensor): + robot_action[f"{motor_name}.pos"] = leader_joint_positions[ + i + ].item() + else: + robot_action[f"{motor_name}.pos"] = float( + leader_joint_positions[i] + ) + + # Handle gripper if used + if self.use_gripper: + gripper_index = len(self.motor_names) + if ( + isinstance(leader_joint_positions, (list, tuple)) + and len(leader_joint_positions) > gripper_index + ): + if isinstance(leader_joint_positions, torch.Tensor): + robot_action["gripper.pos"] = leader_joint_positions[ + gripper_index + ].item() + else: + robot_action["gripper.pos"] = float( + leader_joint_positions[gripper_index] + ) + + # Store the robot action + complementary_data["robot_action"] = robot_action + transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data + + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # This step doesn't change the feature definitions + # It only processes complementary data, so we return features as-is + return features + + +@ProcessorStepRegistry.register("joint_bounds_and_safety") +@dataclass +class JointBoundsAndSafetyStep(ActionProcessorStep): + """Apply joint bounds and safety checks for direct joint control.""" + + joint_bounds: dict = field(default_factory=dict) + + def action(self, action: RobotAction) -> RobotAction: + if not isinstance(action, dict): + return action + + # Apply joint bounds if specified + bounded_action = action.copy() + for joint_name, action_value in action.items(): + if joint_name in self.joint_bounds: + bounds = self.joint_bounds[joint_name] + min_bound = bounds.get("min", -180.0) + max_bound = bounds.get("max", 180.0) + bounded_action[joint_name] = np.clip(action_value, min_bound, max_bound) + elif ".pos" in joint_name: + # Default safety bounds for joint positions + bounded_action[joint_name] = np.clip(action_value, -175.0, 175.0) + + return bounded_action + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # This step doesn't change the feature definitions + return features diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index d22d8fb96e..e028ca1746 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -55,39 +55,66 @@ class VanillaObservationProcessorStep(ObservationProcessorStep): def _process_single_image(self, img: np.ndarray) -> Tensor: """ Processes a single NumPy image array into a channel-first, normalized tensor. + Supports both 2D (depth/grayscale) and 3D (RGB) images. Args: - img: A NumPy array representing the image, expected to be in channel-last - (H, W, C) format with a `uint8` dtype. + img: A NumPy array representing the image. Can be: + - 2D array (H, W) for depth/grayscale images + - 3D array (H, W, C) for RGB images with `uint8` dtype Returns: A `float32` PyTorch tensor in channel-first (B, C, H, W) format, with - pixel values normalized to the [0, 1] range. + pixel values normalized appropriately: + - [0, 1] range for uint8 images + - Original scale for depth maps (preserving metric values) Raises: - ValueError: If the input image does not appear to be in channel-last - format or is not of `uint8` dtype. + ValueError: If the input image has invalid dimensions or format. """ # Convert to tensor img_tensor = torch.from_numpy(img) - # Add batch dimension if needed - if img_tensor.ndim == 3: + # Handle 2D images (depth maps, grayscale) + if img_tensor.ndim == 2: + # Add channel and batch dimensions: (H, W) -> (1, 1, H, W) + img_tensor = img_tensor.unsqueeze(0).unsqueeze(0) # (B=1, C=1, H, W) + + # For depth maps, preserve original values (don't normalize to [0,1]) + # Depth maps typically have meaningful metric values + if img.dtype in [np.float32, np.float64, np.uint16]: + # Keep depth values as-is for metric preservation + img_tensor = img_tensor.type(torch.float32) + else: + # For other 2D images (grayscale), normalize to [0,1] + img_tensor = img_tensor.type(torch.float32) / 255.0 + + # Handle 3D images (RGB, channel-last) + elif img_tensor.ndim == 3: + # Add batch dimension: (H, W, C) -> (1, H, W, C) img_tensor = img_tensor.unsqueeze(0) - # Validate image format - _, h, w, c = img_tensor.shape - if not (c < h and c < w): - raise ValueError(f"Expected channel-last images, but got shape {img_tensor.shape}") + # Validate image format + _, h, w, c = img_tensor.shape + if not (c < h and c < w): + raise ValueError( + f"Expected channel-last images, but got shape {img_tensor.shape}" + ) + + if img_tensor.dtype != torch.uint8: + raise ValueError( + f"Expected torch.uint8 for RGB images, but got {img_tensor.dtype}" + ) - if img_tensor.dtype != torch.uint8: - raise ValueError(f"Expected torch.uint8 images, but got {img_tensor.dtype}") + # Convert to channel-first format: (B, H, W, C) -> (B, C, H, W) + img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous() - # Convert to channel-first format - img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous() + # Convert to float32 and normalize to [0, 1] + img_tensor = img_tensor.type(torch.float32) / 255.0 - # Convert to float32 and normalize to [0, 1] - img_tensor = img_tensor.type(torch.float32) / 255.0 + else: + raise ValueError( + f"Unsupported image dimensions: {img_tensor.ndim}. Expected 2D or 3D array." + ) return img_tensor diff --git a/src/lerobot/processor/policy_robot_bridge.py b/src/lerobot/processor/policy_robot_bridge.py index 25887d414e..d46f5a5f81 100644 --- a/src/lerobot/processor/policy_robot_bridge.py +++ b/src/lerobot/processor/policy_robot_bridge.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from typing import Any import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature -from lerobot.processor import ActionProcessorStep, PolicyAction, ProcessorStepRegistry, RobotAction +from lerobot.processor import ActionProcessorStep, PolicyAction, ProcessorStepRegistry, RobotAction, ProcessorStep, EnvTransition, TransitionKey from lerobot.utils.constants import ACTION @@ -67,3 +67,51 @@ def transform_features(self, features): type=FeatureType.ACTION, shape=(1,) ) return features + + +@ProcessorStepRegistry.register("direct_joint_to_policy_action") +@dataclass +class DirectJointToPolicyActionProcessorStep(ProcessorStep): + """Convert direct joint control to policy action.""" + + motor_names: list[str] = field(default_factory=list) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + + # Check if we have direct joint control action + robot_action = complementary_data.get("robot_action") + if robot_action is not None: + # Use direct joint control action + action_values = [] + + # Add arm joint positions + for motor_name in self.motor_names: + joint_key = f"{motor_name}.pos" + if joint_key in robot_action: + action_values.append(robot_action[joint_key]) + else: + action_values.append(0.0) # Default value + + # Add gripper if present + if "gripper.pos" in robot_action: + action_values.append(robot_action["gripper.pos"]) + elif "gripper" in robot_action: # Fallback to non-.pos format + action_values.append(robot_action["gripper"]) + + # Convert to tensor + action_tensor = torch.tensor(action_values, dtype=torch.float32) + transition[TransitionKey.ACTION] = action_tensor + + # Store control mode in info for debugging + info = transition.get(TransitionKey.INFO, {}) + info["control_mode"] = "direct_joint" + transition[TransitionKey.INFO] = info + + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # This step doesn't change the feature definitions + return features diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index ad36f1b364..c8f3b75b24 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -26,6 +26,7 @@ from lerobot.cameras import opencv # noqa: F401 from lerobot.configs import parser from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.utils import hw_to_dataset_features from lerobot.envs.configs import HILSerlRobotEnvConfig from lerobot.model.kinematics import RobotKinematics from lerobot.processor import ( @@ -34,17 +35,23 @@ AddTeleopEventsAsInfoStep, DataProcessorPipeline, DeviceProcessorStep, + DirectJointControlStep, EnvTransition, GripperPenaltyProcessorStep, ImageCropResizeProcessorStep, InterventionActionProcessorStep, JointVelocityProcessorStep, + JointBoundsAndSafetyStep, + LeaderArmInterventionProcessorStep, MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep, + MapTensorTo7DDeltaActionDictStep, + Map7DDeltaActionToRobotActionStep, MotorCurrentProcessorStep, Numpy2TorchActionProcessorStep, RewardClassifierProcessorStep, RobotActionToPolicyActionProcessorStep, + DirectJointToPolicyActionProcessorStep, TimeLimitProcessorStep, Torch2NumpyActionProcessorStep, TransitionKey, @@ -56,6 +63,7 @@ RobotConfig, make_robot_from_config, so100_follower, + so101_follower, ) from lerobot.robots.robot import Robot from lerobot.robots.so100_follower.robot_kinematic_processor import ( @@ -75,7 +83,10 @@ from lerobot.teleoperators.utils import TeleopEvents from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD from lerobot.utils.robot_utils import busy_wait -from lerobot.utils.utils import log_say +from lerobot.utils.utils import log_say, get_shape_as_tuple +from lerobot.utils.control_utils import ( + sanity_check_dataset_robot_compatibility, +) logging.basicConfig(level=logging.INFO) @@ -100,6 +111,7 @@ class GymManipulatorConfig: dataset: DatasetConfig mode: str | None = None # Either "record", "replay", None device: str = "cpu" + resume: bool = False # Resume recording on an existing dataset. def reset_follower_position(robot_arm: Robot, target_position: np.ndarray) -> None: @@ -151,7 +163,7 @@ def __init__( self.episode_data = None self._joint_names = [f"{key}.pos" for key in self.robot.bus.motors] - self._image_keys = self.robot.cameras.keys() + self._image_keys = self.robot.image_keys self.reset_pose = reset_pose self.reset_time_s = reset_time_s @@ -462,41 +474,26 @@ def make_processors( action_pipeline_steps = [ AddTeleopActionAsComplimentaryDataStep(teleop_device=teleop_device), AddTeleopEventsAsInfoStep(teleop_device=teleop_device), - InterventionActionProcessorStep( - use_gripper=cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else False, + LeaderArmInterventionProcessorStep( + use_gripper=( + cfg.processor.gripper.use_gripper + if cfg.processor.gripper is not None + else False + ), terminate_on_success=terminate_on_success, + kinematics_solver=kinematics_solver, + motor_names=motor_names, ), - ] - - # Replace InverseKinematicsProcessor with new kinematic processors - if cfg.processor.inverse_kinematics is not None and kinematics_solver is not None: - # Add EE bounds and safety processor - inverse_kinematics_steps = [ - MapTensorToDeltaActionDictStep( - use_gripper=cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else False - ), - MapDeltaActionToRobotActionStep(), - EEReferenceAndDelta( - kinematics=kinematics_solver, - end_effector_step_sizes=cfg.processor.inverse_kinematics.end_effector_step_sizes, - motor_names=motor_names, - use_latched_reference=False, - use_ik_solution=True, - ), - EEBoundsAndSafety( - end_effector_bounds=cfg.processor.inverse_kinematics.end_effector_bounds, - ), - GripperVelocityToJoint( - clip_max=cfg.processor.max_gripper_pos, - speed_factor=1.0, - discrete_gripper=True, + DirectJointControlStep( + motor_names=motor_names, + use_gripper=( + cfg.processor.gripper.use_gripper + if cfg.processor.gripper is not None + else False ), - InverseKinematicsRLStep( - kinematics=kinematics_solver, motor_names=motor_names, initial_guess_current_joints=False - ), - ] - action_pipeline_steps.extend(inverse_kinematics_steps) - action_pipeline_steps.append(RobotActionToPolicyActionProcessorStep(motor_names=motor_names)) + ), + DirectJointToPolicyActionProcessorStep(motor_names=motor_names), + ] return DataProcessorPipeline( steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition @@ -505,6 +502,113 @@ def make_processors( ) +def debug_observation_structure(transition: EnvTransition) -> dict: + """ + Analyze and print the complete structure of observation data. + Returns a summary dictionary for further analysis. + """ + observation = transition.get(TransitionKey.OBSERVATION, {}) + + print(f"\n=== OBSERVATION STRUCTURE DEBUG ===") + print(f"Total observation keys: {len(observation)}") + + # Print all keys and their properties + for key, value in observation.items(): + print(f"\n--- {key} ---") + + # Handle different data types + if isinstance(value, (torch.Tensor, np.ndarray)): + print(f" Type: {type(value)}") + print(f" Shape: {value.shape}") + print(f" Dtype: {value.dtype}") + if hasattr(value, "numel"): + print(f" Num elements: {value.numel()}") + + # Print sample values for small tensors + if hasattr(value, "numel") and value.numel() <= 10: + print( + f" Values: {value.tolist() if hasattr(value, 'tolist') else value}" + ) + elif ( + hasattr(value, "shape") + and len(value.shape) == 1 + and value.shape[0] <= 10 + ): + print( + f" Values: {value.tolist() if hasattr(value, 'tolist') else value}" + ) + + elif isinstance(value, dict): + print(f" Type: dict with keys: {list(value.keys())}") + # Print first few items if it's a nested dict + for sub_key, sub_value in list(value.items())[:3]: + print(f" {sub_key}: {type(sub_value)}") + + else: + print(f" Type: {type(value)}") + print(f" Value: {value}") + + # Categorize keys + image_keys = [k for k in observation.keys() if "image" in k.lower()] + state_keys = [k for k in observation.keys() if "state" in k.lower()] + joint_keys = [ + k + for k in observation.keys() + if any(term in k.lower() for term in ["joint", "qpos", "angle", "position"]) + ] + ee_keys = [ + k + for k in observation.keys() + if any(term in k.lower() for term in ["ee", "end_effector", "tool", "tcp"]) + ] + + print(f"\n=== CATEGORIZED KEYS ===") + print(f"Image keys: {image_keys}") + print(f"State keys: {state_keys}") + print(f"Joint keys: {joint_keys}") + print(f"EE keys: {ee_keys}") + + # Analyze observation.state in detail if it exists + if "observation.state" in observation: + state = observation["observation.state"] + print(f"\n=== DETAILED observation.state ANALYSIS ===") + if isinstance(state, torch.Tensor): + state_size = state.numel() + print(f"State vector size: {state_size}") + + # Try to interpret the state vector + if state_size == 18: # Based on your dataset features + print("State appears to be 18-dimensional") + print("Possible interpretation:") + print(" Elements 0-3: EE position (x,y,z) + gripper?") + print(" Elements 4-9: 6 joint positions?") + print(" Elements 10-17: Other state information (velocity, etc.)?") + + # Print actual values + state_list = state.tolist() + print(f"First 10 elements: {state_list[:10]}") + + elif state_size == 6: # Possibly just joint positions + print("State appears to be 6-dimensional (possibly joint positions)") + print(f"Values: {state.tolist()}") + + else: + print(f"Unknown state structure with {state_size} elements") + print(f"All values: {state.tolist()}") + + print("=== END OBSERVATION DEBUG ===\n") + + # Return summary for programmatic use + return { + "all_keys": list(observation.keys()), + "image_keys": image_keys, + "state_keys": state_keys, + "joint_keys": joint_keys, + "ee_keys": ee_keys, + "has_observation_state": "observation.state" in observation, + } + + def step_env_and_process_transition( env: gym.Env, transition: EnvTransition, @@ -531,6 +635,7 @@ def step_env_and_process_transition( transition[TransitionKey.OBSERVATION] = ( env.get_raw_joint_positions() if hasattr(env, "get_raw_joint_positions") else {} ) + processed_action_transition = action_processor(transition) processed_action = processed_action_transition[TransitionKey.ACTION] @@ -601,7 +706,30 @@ def control_loop( if cfg.mode == "record": action_features = teleop_device.action_features features = { - ACTION: action_features, + "action": { + "dtype": "float32", + "shape": (6,), + "names": [ + "shoulder_pan", + "shoulder_lift", + "elbow_flex", + "wrist_flex", + "wrist_roll", + "gripper", + ], + }, + "leader_joint_positions": { + "dtype": "float32", + "shape": (6,), + "names": [ + "shoulder_pan", + "shoulder_lift", + "elbow_flex", + "wrist_flex", + "wrist_roll", + "gripper", + ], + }, REWARD: {"dtype": "float32", "shape": (1,), "names": None}, DONE: {"dtype": "bool", "shape": (1,), "names": None}, } @@ -616,26 +744,33 @@ def control_loop( if key == OBS_STATE: features[key] = { "dtype": "float32", - "shape": value.squeeze(0).shape, + "shape": get_shape_as_tuple(value), "names": None, } if "image" in key: features[key] = { "dtype": "video", - "shape": value.squeeze(0).shape, + "shape": get_shape_as_tuple(value), "names": ["channels", "height", "width"], } - # Create dataset - dataset = LeRobotDataset.create( - cfg.dataset.repo_id, - cfg.env.fps, - root=cfg.dataset.root, - use_videos=True, - image_writer_threads=4, - image_writer_processes=0, - features=features, - ) + if cfg.resume: + dataset = LeRobotDataset( + cfg.dataset.repo_id, + root=cfg.dataset.root, + ) + sanity_check_dataset_robot_compatibility(dataset, None, cfg.env.fps, features) + else: + # Create dataset + dataset = LeRobotDataset.create( + cfg.dataset.repo_id, + cfg.env.fps, + root=cfg.dataset.root, + use_videos=True, + image_writer_threads=4, + image_writer_processes=0, + features=features, + ) episode_idx = 0 episode_step = 0 @@ -670,9 +805,24 @@ def control_loop( action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get( "teleop_action", transition[TransitionKey.ACTION] ) + + # Get leader joint positions from complementary data + leader_joint_positions = transition[TransitionKey.COMPLEMENTARY_DATA].get( + "leader_joint_positions", None + ) + if leader_joint_positions is None: + leader_joint_positions_tensor = torch.zeros(6) # 6 joints including gripper + else: + leader_joint_positions_tensor = ( + leader_joint_positions.cpu() + if isinstance(leader_joint_positions, torch.Tensor) + else torch.tensor(leader_joint_positions) + ) + frame = { **observations, ACTION: action_to_record.cpu(), + "leader_joint_positions": leader_joint_positions_tensor, REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32), DONE: np.array([terminated or truncated], dtype=bool), } @@ -753,6 +903,7 @@ def main(cfg: GymManipulatorConfig) -> None: """Main entry point for gym manipulator script.""" env, teleop_device = make_robot_env(cfg.env) env_processor, action_processor = make_processors(env, teleop_device, cfg.env, cfg.device) + # Full processor pipeline for real robot environment print("Environment observation space:", env.observation_space) print("Environment action space:", env.action_space) diff --git a/src/lerobot/robots/robot.py b/src/lerobot/robots/robot.py index 5e88b915b1..5f8418ec5f 100644 --- a/src/lerobot/robots/robot.py +++ b/src/lerobot/robots/robot.py @@ -183,3 +183,13 @@ def send_action(self, action: dict[str, Any]) -> dict[str, Any]: def disconnect(self) -> None: """Disconnect from the robot and perform any necessary cleanup.""" pass + + @property + def image_keys(self) -> list[str]: + """Return the keys of available camera images.""" + cameras = getattr(getattr(self, "config", None), "cameras", None) + + if cameras is None: + raise RuntimeError("Camera configuration not properly initialized") + + return list(cameras.keys()) diff --git a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py index 87e832db6e..00b5679c11 100644 --- a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py +++ b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py @@ -548,10 +548,10 @@ def __call__(self, transition: EnvTransition) -> EnvTransition: wz = action.pop("ee.wz") gripper_pos = action.pop("ee.gripper_pos") - if None in (x, y, z, wx, wy, wz, gripper_pos): - raise ValueError( - "Missing required end-effector pose components: ee.x, ee.y, ee.z, ee.wx, ee.wy, ee.wz, ee.gripper_pos must all be present in action" - ) + # 🚨 添加输入debug + print(f"IK INPUT DEBUG:") + print(f" Target EE - pos: [{x:.3f}, {y:.3f}, {z:.3f}]") + print(f" Target EE - rot: [{wx:.3f}, {wy:.3f}, {wz:.3f}]") observation = new_transition.get(TransitionKey.OBSERVATION).copy() if observation is None: @@ -561,22 +561,41 @@ def __call__(self, transition: EnvTransition) -> EnvTransition: [float(v) for k, v in observation.items() if isinstance(k, str) and k.endswith(".pos")], dtype=float, ) - if q_raw is None: - raise ValueError("Joints observation is require for computing robot kinematics") + + # 🚨 添加当前关节状态debug + print(f" Current follower joints: {[f'{q:.1f}' for q in q_raw]}") if self.initial_guess_current_joints: # Use current joints as initial guess self.q_curr = q_raw + print(f" Initial guess: CURRENT joints") else: # Use previous ik solution as initial guess if self.q_curr is None: self.q_curr = q_raw + print(f" Initial guess: PREVIOUS IK solution") # Build desired 4x4 transform from pos + rotvec (twist) t_des = np.eye(4, dtype=float) t_des[:3, :3] = Rotation.from_rotvec([wx, wy, wz]).as_matrix() t_des[:3, 3] = [x, y, z] + # 🚨 添加初始猜测debug + print(f" Initial guess joints: {[f'{q:.1f}' for q in self.q_curr]}") + # Compute inverse kinematics q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des) + + # 🚨 添加IK结果debug + print(f" IK solution: {[f'{q:.1f}' for q in q_target]}") + print(f" Joint changes: {[f'{(q_target[i]-q_raw[i]):.1f}' for i in range(len(q_target))]}") + + # 🚨 验证FK一致性 + fk_verification = self.kinematics.forward_kinematics(q_target) + fk_pos = fk_verification[:3, 3] + pos_error = np.linalg.norm(fk_pos - np.array([x, y, z])) + print(f" FK verification error: {pos_error:.4f}m") + + print("=" * 50) + self.q_curr = q_target # TODO: This is sentitive to order of motor_names = q_target mapping diff --git a/src/lerobot/robots/so101_follower/config_so101_follower.py b/src/lerobot/robots/so101_follower/config_so101_follower.py index 03c3530c2f..1d824b9918 100644 --- a/src/lerobot/robots/so101_follower/config_so101_follower.py +++ b/src/lerobot/robots/so101_follower/config_so101_follower.py @@ -39,3 +39,9 @@ class SO101FollowerConfig(RobotConfig): # Set to `True` for backward compatibility with previous policies/dataset use_degrees: bool = False + + urdf_path: str = field( + default="./src/lerobot/robots/so101_follower/sim/so101_new_calib.urdf", metadata=dict(help="Path to URDF file") + ) + + target_frame_name: str = field(default="gripper_frame_link", metadata=dict(help="Target frame name for kinematics")) diff --git a/src/lerobot/robots/so101_follower/so101_follower.py b/src/lerobot/robots/so101_follower/so101_follower.py index acfd4bd114..314c2eb2c7 100644 --- a/src/lerobot/robots/so101_follower/so101_follower.py +++ b/src/lerobot/robots/so101_follower/so101_follower.py @@ -64,11 +64,28 @@ def __init__(self, config: SO101FollowerConfig): def _motors_ft(self) -> dict[str, type]: return {f"{motor}.pos": float for motor in self.bus.motors} + @property + def image_keys(self) -> list[str]: + return list(self._cameras_ft.keys()) + @property def _cameras_ft(self) -> dict[str, tuple]: - return { - cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras - } + """Generate camera features with multi-modal support.""" + features = {} + + for cam_name, camera_config in self.config.cameras.items(): + # Depth feature if enabled + if hasattr(camera_config, "use_depth") and camera_config.use_depth: + features[f"{cam_name}_depth"] = ( + camera_config.height, + camera_config.width, + ) + features[f"{cam_name}_rgb"] = (camera_config.height, camera_config.width, 3) + else: + # RGB only + features[cam_name] = (camera_config.height, camera_config.width, 3) + + return features @cached_property def observation_features(self) -> dict[str, type | tuple]: @@ -184,7 +201,13 @@ def get_observation(self) -> dict[str, Any]: # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + images = cam.async_read() + if len(images) == 1: + modality, image = next(iter(images.items())) + obs_dict[cam_key] = image + elif len(images) > 1: + for modality, image in images.items(): + obs_dict[f"{cam_key}_{modality}"] = image dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/scripts/lerobot_find_cameras.py b/src/lerobot/scripts/lerobot_find_cameras.py index e17dca8055..732201fe52 100644 --- a/src/lerobot/scripts/lerobot_find_cameras.py +++ b/src/lerobot/scripts/lerobot_find_cameras.py @@ -42,6 +42,10 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.cameras.realsense.camera_realsense import RealSenseCamera from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig +from lerobot.cameras.utils import get_image_modality_key +from lerobot.cameras.zed.camera_zed import ZedCamera + +from lerobot.cameras.zed.camera_zed import ZedCameraConfig logger = logging.getLogger(__name__) @@ -87,13 +91,34 @@ def find_all_realsense_cameras() -> list[dict[str, Any]]: return all_realsense_cameras_info +def find_all_zed_cameras() -> list[dict[str, Any]]: + """ + Finds all available ZED cameras plugged into the system. + + Returns: + A list of all available ZED cameras with their metadata. + """ + all_zed_cameras_info: list[dict[str, Any]] = [] + logger.info("Searching for ZED cameras...") + try: + zed_cameras = ZedCamera.find_cameras() + for cam_info in zed_cameras: + all_zed_cameras_info.append(cam_info) + logger.info(f"Found {len(zed_cameras)} ZED cameras.") + except ImportError: + logger.warning("Skipping ZED camera search: pyzed library not found or not importable.") + except Exception as e: + logger.error(f"Error finding ZED cameras: {e}") + + return all_zed_cameras_info + def find_and_print_cameras(camera_type_filter: str | None = None) -> list[dict[str, Any]]: """ Finds available cameras based on an optional filter and prints their information. Args: - camera_type_filter: Optional string to filter cameras ("realsense" or "opencv"). + camera_type_filter: Optional string to filter cameras ("realsense", "zed" or "opencv"). If None, lists all cameras. Returns: @@ -108,6 +133,8 @@ def find_and_print_cameras(camera_type_filter: str | None = None) -> list[dict[s all_cameras_info.extend(find_all_opencv_cameras()) if camera_type_filter is None or camera_type_filter == "realsense": all_cameras_info.extend(find_all_realsense_cameras()) + if camera_type_filter is None or camera_type_filter == "zed": + all_cameras_info.extend(find_all_zed_cameras()) if not all_cameras_info: if camera_type_filter: @@ -130,27 +157,111 @@ def find_and_print_cameras(camera_type_filter: str | None = None) -> list[dict[s def save_image( - img_array: np.ndarray, + image_data: np.ndarray | dict[str, np.ndarray], camera_identifier: str | int, images_dir: Path, camera_type: str, + modality: str | None = None, ): """ - Saves a single image to disk using Pillow. Handles color conversion if necessary. + Saves image data to disk using Pillow. Supports multiple modalities. + + Args: + image_data: Single image array or dictionary of modality-keyed images + camera_identifier: Unique identifier for the camera + images_dir: Directory where images will be saved + camera_type: Type of camera (e.g., 'zed', 'opencv') + modality: Explicit modality type. If None and image_data is a dict, + saves all modalities with automatic key detection. + + Note: + Supported modalities and their handling: + - 'gray': Grayscale images, saved as 8-bit PNG + - 'rgb': RGB images, saved as standard PNG + - 'rgba': RGBA images, saved as PNG with alpha + - 'depth': Depth maps, saved as 16-bit PNG + - 'ir': Infrared images, saved as 8-bit PNG """ try: - img = Image.fromarray(img_array, mode="RGB") - + # Handle dictionary input (multiple modalities) + if isinstance(image_data, dict): + futures = [] + for mod, img_array in image_data.items(): + # Recursively call save_image for each modality + future = save_image( + img_array, camera_identifier, images_dir, camera_type, mod + ) + if future: + futures.append(future) + return futures + + # Handle single image array + img_array = image_data safe_identifier = str(camera_identifier).replace("/", "_").replace("\\", "_") - filename_prefix = f"{camera_type.lower()}_{safe_identifier}" - filename = f"{filename_prefix}.png" + # Auto-detect modality if not explicitly provided + if modality is None: + modality = get_image_modality_key(img_array) + + # Process based on modality + if modality == "depth": + # Depth image processing + if img_array.dtype != np.uint16: + img_array = img_array.astype(np.uint16) + img = Image.fromarray(img_array, mode="I;16") + filename_prefix = f"{camera_type.lower()}_{safe_identifier}_depth" + + elif modality == "gray": + # Grayscale image processing + if img_array.dtype != np.uint8: + # Normalize to 0-255 range for 8-bit grayscale + if img_array.dtype in [np.float32, np.float64]: + img_array = (img_array * 255).astype(np.uint8) + else: + img_array = img_array.astype(np.uint8) + img = Image.fromarray(img_array, mode="L") + filename_prefix = f"{camera_type.lower()}_{safe_identifier}_gray" + + elif modality == "rgb": + # RGB image processing + if img_array.dtype != np.uint8: + img_array = img_array.astype(np.uint8) + img = Image.fromarray(img_array, mode="RGB") + filename_prefix = f"{camera_type.lower()}_{safe_identifier}_rgb" + + elif modality == "rgba": + # RGBA image processing + if img_array.dtype != np.uint8: + img_array = img_array.astype(np.uint8) + img = Image.fromarray(img_array, mode="RGBA") + filename_prefix = f"{camera_type.lower()}_{safe_identifier}_rgba" + + elif modality == "ir": + # Infrared image processing (similar to grayscale) + if img_array.dtype != np.uint8: + img_array = img_array.astype(np.uint8) + img = Image.fromarray(img_array, mode="L") + filename_prefix = f"{camera_type.lower()}_{safe_identifier}_ir" + + else: + # Fallback for unknown modalities + logger.warning(f"Unknown modality '{modality}', saving as RGB") + if img_array.dtype != np.uint8: + img_array = img_array.astype(np.uint8) + img = Image.fromarray(img_array, mode="RGB") + filename_prefix = f"{camera_type.lower()}_{safe_identifier}_{modality}" + + filename = f"{filename_prefix}.png" path = images_dir / filename path.parent.mkdir(parents=True, exist_ok=True) img.save(str(path)) - logger.info(f"Saved image: {path}") + logger.info(f"Saved {modality} image: {path}") + except Exception as e: - logger.error(f"Failed to save image for camera {camera_identifier} (type {camera_type}): {e}") + logger.error( + f"Failed to save image for camera {camera_identifier} " + f"(type {camera_type}, modality={modality}): {e}" + ) def create_camera_instance(cam_meta: dict[str, Any]) -> dict[str, Any] | None: @@ -174,6 +285,12 @@ def create_camera_instance(cam_meta: dict[str, Any]) -> dict[str, Any] | None: color_mode=ColorMode.RGB, ) instance = RealSenseCamera(rs_config) + elif cam_type == "ZED": + zed_config = ZedCameraConfig( + serial_number_or_name=cam_id, + color_mode=ColorMode.RGB, + ) + instance = ZedCamera(zed_config) else: logger.warning(f"Unknown camera type: {cam_type} for ID {cam_id}. Skipping.") return None @@ -197,25 +314,25 @@ def process_camera_image( meta = cam_dict["meta"] cam_type_str = str(meta.get("type", "unknown")) cam_id_str = str(meta.get("id", "unknown")) - + logger.info(f"{cam=}\n{meta=}") try: image_data = cam.read() - return save_image( image_data, cam_id_str, output_dir, cam_type_str, ) + except TimeoutError: logger.warning( f"Timeout reading from {cam_type_str} camera {cam_id_str} at time {current_time:.2f}s." ) except Exception as e: logger.error(f"Error reading from {cam_type_str} camera {cam_id_str}: {e}") + raise e return None - def cleanup_cameras(cameras_to_use: list[dict[str, Any]]): """Disconnect all cameras.""" logger.info(f"Disconnecting {len(cameras_to_use)} cameras...") @@ -245,11 +362,10 @@ def save_images_from_all_cameras( output_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Saving images to {output_dir}") all_camera_metadata = find_and_print_cameras(camera_type_filter=camera_type) - + print(f"{all_camera_metadata=}") if not all_camera_metadata: logger.warning("No cameras detected matching the criteria. Cannot save images.") return - cameras_to_use = [] for cam_meta in all_camera_metadata: camera_instance = create_camera_instance(cam_meta) @@ -296,8 +412,8 @@ def main(): type=str, nargs="?", default=None, - choices=["realsense", "opencv"], - help="Specify camera type to capture from (e.g., 'realsense', 'opencv'). Captures from all if omitted.", + choices=["realsense", "opencv", "zed"], + help="Specify camera type to capture from (e.g., 'realsense', 'opencv', 'zed'). Captures from all if omitted.", ) parser.add_argument( "--output-dir", diff --git a/src/lerobot/scripts/lerobot_find_joint_limits.py b/src/lerobot/scripts/lerobot_find_joint_limits.py index 07d57a7608..e51a8057ce 100644 --- a/src/lerobot/scripts/lerobot_find_joint_limits.py +++ b/src/lerobot/scripts/lerobot_find_joint_limits.py @@ -42,6 +42,7 @@ koch_follower, make_robot_from_config, so100_follower, + so101_follower, ) from lerobot.teleoperators import ( # noqa: F401 TeleoperatorConfig, @@ -49,6 +50,7 @@ koch_leader, make_teleoperator_from_config, so100_leader, + so101_leader, ) from lerobot.utils.robot_utils import busy_wait @@ -77,6 +79,7 @@ def find_joint_and_ee_bounds(cfg: FindJointLimitsConfig): # Note to be compatible with the rest of the codebase, # we are using the new calibration method for so101 and so100 robot_type = "so_new_calibration" + print(f"{cfg.robot.urdf_path=}") kinematics = RobotKinematics(cfg.robot.urdf_path, cfg.robot.target_frame_name) # Initialize min/max values diff --git a/src/lerobot/teleoperators/so101_leader/so101_leader.py b/src/lerobot/teleoperators/so101_leader/so101_leader.py index be804bf702..2689c83141 100644 --- a/src/lerobot/teleoperators/so101_leader/so101_leader.py +++ b/src/lerobot/teleoperators/so101_leader/so101_leader.py @@ -15,7 +15,11 @@ # limitations under the License. import logging +import os +import sys import time +from queue import Queue +from typing import Any from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.feetech import ( @@ -25,11 +29,28 @@ from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..teleoperator import Teleoperator +from ..utils import TeleopEvents from .config_so101_leader import SO101LeaderConfig logger = logging.getLogger(__name__) +PYNPUT_AVAILABLE = True +try: + if ("DISPLAY" not in os.environ) and ("linux" in sys.platform): + logging.info("No DISPLAY set. Skipping pynput import.") + raise ImportError("pynput blocked intentionally due to no display.") + + from pynput import keyboard +except ImportError: + keyboard = None + PYNPUT_AVAILABLE = False +except Exception as e: + keyboard = None + PYNPUT_AVAILABLE = False + logging.info(f"Could not import pynput: {e}") + + class SO101Leader(Teleoperator): """ SO-101 Leader Arm designed by TheRobotStudio and Hugging Face. @@ -54,6 +75,30 @@ def __init__(self, config: SO101LeaderConfig): }, calibration=self.calibration, ) + # Initialize keyboard event handling + self.misc_keys_queue = Queue() + self._setup_keyboard_listener() + + def _setup_keyboard_listener(self): + """Set up keyboard listener for teleoperation events.""" + if not PYNPUT_AVAILABLE: + logging.info("pynput not available - keyboard events will not work for SO101Leader") + return + + def on_press(key): + # Only process event keys, not movement keys + try: + print(f"key: {key} pressed") + if hasattr(key, "char") and key.char is not None: + if key.char in ["s", "r", "q"]: + self.misc_keys_queue.put(key) + elif key in [keyboard.Key.space, keyboard.Key.esc]: + self.misc_keys_queue.put(key) + except AttributeError: + pass + + self.keyboard_listener = keyboard.Listener(on_press=on_press) + self.keyboard_listener.start() @property def action_features(self) -> dict[str, type]: @@ -149,8 +194,91 @@ def send_feedback(self, feedback: dict[str, float]) -> None: raise NotImplementedError def disconnect(self) -> None: + """Disconnect the SO101 leader arm and clean up keyboard listener.""" + # Stop keyboard listener if running + if hasattr(self, "keyboard_listener") and self.keyboard_listener.is_alive(): + self.keyboard_listener.stop() + if not self.is_connected: - DeviceNotConnectedError(f"{self} is not connected.") + raise DeviceNotConnectedError(f"{self} is not connected.") self.bus.disconnect() - logger.info(f"{self} disconnected.") + logging.info(f"{self} disconnected.") + + def get_teleop_events(self) -> dict[str, Any]: + """ + Get extra control events from the keyboard for SO101 leader arm. + + Keyboard mappings for events (based on official tutorial): + - 's' key = success (terminate episode successfully) + - 'r' key = rerecord episode (terminate and rerecord) + - 'q' key = quit episode (terminate without success) + - 'space' key = intervention (take over/give back control) + - 'esc' key = failure/terminate episode + + Note: SO101 leader arm uses hardware for motion control, keyboard is only used for events. + + Returns: + Dictionary containing: + - is_intervention: bool - Whether human is currently intervening + - terminate_episode: bool - Whether to terminate the current episode + - success: bool - Whether the episode was successful + - rerecord_episode: bool - Whether to rerecord the episode + """ + if not self.is_connected: + return { + TeleopEvents.IS_INTERVENTION: False, + TeleopEvents.TERMINATE_EPISODE: False, + TeleopEvents.SUCCESS: False, + TeleopEvents.RERECORD_EPISODE: False, + } + + # Initialize event states + is_intervention = self._is_intervention_active # Use persisted state + terminate_episode = False + success = False + rerecord_episode = False + + # Process any pending misc keys from the queue + while not self.misc_keys_queue.empty(): + try: + key = self.misc_keys_queue.get_nowait() + print(f"SO101Leader: Processing key: {key}") + + # Handle character keys + if hasattr(key, "char") and key.char: + if key.char == "s": + # Success - terminate episode successfully + success = True + terminate_episode = True + elif key.char == "r": + # Rerecord - terminate and rerecord episode + rerecord_episode = True + terminate_episode = True + elif key.char == "q": + # Quit - terminate without success + terminate_episode = True + success = False + + # Handle special keys + elif key == keyboard.Key.esc: + # ESC - terminate episode (failure) + terminate_episode = True + success = False + + elif key == keyboard.Key.space: + # Space - intervention (take over/give back control) + # Space - TOGGLE intervention (take over/give back control) + self._is_intervention_active = not self._is_intervention_active + is_intervention = self._is_intervention_active + + except Exception as e: + logging.debug(f"Error processing keyboard event: {e}") + continue + + return { + TeleopEvents.IS_INTERVENTION: is_intervention, + TeleopEvents.TERMINATE_EPISODE: terminate_episode, + TeleopEvents.SUCCESS: success, + TeleopEvents.RERECORD_EPISODE: rerecord_episode, + } diff --git a/src/lerobot/teleoperators/teleoperator.py b/src/lerobot/teleoperators/teleoperator.py index 95020a962f..1a15f3db58 100644 --- a/src/lerobot/teleoperators/teleoperator.py +++ b/src/lerobot/teleoperators/teleoperator.py @@ -53,6 +53,7 @@ def __init__(self, config: TeleoperatorConfig): self.calibration: dict[str, MotorCalibration] = {} if self.calibration_fpath.is_file(): self._load_calibration() + self._is_intervention_active = False # Track intervention state def __str__(self) -> str: return f"{self.id} {self.__class__.__name__}" diff --git a/src/lerobot/utils/control_utils.py b/src/lerobot/utils/control_utils.py index 17371921cd..da2adc1f56 100644 --- a/src/lerobot/utils/control_utils.py +++ b/src/lerobot/utils/control_utils.py @@ -215,7 +215,7 @@ def sanity_check_dataset_name(repo_id, policy_cfg): def sanity_check_dataset_robot_compatibility( - dataset: LeRobotDataset, robot: Robot, fps: int, features: dict + dataset: LeRobotDataset, robot: Robot | None, fps: int, features: dict ) -> None: """ Checks if a dataset's metadata is compatible with the current robot and recording setup. @@ -233,10 +233,11 @@ def sanity_check_dataset_robot_compatibility( ValueError: If any of the checked metadata fields do not match. """ fields = [ - ("robot_type", dataset.meta.robot_type, robot.robot_type), ("fps", dataset.fps, fps), ("features", dataset.features, {**features, **DEFAULT_FEATURES}), ] + if robot is not None: + fields.append(("robot_type", dataset.meta.robot_type, robot.robot_type)) mismatches = [] for field, dataset_value, present_value in fields: diff --git a/src/lerobot/utils/rotation.py b/src/lerobot/utils/rotation.py index 41b6529478..389dadb3fb 100644 --- a/src/lerobot/utils/rotation.py +++ b/src/lerobot/utils/rotation.py @@ -268,3 +268,185 @@ def __mul__(self, other: "Rotation") -> "Rotation": ) return Rotation(composed_quat) + + def as_euler(self, seq: str, degrees: bool = False) -> np.ndarray: + """ + Convert the rotation to Euler angles. + + Args: + seq: Axis sequence, e.g., "xyz", "zyx", "xzy", "yxz", "yzx", "zxy". + Only proper Tait-Bryan sequences with all distinct axes are supported. + degrees: If True, return angles in degrees; otherwise radians. + + Returns: + ndarray of shape (3,) with angles [a1, a2, a3] for the given sequence. + """ + seq = seq.lower() + valid = {"xyz", "xzy", "yxz", "yzx", "zxy", "zyx"} + if seq not in valid: + raise ValueError( + f"Unsupported euler sequence '{seq}'. Supported: {sorted(valid)}" + ) + + R = self.as_matrix() + r00, r01, r02 = R[0, 0], R[0, 1], R[0, 2] + r10, r11, r12 = R[1, 0], R[1, 1], R[1, 2] + r20, r21, r22 = R[2, 0], R[2, 1], R[2, 2] + + # For Tait-Bryan sequences (all axes different), formulas: + # xyz: + # sy = -r20 + # y = asin(sy) + # x = atan2(r21, r22) + # z = atan2(r10, r00) + # xzy: + # sz = r10 + # z = asin(sz) + # x = atan2(-r12, r11) + # y = atan2(-r20, r00) + # yxz: + # sx = r21 + # x = asin(sx) + # y = atan2(-r20, r22) + # z = atan2(-r01, r00) + # yzx: + # sz = -r01 + # z = asin(sz) + # y = atan2(r02, r00) + # x = atan2(r21, r11) + # zxy: + # sx = -r12 + # x = asin(sx) + # z = atan2(r10, r11) + # y = atan2(r02, r22) + # zyx: + # sy = -r02 + # y = asin(sy) + # z = atan2(r01, r00) + # x = atan2(r12, r22) + # + # Handle gimbal lock when |sin(mid)| ~ 1. + + eps = 1e-8 + + if seq == "xyz": + sy = -r20 + y = np.arcsin(np.clip(sy, -1.0, 1.0)) + if abs(sy) < 1 - eps: + x = np.arctan2(r21, r22) + z = np.arctan2(r10, r00) + else: + # Gimbal lock: z set to 0, solve x from r01/r02 + x = np.arctan2(-r12, r11) + z = 0.0 + + elif seq == "xzy": + sz = r10 + z = np.arcsin(np.clip(sz, -1.0, 1.0)) + if abs(sz) < 1 - eps: + x = np.arctan2(-r12, r11) + y = np.arctan2(-r20, r00) + else: + x = np.arctan2(r21, r22) + y = 0.0 + + elif seq == "yxz": + sx = r21 + x = np.arcsin(np.clip(sx, -1.0, 1.0)) + if abs(sx) < 1 - eps: + y = np.arctan2(-r20, r22) + z = np.arctan2(-r01, r00) + else: + y = np.arctan2(r02, r00) + z = 0.0 + + elif seq == "yzx": + sz = -r01 + z = np.arcsin(np.clip(sz, -1.0, 1.0)) + if abs(sz) < 1 - eps: + y = np.arctan2(r02, r00) + x = np.arctan2(r21, r11) + else: + y = np.arctan2(-r20, r22) + x = 0.0 + + elif seq == "zxy": + sx = -r12 + x = np.arcsin(np.clip(sx, -1.0, 1.0)) + if abs(sx) < 1 - eps: + z = np.arctan2(r10, r11) + y = np.arctan2(r02, r22) + else: + z = np.arctan2(-r01, r00) + y = 0.0 + + elif seq == "zyx": + sy = -r02 + y = np.arcsin(np.clip(sy, -1.0, 1.0)) + if abs(sy) < 1 - eps: + z = np.arctan2(r01, r00) + x = np.arctan2(r12, r22) + else: + z = np.arctan2(-r10, r11) + x = 0.0 + + angles = { + "xyz": np.array([x, y, z]), + "xzy": np.array([x, y, z]), + "yxz": np.array([y, x, z]), + "yzx": np.array([y, z, x]), + "zxy": np.array([z, x, y]), + "zyx": np.array([z, y, x]), + }[seq] + + if degrees: + angles = np.degrees(angles) + return angles + + @classmethod + def from_euler(cls, seq: str, angles, degrees: bool = False) -> "Rotation": + """ + Create rotation from Euler angles. + + Args: + seq: Axis sequence, e.g., "xyz", "zyx", "xzy", "yxz", "yzx", "zxy". + angles: Iterable of 3 angles [a1, a2, a3]. + degrees: If True, input angles are in degrees. + + Returns: + Rotation instance. + """ + seq = seq.lower() + valid = {"xyz", "xzy", "yxz", "yzx", "zxy", "zyx"} + if seq not in valid: + raise ValueError( + f"Unsupported euler sequence '{seq}'. Supported: {sorted(valid)}" + ) + + angles = np.asarray(angles, dtype=float).reshape(3) + if degrees: + angles = np.radians(angles) + + a1, a2, a3 = angles + + def Rx(a): + ca, sa = np.cos(a), np.sin(a) + return np.array([[1, 0, 0], [0, ca, -sa], [0, sa, ca]], dtype=float) + + def Ry(a): + ca, sa = np.cos(a), np.sin(a) + return np.array([[ca, 0, sa], [0, 1, 0], [-sa, 0, ca]], dtype=float) + + def Rz(a): + ca, sa = np.cos(a), np.sin(a) + return np.array([[ca, -sa, 0], [sa, ca, 0], [0, 0, 1]], dtype=float) + + axis_map = {"x": Rx, "y": Ry, "z": Rz} + R1 = axis_map[seq[0]](a1) + R2 = axis_map[seq[1]](a2) + R3 = axis_map[seq[2]](a3) + + # Active rotations applied in order: first a1 about seq[0], then a2, then a3 + R = R3 @ R2 @ R1 + return cls.from_matrix(R) + diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 8777d5a9db..06d024a47f 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -48,6 +48,14 @@ def auto_select_torch_device() -> torch.device: return torch.device("cpu") +def get_shape_as_tuple(tensor): + """Extract tensor shape and ensure it returns as a tuple.""" + shape = tensor.squeeze(0).shape + if isinstance(shape, torch.Size): + return tuple(shape) + return shape + + # TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: """Given a string, return a torch.device with checks on whether the device is available."""