-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Description
Title: Performance regression in _query_hf_dataset when using delta_timestamps (~40x slower)
Summary
LeRobot v3.0's _query_hf_dataset implementation has a severe performance regression when querying multiple indices (e.g., for action horizons with delta_timestamps). The current implementation is ~40x slower than using select().
Root Cause
In commit a4aa3164 ("fix data access bottleneck"), the implementation changed from using select() to direct indexing:
# Old approach (fast) - used select()
torch.stack(self.hf_dataset.select(q_idx)[key])
# New approach (slow) - direct indexing
# In _query_hf_dataset, it tries column-first access which loads ENTIRE column
torch.stack(self.hf_dataset[key][indices]) # This loads entire column first (~3s)
# Then falls back to row-first
torch.stack(self.hf_dataset[indices][key]) # Still slow compared to select()The column-first approach hf_dataset[key][indices] loads the entire column into memory before indexing, which is extremely slow. Even after falling back to row-first access, direct indexing hf_dataset[indices][key] is still ~40x slower than using select().
Benchmark Results
Running on a parquet-based dataset with delta_timestamps for 50-frame action horizon:
Benchmarking dataset: JianZhou0420/libero_openvla_LeRobotv3_0
Num samples: 20
Action horizon: 50
FPS: 10
Total frames: 273465
============================================================
BENCHMARK RESULTS
============================================================
1. hf_dataset[i] (single frame):
0.152s total, 7.6ms per sample
2. hf_dataset.select(indices)[key] (OpenPi's fast approach):
0.128s total, 6.4ms per sample
3. hf_dataset[indices][key] (LeRobot v3.0's slow approach):
5.253s total, 262.6ms per sample
4. LeRobotDataset[i] with delta_timestamps (what training uses):
6.980s total, 349.0ms per sample
============================================================
ANALYSIS
============================================================
select() is 40.6x faster than direct indexing
Estimated time per epoch with current LeRobotDataset:
26.5 hours (1591 minutes)
Estimated time per epoch with select() fix:
0.56 hours (34 minutes)
| Method | Time per sample | Relative |
|---|---|---|
hf_dataset.select(indices)[key] |
6.4ms | 1x (baseline) |
hf_dataset[indices][key] |
262.6ms | 41x slower |
LeRobotDataset[i] with delta_timestamps |
349.0ms | 55x slower |
Impact: Training time increases from ~34 minutes to ~26.5 hours per epoch.
Use Case
This affects Pi0/OpenPi style action chunking where we need to query 50-frame action horizons using delta_timestamps. This is a common pattern for diffusion-based action prediction models.
# Pi0/OpenPi action chunking pattern
delta_timestamps = {"actions": [t / fps for t in range(50)]} # 50-frame horizon
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
# Each __getitem__ call queries 50 indices - very slow with current implementation
sample = dataset[i]Suggested Fix
Use select() for multi-index queries in _query_hf_dataset:
# Fast approach
torch.stack(self.hf_dataset.select(indices)[key])Benchmark Script
Click to expand benchmark_dataset_loading.py
"""Benchmark LeRobotDataset loading speed.
Compares different data access methods to measure performance impact
of LeRobot v3.0's slow _query_hf_dataset implementation.
Usage:
python scripts/benchmark_dataset_loading.py
python scripts/benchmark_dataset_loading.py --num-samples 100
python scripts/benchmark_dataset_loading.py --repo-id your/dataset
"""
import argparse
import time
import torch
def benchmark_hf_dataset_direct(dataset, num_samples: int) -> float:
"""Benchmark direct hf_dataset access (no delta_timestamps processing)."""
start = time.time()
for i in range(num_samples):
_ = dataset.hf_dataset[i]
return time.time() - start
def benchmark_hf_dataset_select(dataset, num_samples: int, horizon: int = 50) -> float:
"""Benchmark hf_dataset.select() for action horizon queries."""
start = time.time()
for i in range(num_samples):
indices = list(range(i, min(i + horizon, len(dataset))))
_ = torch.stack(dataset.hf_dataset.select(indices)['actions'])
return time.time() - start
def benchmark_hf_dataset_direct_indexing(dataset, num_samples: int, horizon: int = 50) -> float:
"""Benchmark hf_dataset[indices][key] (LeRobot v3.0's slow approach)."""
start = time.time()
for i in range(num_samples):
indices = list(range(i, min(i + horizon, len(dataset))))
_ = torch.stack(dataset.hf_dataset[indices]['actions'])
return time.time() - start
def benchmark_lerobot_dataset(dataset_with_delta, num_samples: int) -> float:
"""Benchmark LeRobotDataset with delta_timestamps (full __getitem__)."""
start = time.time()
for i in range(num_samples):
_ = dataset_with_delta[i]
return time.time() - start
def main():
parser = argparse.ArgumentParser(description="Benchmark LeRobotDataset loading speed")
parser.add_argument("--repo-id", default="JianZhou0420/libero_openvla_LeRobotv3_0",
help="HuggingFace dataset repo ID")
parser.add_argument("--num-samples", type=int, default=50,
help="Number of samples to benchmark")
parser.add_argument("--horizon", type=int, default=50,
help="Action horizon for delta_timestamps")
args = parser.parse_args()
print(f"Benchmarking dataset: {args.repo_id}")
print(f" Num samples: {args.num_samples}")
print(f" Action horizon: {args.horizon}")
print()
# Import here to see any warnings
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
# Load metadata
meta = LeRobotDatasetMetadata(args.repo_id)
fps = meta.fps
print(f" FPS: {fps}")
print(f" Total frames: {meta.total_frames}")
print()
# Load dataset WITHOUT delta_timestamps
print("Loading dataset without delta_timestamps...")
dataset_simple = LeRobotDataset(args.repo_id)
# Load dataset WITH delta_timestamps
print("Loading dataset with delta_timestamps...")
delta_timestamps = {"actions": [t / fps for t in range(args.horizon)]}
dataset_with_delta = LeRobotDataset(args.repo_id, delta_timestamps=delta_timestamps)
print()
print("=" * 60)
print("BENCHMARK RESULTS")
print("=" * 60)
# Benchmark 1: Direct hf_dataset access
elapsed = benchmark_hf_dataset_direct(dataset_simple, args.num_samples)
print(f"1. hf_dataset[i] (single frame):")
print(f" {elapsed:.3f}s total, {elapsed/args.num_samples*1000:.1f}ms per sample")
print()
# Benchmark 2: select() method (fast)
elapsed = benchmark_hf_dataset_select(dataset_simple, args.num_samples, args.horizon)
print(f"2. hf_dataset.select(indices)[key] (OpenPi's fast approach):")
print(f" {elapsed:.3f}s total, {elapsed/args.num_samples*1000:.1f}ms per sample")
print()
# Benchmark 3: Direct indexing (slow - LeRobot v3.0's approach)
elapsed = benchmark_hf_dataset_direct_indexing(dataset_simple, args.num_samples, args.horizon)
print(f"3. hf_dataset[indices][key] (LeRobot v3.0's slow approach):")
print(f" {elapsed:.3f}s total, {elapsed/args.num_samples*1000:.1f}ms per sample")
print()
# Benchmark 4: Full LeRobotDataset with delta_timestamps
elapsed = benchmark_lerobot_dataset(dataset_with_delta, args.num_samples)
print(f"4. LeRobotDataset[i] with delta_timestamps (what training uses):")
print(f" {elapsed:.3f}s total, {elapsed/args.num_samples*1000:.1f}ms per sample")
print()
print("=" * 60)
print("ANALYSIS")
print("=" * 60)
# Calculate speedup
t_select = benchmark_hf_dataset_select(dataset_simple, 10, args.horizon) / 10
t_direct = benchmark_hf_dataset_direct_indexing(dataset_simple, 10, args.horizon) / 10
speedup = t_direct / t_select if t_select > 0 else 0
print(f"select() is {speedup:.1f}x faster than direct indexing")
print()
# Estimate full training impact
total_frames = meta.total_frames
t_per_sample = elapsed / args.num_samples
estimated_epoch_time = total_frames * t_per_sample
print(f"Estimated time per epoch with current LeRobotDataset:")
print(f" {estimated_epoch_time/3600:.1f} hours ({estimated_epoch_time/60:.0f} minutes)")
print()
# Estimate with fix
t_fast = benchmark_hf_dataset_select(dataset_simple, args.num_samples, args.horizon) / args.num_samples
estimated_fast_epoch = total_frames * t_fast
print(f"Estimated time per epoch with select() fix:")
print(f" {estimated_fast_epoch/3600:.2f} hours ({estimated_fast_epoch/60:.0f} minutes)")
if __name__ == "__main__":
main()Environment
- LeRobot version: v3.0 (latest main)
- Dataset format: parquet (images stored in parquet, not video)
- Python: 3.10
- Platform: Linux