Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions src/lerobot/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import logging
import shutil
import tempfile
import torch
from collections.abc import Callable
from pathlib import Path

Expand Down Expand Up @@ -994,8 +995,46 @@ def __getitem__(self, idx) -> dict:
if self.image_transforms is not None:
image_keys = self.meta.camera_keys
for cam in image_keys:
item[cam] = self.image_transforms(item[cam])

cam_val = item.get(cam, None)
if cam_val is None:
continue

if not isinstance(cam_val, torch.Tensor):
try:
cam_val = torch.as_tensor(cam_val)
except Exception:
item[cam] = self.image_transforms(cam_val)
continue

# Convert [N, H, W, C] β†’ [N, C, H, W]
if cam_val.dim() == 4 and cam_val.shape[-1] == 3:
cam_val = cam_val.permute(0, 3, 1, 2)

# Convert [H, W, C] β†’ [C, H, W]
if cam_val.dim() == 3 and cam_val.shape[-1] == 3:
cam_val = cam_val.permute(2, 0, 1)

# Multi-frame case
if cam_val.dim() == 4:
frames = []
for f_idx in range(cam_val.shape[0]):
frame = cam_val[f_idx]
frame_t = self.image_transforms(frame)
if not isinstance(frame_t, torch.Tensor):
frame_t = torch.as_tensor(frame_t)
if frame_t.dim() == 3 and frame_t.shape[-1] == 3:
frame_t = frame_t.permute(2, 0, 1)
frames.append(frame_t)
try:
item[cam] = torch.cat(frames, dim=0)
except Exception:
stacked = torch.stack(frames, dim=0)
N, C, H, W = stacked.shape
item[cam] = stacked.view(N * C, H, W)
elif cam_val.dim() == 3:
item[cam] = self.image_transforms(cam_val)
else:
item[cam] = self.image_transforms(cam_val)
# Add task as a string
task_idx = item["task_index"].item()
item["task"] = self.meta.tasks.iloc[task_idx].name
Expand Down