Skip to content

Performance regression in _query_hf_dataset when using delta_timestamps (~40x slower) #2895

@jianzhou0420

Description

@jianzhou0420

Title: Performance regression in _query_hf_dataset when using delta_timestamps (~40x slower)


Summary

LeRobot v3.0's _query_hf_dataset implementation has a severe performance regression when querying multiple indices (e.g., for action horizons with delta_timestamps). The current implementation is ~40x slower than using select().

Root Cause

In commit a4aa3164 ("fix data access bottleneck"), the implementation changed from using select() to direct indexing:

# Old approach (fast) - used select()
torch.stack(self.hf_dataset.select(q_idx)[key])

# New approach (slow) - direct indexing
# In _query_hf_dataset, it tries column-first access which loads ENTIRE column
torch.stack(self.hf_dataset[key][indices])  # This loads entire column first (~3s)
# Then falls back to row-first
torch.stack(self.hf_dataset[indices][key])  # Still slow compared to select()

The column-first approach hf_dataset[key][indices] loads the entire column into memory before indexing, which is extremely slow. Even after falling back to row-first access, direct indexing hf_dataset[indices][key] is still ~40x slower than using select().

Benchmark Results

Running on a parquet-based dataset with delta_timestamps for 50-frame action horizon:

Benchmarking dataset: JianZhou0420/libero_openvla_LeRobotv3_0
  Num samples: 20
  Action horizon: 50

  FPS: 10
  Total frames: 273465

============================================================
BENCHMARK RESULTS
============================================================
1. hf_dataset[i] (single frame):
   0.152s total, 7.6ms per sample

2. hf_dataset.select(indices)[key] (OpenPi's fast approach):
   0.128s total, 6.4ms per sample

3. hf_dataset[indices][key] (LeRobot v3.0's slow approach):
   5.253s total, 262.6ms per sample

4. LeRobotDataset[i] with delta_timestamps (what training uses):
   6.980s total, 349.0ms per sample

============================================================
ANALYSIS
============================================================
select() is 40.6x faster than direct indexing

Estimated time per epoch with current LeRobotDataset:
  26.5 hours (1591 minutes)

Estimated time per epoch with select() fix:
  0.56 hours (34 minutes)
Method Time per sample Relative
hf_dataset.select(indices)[key] 6.4ms 1x (baseline)
hf_dataset[indices][key] 262.6ms 41x slower
LeRobotDataset[i] with delta_timestamps 349.0ms 55x slower

Impact: Training time increases from ~34 minutes to ~26.5 hours per epoch.

Use Case

This affects Pi0/OpenPi style action chunking where we need to query 50-frame action horizons using delta_timestamps. This is a common pattern for diffusion-based action prediction models.

# Pi0/OpenPi action chunking pattern
delta_timestamps = {"actions": [t / fps for t in range(50)]}  # 50-frame horizon
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)

# Each __getitem__ call queries 50 indices - very slow with current implementation
sample = dataset[i]

Suggested Fix

Use select() for multi-index queries in _query_hf_dataset:

# Fast approach
torch.stack(self.hf_dataset.select(indices)[key])

Benchmark Script

Click to expand benchmark_dataset_loading.py
"""Benchmark LeRobotDataset loading speed.

Compares different data access methods to measure performance impact
of LeRobot v3.0's slow _query_hf_dataset implementation.

Usage:
    python scripts/benchmark_dataset_loading.py
    python scripts/benchmark_dataset_loading.py --num-samples 100
    python scripts/benchmark_dataset_loading.py --repo-id your/dataset
"""

import argparse
import time

import torch


def benchmark_hf_dataset_direct(dataset, num_samples: int) -> float:
    """Benchmark direct hf_dataset access (no delta_timestamps processing)."""
    start = time.time()
    for i in range(num_samples):
        _ = dataset.hf_dataset[i]
    return time.time() - start


def benchmark_hf_dataset_select(dataset, num_samples: int, horizon: int = 50) -> float:
    """Benchmark hf_dataset.select() for action horizon queries."""
    start = time.time()
    for i in range(num_samples):
        indices = list(range(i, min(i + horizon, len(dataset))))
        _ = torch.stack(dataset.hf_dataset.select(indices)['actions'])
    return time.time() - start


def benchmark_hf_dataset_direct_indexing(dataset, num_samples: int, horizon: int = 50) -> float:
    """Benchmark hf_dataset[indices][key] (LeRobot v3.0's slow approach)."""
    start = time.time()
    for i in range(num_samples):
        indices = list(range(i, min(i + horizon, len(dataset))))
        _ = torch.stack(dataset.hf_dataset[indices]['actions'])
    return time.time() - start


def benchmark_lerobot_dataset(dataset_with_delta, num_samples: int) -> float:
    """Benchmark LeRobotDataset with delta_timestamps (full __getitem__)."""
    start = time.time()
    for i in range(num_samples):
        _ = dataset_with_delta[i]
    return time.time() - start


def main():
    parser = argparse.ArgumentParser(description="Benchmark LeRobotDataset loading speed")
    parser.add_argument("--repo-id", default="JianZhou0420/libero_openvla_LeRobotv3_0",
                        help="HuggingFace dataset repo ID")
    parser.add_argument("--num-samples", type=int, default=50,
                        help="Number of samples to benchmark")
    parser.add_argument("--horizon", type=int, default=50,
                        help="Action horizon for delta_timestamps")
    args = parser.parse_args()

    print(f"Benchmarking dataset: {args.repo_id}")
    print(f"  Num samples: {args.num_samples}")
    print(f"  Action horizon: {args.horizon}")
    print()

    # Import here to see any warnings
    from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata

    # Load metadata
    meta = LeRobotDatasetMetadata(args.repo_id)
    fps = meta.fps
    print(f"  FPS: {fps}")
    print(f"  Total frames: {meta.total_frames}")
    print()

    # Load dataset WITHOUT delta_timestamps
    print("Loading dataset without delta_timestamps...")
    dataset_simple = LeRobotDataset(args.repo_id)

    # Load dataset WITH delta_timestamps
    print("Loading dataset with delta_timestamps...")
    delta_timestamps = {"actions": [t / fps for t in range(args.horizon)]}
    dataset_with_delta = LeRobotDataset(args.repo_id, delta_timestamps=delta_timestamps)

    print()
    print("=" * 60)
    print("BENCHMARK RESULTS")
    print("=" * 60)

    # Benchmark 1: Direct hf_dataset access
    elapsed = benchmark_hf_dataset_direct(dataset_simple, args.num_samples)
    print(f"1. hf_dataset[i] (single frame):")
    print(f"   {elapsed:.3f}s total, {elapsed/args.num_samples*1000:.1f}ms per sample")
    print()

    # Benchmark 2: select() method (fast)
    elapsed = benchmark_hf_dataset_select(dataset_simple, args.num_samples, args.horizon)
    print(f"2. hf_dataset.select(indices)[key] (OpenPi's fast approach):")
    print(f"   {elapsed:.3f}s total, {elapsed/args.num_samples*1000:.1f}ms per sample")
    print()

    # Benchmark 3: Direct indexing (slow - LeRobot v3.0's approach)
    elapsed = benchmark_hf_dataset_direct_indexing(dataset_simple, args.num_samples, args.horizon)
    print(f"3. hf_dataset[indices][key] (LeRobot v3.0's slow approach):")
    print(f"   {elapsed:.3f}s total, {elapsed/args.num_samples*1000:.1f}ms per sample")
    print()

    # Benchmark 4: Full LeRobotDataset with delta_timestamps
    elapsed = benchmark_lerobot_dataset(dataset_with_delta, args.num_samples)
    print(f"4. LeRobotDataset[i] with delta_timestamps (what training uses):")
    print(f"   {elapsed:.3f}s total, {elapsed/args.num_samples*1000:.1f}ms per sample")
    print()

    print("=" * 60)
    print("ANALYSIS")
    print("=" * 60)

    # Calculate speedup
    t_select = benchmark_hf_dataset_select(dataset_simple, 10, args.horizon) / 10
    t_direct = benchmark_hf_dataset_direct_indexing(dataset_simple, 10, args.horizon) / 10
    speedup = t_direct / t_select if t_select > 0 else 0

    print(f"select() is {speedup:.1f}x faster than direct indexing")
    print()

    # Estimate full training impact
    total_frames = meta.total_frames
    t_per_sample = elapsed / args.num_samples
    estimated_epoch_time = total_frames * t_per_sample
    print(f"Estimated time per epoch with current LeRobotDataset:")
    print(f"  {estimated_epoch_time/3600:.1f} hours ({estimated_epoch_time/60:.0f} minutes)")
    print()

    # Estimate with fix
    t_fast = benchmark_hf_dataset_select(dataset_simple, args.num_samples, args.horizon) / args.num_samples
    estimated_fast_epoch = total_frames * t_fast
    print(f"Estimated time per epoch with select() fix:")
    print(f"  {estimated_fast_epoch/3600:.2f} hours ({estimated_fast_epoch/60:.0f} minutes)")


if __name__ == "__main__":
    main()

Environment

  • LeRobot version: v3.0 (latest main)
  • Dataset format: parquet (images stored in parquet, not video)
  • Python: 3.10
  • Platform: Linux

Metadata

Metadata

Labels

datasetIssues regarding data inputs, processing, or datasetsevaluationFor issues or PRs related to environment evaluation, and benchmarks.examplesIssues related to the examplesperformanceIssues aimed at improving speed or resource usagetrainingIssues related at training time

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions