Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,9 @@ ignore_errors = false
# module = "lerobot.processor.*"
# ignore_errors = false

# [[tool.mypy.overrides]]
# module = "lerobot.datasets.*"
# ignore_errors = false
[[tool.mypy.overrides]]
module = "lerobot.datasets.*"
ignore_errors = false

[[tool.mypy.overrides]]
module = "lerobot.cameras.*"
Expand Down
8 changes: 4 additions & 4 deletions src/lerobot/datasets/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ def aggregate_datasets(
aggr_repo_id: str,
roots: list[Path] | None = None,
aggr_root: Path | None = None,
data_files_size_in_mb: float | None = None,
video_files_size_in_mb: float | None = None,
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
chunk_size: int | None = None,
):
"""Aggregates multiple LeRobot datasets into a single unified dataset.
Expand Down Expand Up @@ -450,7 +450,7 @@ def append_or_create_parquet_file(
chunk_size: int,
default_path: str,
contains_images: bool = False,
aggr_root: Path = None,
aggr_root: Path = Path.cwd(),
):
"""Appends data to an existing parquet file or creates a new one based on size constraints.

Expand All @@ -465,7 +465,7 @@ def append_or_create_parquet_file(
chunk_size: Maximum number of files per chunk before incrementing chunk index.
default_path: Format string for generating file paths.
contains_images: Whether the data contains images requiring special handling.
aggr_root: Root path for the aggregated dataset.
aggr_root: Root path for the aggregated dataset. Defaults to Path.cwd()

Returns:
dict: Updated index dictionary with current chunk and file indices.
Expand Down
45 changes: 31 additions & 14 deletions src/lerobot/datasets/compute_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,15 @@ class RunningQuantileStats:

def __init__(self, quantile_list: list[float] | None = None, num_quantile_bins: int = 5000):
self._count = 0
self._mean = None
self._mean: np.ndarray | None = None
self._mean_of_squares = None
self._min = None
self._max = None
self._histograms = None
self._bin_edges = None
self._min: np.ndarray | None = None
self._max: np.ndarray | None = None
self._histograms: list[np.ndarray] | None = None
self._bin_edges: list[np.ndarray] | None = None
self._num_quantile_bins = num_quantile_bins

self._quantile_list = quantile_list
if self._quantile_list is None:
self._quantile_list = DEFAULT_QUANTILES
self._quantile_list: list[float] = quantile_list or DEFAULT_QUANTILES
self._quantile_keys = [f"q{int(q * 100):02d}" for q in self._quantile_list]

def update(self, batch: np.ndarray) -> None:
Expand All @@ -65,6 +63,10 @@ def update(self, batch: np.ndarray) -> None:
for i in range(vector_length)
]
else:
assert self._mean is not None
assert self._min is not None
assert self._max is not None

if vector_length != self._mean.size:
raise ValueError("The length of new vectors does not match the initialized vector length.")

Expand Down Expand Up @@ -103,6 +105,10 @@ def get_statistics(self) -> dict[str, np.ndarray]:
if self._count < 2:
raise ValueError("Cannot compute statistics for less than 2 vectors.")

assert self._mean is not None
assert self._min is not None
assert self._max is not None

variance = self._mean_of_squares - self._mean**2

stddev = np.sqrt(np.maximum(0, variance))
Expand Down Expand Up @@ -150,12 +156,19 @@ def _adjust_histograms(self):

def _update_histograms(self, batch: np.ndarray) -> None:
"""Update histograms with new vectors."""

assert self._histograms is not None
assert self._bin_edges is not None

for i in range(batch.shape[1]):
hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i])
self._histograms[i] += hist

def _compute_quantiles(self) -> list[np.ndarray]:
"""Compute quantiles based on histograms."""
assert self._histograms is not None
assert self._bin_edges is not None

results = []
for q in self._quantile_list:
target_count = q * self._count
Expand All @@ -174,9 +187,9 @@ def _compute_single_quantile(self, hist: np.ndarray, edges: np.ndarray, target_c
idx = np.searchsorted(cumsum, target_count)

if idx == 0:
return edges[0]
return float(edges[0])
if idx >= len(cumsum):
return edges[-1]
return float(edges[-1])

# If not edge case, interpolate within the bin
count_before = cumsum[idx - 1]
Expand Down Expand Up @@ -242,6 +255,7 @@ def sample_images(image_paths: list[str]) -> np.ndarray:

images[i] = img

assert images is not None
return images


Expand Down Expand Up @@ -318,7 +332,7 @@ def _reshape_for_feature_stats(value: np.ndarray, keepdims: bool) -> np.ndarray:

def _reshape_for_global_stats(
value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...]
) -> np.ndarray | float:
) -> np.ndarray:
"""Reshape statistics for global reduction (axis=None)."""
if keepdims:
target_shape = tuple(1 for _ in original_shape)
Expand All @@ -329,7 +343,7 @@ def _reshape_for_global_stats(

def _reshape_single_stat(
value: np.ndarray, axis: int | tuple[int, ...] | None, keepdims: bool, original_shape: tuple[int, ...]
) -> np.ndarray | float:
) -> np.ndarray:
"""Apply appropriate reshaping to a single statistic array.

This function transforms statistic arrays to match expected output shapes
Expand Down Expand Up @@ -508,11 +522,14 @@ def compute_episode_stats(
if features[key]["dtype"] == "string":
continue

axes_to_reduce: int | tuple[int, ...] | None
if features[key]["dtype"] in ["image", "video"]:
assert isinstance(data, list)
ep_ft_array = sample_images(data)
axes_to_reduce = (0, 2, 3)
keepdims = True
else:
assert isinstance(data, np.ndarray)
ep_ft_array = data
axes_to_reduce = 0
keepdims = data.ndim == 1
Expand Down Expand Up @@ -562,7 +579,7 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
_validate_stat_value(stat_value, stat_key, feature_key)


def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
def aggregate_feature_stats(stats_ft_list: list[dict[str, float]]) -> dict[str, np.ndarray]:
"""Aggregates stats for a single feature."""
means = np.stack([s["mean"] for s in stats_ft_list])
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
Expand Down Expand Up @@ -617,7 +634,7 @@ def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np
_assert_type_and_shape(stats_list)

data_keys = {key for stats in stats_list for key in stats}
aggregated_stats = {key: {} for key in data_keys}
aggregated_stats: dict[str, dict[str, np.ndarray]] = {key: {} for key in data_keys}

for key in data_keys:
stats_with_key = [stats[key] for stats in stats_list if key in stats]
Expand Down
17 changes: 10 additions & 7 deletions src/lerobot/datasets/dataset_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@

import logging
import shutil
from collections.abc import Callable
from collections.abc import Callable, Mapping
from pathlib import Path
from typing import Any, cast

import datasets
import numpy as np
Expand Down Expand Up @@ -140,7 +141,7 @@ def delete_episodes(

def split_dataset(
dataset: LeRobotDataset,
splits: dict[str, float | list[int]],
splits: Mapping[str, float | list[int]],
output_dir: str | Path | None = None,
) -> dict[str, LeRobotDataset]:
"""Split a LeRobotDataset into multiple smaller datasets.
Expand All @@ -164,12 +165,13 @@ def split_dataset(
raise ValueError("No splits provided")

if all(isinstance(v, float) for v in splits.values()):
splits = _fractions_to_episode_indices(dataset.meta.total_episodes, splits)
splits = _fractions_to_episode_indices(dataset.meta.total_episodes, cast(dict[str, float], splits))

all_episodes = set()
all_episodes: set[int] = set()
for split_name, episodes in splits.items():
if not episodes:
raise ValueError(f"Split '{split_name}' has no episodes")
assert not isinstance(episodes, float)
episode_set = set(episodes)
if episode_set & all_episodes:
raise ValueError("Episodes cannot appear in multiple splits")
Expand All @@ -186,6 +188,7 @@ def split_dataset(
result_datasets = {}

for split_name, episodes in splits.items():
assert not isinstance(episodes, float)
logging.info(f"Creating split '{split_name}' with {len(episodes)} episodes")

split_repo_id = f"{dataset.repo_id}_{split_name}"
Expand Down Expand Up @@ -441,8 +444,8 @@ def remove_feature(

def _fractions_to_episode_indices(
total_episodes: int,
splits: dict[str, float],
) -> dict[str, list[int]]:
splits: Mapping[str, float],
) -> Mapping[str, list[int]]:
"""Convert split fractions to episode indices."""
if sum(splits.values()) > 1.0:
raise ValueError("Split fractions must sum to <= 1.0")
Expand Down Expand Up @@ -840,7 +843,7 @@ def _copy_and_reindex_episodes_metadata(
# array([array([array([0.])]), array([array([0.])]), array([array([0.])])])
# This happens particularly with image/video statistics. We need to detect and flatten
# these nested structures back to proper (3, 1, 1) arrays so aggregate_stats can process them.
episode_stats = {}
episode_stats: dict[str, Any] = {}
for key in src_episode_full:
if key.startswith("stats/"):
stat_key = key.replace("stats/", "")
Expand Down
12 changes: 3 additions & 9 deletions src/lerobot/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@

from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
MultiLeRobotDataset,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata, MultiLeRobotDataset
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
Expand Down Expand Up @@ -62,10 +58,7 @@ def resolve_delta_timestamps(
if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]

if len(delta_timestamps) == 0:
delta_timestamps = None

return delta_timestamps
return delta_timestamps if len(delta_timestamps) > 0 else None


def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset:
Expand All @@ -88,6 +81,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
ds_meta = LeRobotDatasetMetadata(
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
)
assert cfg.policy is not None
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
if not cfg.dataset.streaming:
dataset = LeRobotDataset(
Expand Down
3 changes: 2 additions & 1 deletion src/lerobot/datasets/image_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class AsyncImageWriter:
def __init__(self, num_processes: int = 0, num_threads: int = 1):
self.num_processes = num_processes
self.num_threads = num_threads
self.queue = None
self.queue: queue.Queue | multiprocessing.JoinableQueue | None = None
self.threads = []
self.processes = []
self._stopped = False
Expand All @@ -170,6 +170,7 @@ def __init__(self, num_processes: int = 0, num_threads: int = 1):
self.processes.append(p)

def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
assert self.queue is not None
if isinstance(image, torch.Tensor):
# Convert tensor to numpy array to minimize main process time
image = image.cpu().numpy()
Expand Down
Loading