diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index c8bc5049ec..8ac6f933df 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -17,6 +17,7 @@ import logging import shutil import tempfile +import torch from collections.abc import Callable from pathlib import Path @@ -1003,7 +1004,45 @@ def __getitem__(self, idx) -> dict: if self.image_transforms is not None: image_keys = self.meta.camera_keys for cam in image_keys: - item[cam] = self.image_transforms(item[cam]) + cam_val = item.get(cam, None) + if cam_val is None: + continue + + # Convert non-tensor camera values safely + if not isinstance(cam_val, torch.Tensor): + try: + cam_val = torch.as_tensor(cam_val) + except (TypeError, ValueError): + item[cam] = self.image_transforms(cam_val) + continue + + # Convert [N, H, W, C] → [N, C, H, W] + if cam_val.dim() == 4 and cam_val.shape[-1] == 3: + cam_val = cam_val.permute(0, 3, 1, 2) + + # Convert [H, W, C] → [C, H, W] + if cam_val.dim() == 3 and cam_val.shape[-1] == 3: + cam_val = cam_val.permute(2, 0, 1) + + # Multi-frame case [N, C, H, W] → [N*C, H, W] + if cam_val.dim() == 4: + frames = [] + for f_idx in range(cam_val.shape[0]): + frame = cam_val[f_idx] + frame_t = self.image_transforms(frame) + if not isinstance(frame_t, torch.Tensor): + frame_t = torch.as_tensor(frame_t) + if frame_t.dim() == 3 and frame_t.shape[-1] == 3: + frame_t = frame_t.permute(2, 0, 1) + frames.append(frame_t) + + stacked = torch.stack(frames, dim=0) + N, C, H, W = stacked.shape + item[cam] = stacked.view(N * C, H, W) + + else: + # Apply transforms for single-frame or unexpected shapes + item[cam] = self.image_transforms(cam_val) # Add task as a string task_idx = item["task_index"].item()