Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 91 additions & 3 deletions src/lerobot/scripts/lerobot_edit_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Edit LeRobot datasets using various transformation tools.

This script allows you to delete episodes, split datasets, merge datasets,
and remove features. When new_repo_id is specified, creates a new dataset.
remove features, and add features. When new_repo_id is specified, creates a new dataset.

Usage Examples:

Expand Down Expand Up @@ -65,6 +65,12 @@
--operation.type remove_feature \
--operation.feature_names "['observation.images.top']"

Add feature from numpy file:
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type add_feature \
--operation.features '{"reward": {"file": "rewards.npy", "dtype": "float32", "shape": [1], "names": null}}'

Using JSON config file:
python -m lerobot.scripts.lerobot_edit_dataset \
--config_path path/to/edit_config.json
Expand All @@ -75,8 +81,11 @@
from dataclasses import dataclass
from pathlib import Path

import numpy as np

from lerobot.configs import parser
from lerobot.datasets.dataset_tools import (
add_features,
delete_episodes,
merge_datasets,
remove_feature,
Expand Down Expand Up @@ -111,10 +120,16 @@ class RemoveFeatureConfig:
feature_names: list[str] | None = None


@dataclass
class AddFeatureConfig:
type: str = "add_feature"
features: dict[str, dict] | None = None


@dataclass
class EditDatasetConfig:
repo_id: str
operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig
operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | AddFeatureConfig
root: str | None = None
new_repo_id: str | None = None
push_to_hub: bool = False
Expand Down Expand Up @@ -258,6 +273,77 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()


def handle_add_feature(cfg: EditDatasetConfig) -> None:
if not isinstance(cfg.operation, AddFeatureConfig):
raise ValueError("Operation config must be AddFeatureConfig")

if not cfg.operation.features:
raise ValueError("features must be specified for add_feature operation")

dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
output_repo_id, output_dir = get_output_path(
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
)

if cfg.new_repo_id is None:
dataset.root = Path(str(dataset.root) + "_old")

# Process features config to load data and prepare for add_features
features_dict = {}
for feature_name, feature_config in cfg.operation.features.items():
# Extract feature info (dtype, shape, names)
shape = feature_config.get("shape")
# Convert list to tuple if needed
if isinstance(shape, list):
shape = tuple(shape)

feature_info = {
"dtype": feature_config.get("dtype"),
"shape": shape,
"names": feature_config.get("names"),
}

# Load feature data from file
feature_file = feature_config.get("file")
if not feature_file:
raise ValueError(f"Feature '{feature_name}' must specify a 'file' path to load data from")

file_path = Path(feature_file)
if not file_path.exists():
raise FileNotFoundError(f"Feature file not found: {feature_file}")

# Load numpy array
if file_path.suffix == ".npy":
feature_data = np.load(file_path)
else:
raise ValueError(f"Unsupported file format for feature '{feature_name}': {file_path.suffix}")

# Validate data length matches dataset
expected_length = dataset.meta.total_frames
if len(feature_data) != expected_length:
raise ValueError(
f"Feature '{feature_name}' data length ({len(feature_data)}) "
f"does not match dataset length ({expected_length})"
)

features_dict[feature_name] = (feature_data, feature_info)

logging.info(f"Adding features {list(features_dict.keys())} to {cfg.repo_id}")
new_dataset = add_features(
dataset,
features=features_dict,
output_dir=output_dir,
repo_id=output_repo_id,
)

logging.info(f"Dataset saved to {output_dir}")
logging.info(f"Updated features: {list(new_dataset.meta.features.keys())}")

if cfg.push_to_hub:
logging.info(f"Pushing to hub as {output_repo_id}")
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()


@parser.wrap()
def edit_dataset(cfg: EditDatasetConfig) -> None:
operation_type = cfg.operation.type
Expand All @@ -270,10 +356,12 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
handle_merge(cfg)
elif operation_type == "remove_feature":
handle_remove_feature(cfg)
elif operation_type == "add_feature":
handle_add_feature(cfg)
else:
raise ValueError(
f"Unknown operation type: {operation_type}\n"
f"Available operations: delete_episodes, split, merge, remove_feature"
f"Available operations: delete_episodes, split, merge, remove_feature, add_feature"
)


Expand Down