diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index 83ba027bcc..e9b6fd5786 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -18,7 +18,7 @@ Edit LeRobot datasets using various transformation tools. This script allows you to delete episodes, split datasets, merge datasets, -and remove features. When new_repo_id is specified, creates a new dataset. +remove features, and add features. When new_repo_id is specified, creates a new dataset. Usage Examples: @@ -65,6 +65,12 @@ --operation.type remove_feature \ --operation.feature_names "['observation.images.top']" +Add feature from numpy file: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type add_feature \ + --operation.features '{"reward": {"file": "rewards.npy", "dtype": "float32", "shape": [1], "names": null}}' + Using JSON config file: python -m lerobot.scripts.lerobot_edit_dataset \ --config_path path/to/edit_config.json @@ -75,8 +81,11 @@ from dataclasses import dataclass from pathlib import Path +import numpy as np + from lerobot.configs import parser from lerobot.datasets.dataset_tools import ( + add_features, delete_episodes, merge_datasets, remove_feature, @@ -111,10 +120,16 @@ class RemoveFeatureConfig: feature_names: list[str] | None = None +@dataclass +class AddFeatureConfig: + type: str = "add_feature" + features: dict[str, dict] | None = None + + @dataclass class EditDatasetConfig: repo_id: str - operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig + operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | AddFeatureConfig root: str | None = None new_repo_id: str | None = None push_to_hub: bool = False @@ -258,6 +273,90 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None: LeRobotDataset(output_repo_id, root=output_dir).push_to_hub() +def handle_add_feature(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, AddFeatureConfig): + raise ValueError("Operation config must be AddFeatureConfig") + + if not cfg.operation.features: + raise ValueError("features must be specified for add_feature operation") + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + output_repo_id, output_dir = get_output_path( + cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None + ) + + if cfg.new_repo_id is None: + dataset.root = Path(str(dataset.root) + "_old") + + # Process features config to load data and prepare for add_features + features_dict = {} + for feature_name, feature_config in cfg.operation.features.items(): + # Extract feature info (dtype, shape, names) + shape = feature_config.get("shape") + dtype = feature_config.get("dtype") + # Convert and validate shape before assignment + if isinstance(shape, list): + shape = tuple(shape) + elif isinstance(shape, tuple) or shape is None: + pass # shape is already valid + else: + raise ValueError( + f"Feature '{feature_name}' has invalid shape type: {type(shape).__name__}. " + "Shape must be a list, tuple, or None." + ) + # Validate required metadata fields + if dtype is None: + raise ValueError(f"Feature '{feature_name}' must specify a 'dtype' (data type)") + if shape is None: + raise ValueError(f"Feature '{feature_name}' must specify a 'shape'") + + feature_info = { + "dtype": dtype, + "shape": shape, + "names": feature_config.get("names"), + } + + # Load feature data from file + feature_file = feature_config.get("file") + if not feature_file: + raise ValueError(f"Feature '{feature_name}' must specify a 'file' path to load data from") + + file_path = Path(feature_file) + if not file_path.exists(): + raise FileNotFoundError(f"Feature file not found: {feature_file}") + + # Load numpy array + if file_path.suffix == ".npy": + feature_data = np.load(file_path) + else: + raise ValueError(f"Unsupported file format for feature '{feature_name}': {file_path.suffix}") + + # Validate data length matches dataset + expected_length = dataset.meta.total_frames + if len(feature_data) != expected_length: + raise ValueError( + f"Feature '{feature_name}' data length ({len(feature_data)}) " + f"does not match dataset length ({expected_length})" + ) + + features_dict[feature_name] = (feature_data, feature_info) + + logging.info(f"Adding features {list(features_dict.keys())} to {cfg.repo_id}") + new_dataset = add_features( + dataset, + features=features_dict, + output_dir=output_dir, + repo_id=output_repo_id, + ) + + logging.info(f"Dataset saved to {output_dir}") + logging.info(f"Updated features: {list(new_dataset.meta.features.keys())}") + + if cfg.push_to_hub: + logging.info(f"Pushing to hub as {output_repo_id}") + LeRobotDataset(output_repo_id, root=output_dir).push_to_hub() + + @parser.wrap() def edit_dataset(cfg: EditDatasetConfig) -> None: operation_type = cfg.operation.type @@ -270,10 +369,12 @@ def edit_dataset(cfg: EditDatasetConfig) -> None: handle_merge(cfg) elif operation_type == "remove_feature": handle_remove_feature(cfg) + elif operation_type == "add_feature": + handle_add_feature(cfg) else: raise ValueError( f"Unknown operation type: {operation_type}\n" - f"Available operations: delete_episodes, split, merge, remove_feature" + f"Available operations: delete_episodes, split, merge, remove_feature, add_feature" )