Skip to content

Commit 05077b7

Browse files
author
Your Name
committed
lerobot first attempt
1 parent 57a14f6 commit 05077b7

File tree

5 files changed

+1648
-306
lines changed

5 files changed

+1648
-306
lines changed
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
#!/usr/bin/env python3
2+
"""
3+
LeRobot to RoboDM Dataset Ingestion Pipeline
4+
5+
This module handles the conversion of LeRobot datasets to RoboDM format for parallel processing.
6+
It provides a clean ingestion interface that can be used standalone or as part of a larger pipeline.
7+
8+
Usage:
9+
python lerobot_to_robodm_ingestion.py --dataset lerobot/pusht --num_episodes 50 --output_dir ./robodm_data
10+
"""
11+
12+
import os
13+
import tempfile
14+
import argparse
15+
from pathlib import Path
16+
from typing import Optional, Dict, Any
17+
import numpy as np
18+
19+
# RoboDM imports
20+
from robodm.trajectory import Trajectory
21+
22+
# LeRobot imports (if available)
23+
try:
24+
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
25+
# Set backend to pyav for video processing
26+
import lerobot.datasets.video_utils as video_utils
27+
if hasattr(video_utils, 'set_video_backend'):
28+
video_utils.set_video_backend('pyav')
29+
LEROBOT_AVAILABLE = True
30+
except ImportError:
31+
print("LeRobot not available. Please install lerobot package.")
32+
LEROBOT_AVAILABLE = False
33+
34+
35+
class LeRobotToRoboDMIngestion:
36+
"""Handles conversion of LeRobot datasets to RoboDM format."""
37+
38+
def __init__(self, dataset_name: str, output_dir: Optional[str] = None):
39+
"""
40+
Initialize the ingestion pipeline.
41+
42+
Args:
43+
dataset_name: Name of the LeRobot dataset (e.g., 'lerobot/pusht')
44+
output_dir: Directory to save RoboDM trajectories. If None, uses temp directory.
45+
"""
46+
if not LEROBOT_AVAILABLE:
47+
raise ImportError("LeRobot is not available. Please install lerobot package.")
48+
49+
self.dataset_name = dataset_name
50+
self.output_dir = output_dir or tempfile.mkdtemp(prefix="robodm_lerobot_")
51+
os.makedirs(self.output_dir, exist_ok=True)
52+
53+
# Load dataset metadata
54+
try:
55+
self.metadata = LeRobotDatasetMetadata(dataset_name)
56+
print(f"Dataset info: {self.metadata.total_episodes} episodes, {self.metadata.total_frames} frames")
57+
except Exception as e:
58+
print(f"Could not load metadata: {e}. Proceeding without metadata.")
59+
self.metadata = None
60+
61+
def ingest(self, num_episodes: Optional[int] = None, video_backend: str = 'pyav') -> str:
62+
"""
63+
Convert LeRobot dataset to RoboDM format.
64+
65+
Args:
66+
num_episodes: Number of episodes to convert. If None, converts all episodes.
67+
video_backend: Video backend to use for processing ('pyav' or 'opencv').
68+
69+
Returns:
70+
Path to the directory containing converted RoboDM trajectories.
71+
"""
72+
print(f"Starting ingestion of {self.dataset_name}")
73+
print(f"Output directory: {self.output_dir}")
74+
75+
# Determine episodes to load
76+
episodes_to_load = None
77+
if num_episodes is not None and self.metadata is not None:
78+
episodes_to_load = list(range(min(num_episodes, self.metadata.total_episodes)))
79+
80+
# Load LeRobot dataset
81+
print(f"Loading dataset with episodes: {episodes_to_load if episodes_to_load else 'all'}")
82+
lerobot_dataset = self._load_lerobot_dataset(episodes_to_load, video_backend)
83+
84+
# Convert to RoboDM format
85+
self._convert_to_robodm(lerobot_dataset)
86+
87+
print(f"✅ Ingestion completed successfully!")
88+
print(f"RoboDM trajectories saved to: {self.output_dir}")
89+
return self.output_dir
90+
91+
def _load_lerobot_dataset(self, episodes_to_load: Optional[list], video_backend: str) -> LeRobotDataset:
92+
"""Load LeRobot dataset with proper video backend."""
93+
try:
94+
dataset = LeRobotDataset(
95+
self.dataset_name,
96+
episodes=episodes_to_load,
97+
video_backend=video_backend
98+
)
99+
except TypeError:
100+
# Fallback if video_backend parameter is not supported
101+
dataset = LeRobotDataset(self.dataset_name, episodes=episodes_to_load)
102+
103+
print(f"Dataset loaded with {len(dataset)} samples")
104+
return dataset
105+
106+
def _convert_to_robodm(self, lerobot_dataset: LeRobotDataset):
107+
"""Convert LeRobot dataset to RoboDM trajectory format."""
108+
# Group samples by episode
109+
episodes_data = {}
110+
for i, sample in enumerate(lerobot_dataset):
111+
episode_idx = sample['episode_index'].item()
112+
frame_idx = sample['frame_index'].item()
113+
114+
if episode_idx not in episodes_data:
115+
episodes_data[episode_idx] = []
116+
117+
episodes_data[episode_idx].append((frame_idx, sample))
118+
119+
# Sort each episode by frame index
120+
for episode_idx in episodes_data:
121+
episodes_data[episode_idx].sort(key=lambda x: x[0])
122+
123+
print(f"Converting {len(episodes_data)} episodes to RoboDM format...")
124+
125+
# Convert each episode to RoboDM trajectory
126+
for episode_idx, frames in episodes_data.items():
127+
self._convert_episode_to_trajectory(episode_idx, frames)
128+
129+
def _convert_episode_to_trajectory(self, episode_idx: int, frames: list):
130+
"""Convert a single episode to a RoboDM trajectory file."""
131+
trajectory_path = os.path.join(self.output_dir, f"episode_{episode_idx:03d}.vla")
132+
traj = Trajectory(path=trajectory_path, mode="w")
133+
134+
try:
135+
for frame_idx, sample in frames:
136+
# Convert timestamp (assuming 10 FPS by default)
137+
timestamp = frame_idx * 100 # 100ms intervals = 10 FPS
138+
139+
# Add image observations
140+
self._add_image_observations(traj, sample, timestamp)
141+
142+
# Add state observations
143+
if 'observation.state' in sample:
144+
state = sample['observation.state'].numpy().astype(np.float32)
145+
traj.add("observation/state", state, timestamp=timestamp, time_unit="ms")
146+
147+
# Add actions
148+
if 'action' in sample:
149+
action = sample['action'].numpy().astype(np.float32)
150+
traj.add("action", action, timestamp=timestamp, time_unit="ms")
151+
152+
# Add reward and done signals if available
153+
if 'next.reward' in sample:
154+
reward = sample['next.reward'].numpy().astype(np.float32)
155+
traj.add("reward", reward, timestamp=timestamp, time_unit="ms")
156+
157+
if 'next.done' in sample:
158+
done = sample['next.done'].numpy().astype(np.bool_)
159+
traj.add("done", done, timestamp=timestamp, time_unit="ms")
160+
161+
finally:
162+
traj.close()
163+
164+
def _add_image_observations(self, traj: Trajectory, sample: Dict[str, Any], timestamp: int):
165+
"""Add image observations to trajectory."""
166+
# Handle primary image observation
167+
if 'observation.image' in sample:
168+
image = sample['observation.image'].permute(1, 2, 0).numpy()
169+
if image.max() <= 1.0:
170+
image = (image * 255).astype(np.uint8)
171+
traj.add("observation/image", image, timestamp=timestamp, time_unit="ms")
172+
173+
# Handle multiple camera observations
174+
for key in sample.keys():
175+
if key.startswith('observation.images.'):
176+
camera_name = key.split('.')[-1]
177+
image = sample[key].permute(1, 2, 0).numpy()
178+
if image.max() <= 1.0:
179+
image = (image * 255).astype(np.uint8)
180+
traj.add(f"observation/images/{camera_name}", image, timestamp=timestamp, time_unit="ms")
181+
elif key.startswith('observation.image') and key != 'observation.image':
182+
# Handle other image observations like observation.image_front, etc.
183+
camera_name = key.split('.')[-1] if '.' in key else key.replace('observation.', '')
184+
image = sample[key].permute(1, 2, 0).numpy()
185+
if image.max() <= 1.0:
186+
image = (image * 255).astype(np.uint8)
187+
traj.add(f"observation/images/{camera_name}", image, timestamp=timestamp, time_unit="ms")
188+
189+
def get_conversion_stats(self) -> Dict[str, Any]:
190+
"""Get statistics about the converted dataset."""
191+
trajectory_files = list(Path(self.output_dir).glob("*.vla"))
192+
return {
193+
"output_directory": self.output_dir,
194+
"num_trajectories": len(trajectory_files),
195+
"trajectory_files": [str(f) for f in trajectory_files],
196+
"total_size_mb": sum(f.stat().st_size for f in trajectory_files) / (1024 * 1024)
197+
}
198+
199+
200+
def main():
201+
"""Main function for standalone usage."""
202+
parser = argparse.ArgumentParser(description="Convert LeRobot dataset to RoboDM format")
203+
parser.add_argument("--dataset", type=str, required=True,
204+
help="LeRobot dataset name (e.g., lerobot/pusht)")
205+
parser.add_argument("--num_episodes", type=int, default=None,
206+
help="Number of episodes to convert (None for all)")
207+
parser.add_argument("--output_dir", type=str, default=None,
208+
help="Output directory for RoboDM trajectories")
209+
parser.add_argument("--video_backend", type=str, default='pyav',
210+
choices=['pyav', 'opencv'], help="Video backend to use")
211+
212+
args = parser.parse_args()
213+
214+
# Create ingestion pipeline
215+
ingestion = LeRobotToRoboDMIngestion(
216+
dataset_name=args.dataset,
217+
output_dir=args.output_dir
218+
)
219+
220+
# Run ingestion
221+
output_dir = ingestion.ingest(
222+
num_episodes=args.num_episodes,
223+
video_backend=args.video_backend
224+
)
225+
226+
# Print statistics
227+
stats = ingestion.get_conversion_stats()
228+
print(f"\n📊 Conversion Statistics:")
229+
print(f" Output directory: {stats['output_directory']}")
230+
print(f" Trajectories converted: {stats['num_trajectories']}")
231+
print(f" Total size: {stats['total_size_mb']:.2f} MB")
232+
233+
234+
if __name__ == "__main__":
235+
main()

0 commit comments

Comments
 (0)