diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 3e36a0d981..499150af00 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -9,6 +9,8 @@ title: Imitation Learning for Robots - local: cameras title: Cameras + - local: bring_your_own_policies + title: Bring Your Own Policies - local: integrate_hardware title: Bring Your Own Hardware - local: hilserl diff --git a/docs/source/bring_your_own_policies.mdx b/docs/source/bring_your_own_policies.mdx new file mode 100644 index 0000000000..71636bec20 --- /dev/null +++ b/docs/source/bring_your_own_policies.mdx @@ -0,0 +1,175 @@ +# Bring Your Own Policies + +This tutorial explains how to integrate your own custom policy implementations into the LeRobot ecosystem, allowing you to leverage all LeRobot tools for training, evaluation, and deployment while using your own algorithms. + +## Step 1: Create a Policy Package + +Your custom policy should be organized as an installable Python package following LeRobot's plugin conventions. + +### Package Structure + +Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name: + +```bash +lerobot_policy_my_custom_policy/ +├── pyproject.toml +└── src/ + └── lerobot_policy_my_custom_policy/ + ├── __init__.py + ├── configuration_my_custom_policy.py + ├── modeling_my_custom_policy.py + └── processor_my_custom_policy.py +``` + +### Package Configuration + +Set up your `pyproject.toml`: + +```toml +[project] +name = "lerobot_policy_my_custom_policy" +version = "0.1.0" +dependencies = [ + # your policy-specific dependencies +] +requires-python = ">= 3.11" + +[build-system] +build-backend = "hatchling.build" +requires = ["hatchling"] +``` + +## Step 2: Define the Policy Configuration + +Create a configuration class that inherits from `PreTrainedConfig` and registers your policy type: + +```python +# configuration_my_custom_policy.py +from dataclasses import dataclass, field +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode + +@PreTrainedConfig.register_subclass("my_custom_policy") +@dataclass +class MyCustomPolicyConfig(PreTrainedConfig): + """Configuration class for MyCustomPolicy. + + Args: + n_obs_steps: Number of observation steps to use as input + horizon: Action prediction horizon + n_action_steps: Number of action steps to execute + hidden_dim: Hidden dimension for the policy network + # Add your policy-specific parameters here + """ + # ...PreTrainedConfig fields... + pass + + def __post_init__(self): + super().__post_init__() + # Add any validation logic here + + def validate_features(self) -> None: + """Validate input/output feature compatibility.""" + # Implement validation logic for your policy's requirements + pass +``` + +## Step 3: Implement the Policy Class + +Create your policy implementation by inheriting from LeRobot's base `PreTrainedPolicy` class: + +```python +# modeling_my_custom_policy.py +import torch +import torch.nn as nn +from typing import Dict, Any + +from lerobot.policies.pretrained import PreTrainedPolicy +from .configuration_my_custom_policy import MyCustomPolicyConfig + +class MyCustomPolicy(PreTrainedPolicy): + config_class = MyCustomPolicyConfig + name = "my_custom_policy" + + def __init__(self, config: MyCustomPolicyConfig, dataset_stats: Dict[str, Any] = None): + super().__init__(config, dataset_stats) + ... +``` + +## Step 4: Add Data Processors + +Create processor functions: + +```python +# processor_my_custom_policy.py +from typing import Dict, Any +import torch + + +def make_my_custom_policy_pre_post_processors( + config, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """Create preprocessing and postprocessing functions for your policy.""" + pass # Define your preprocessing and postprocessing logic here + +``` + +## Step 5: Package Initialization + +Expose your classes in the package's `__init__.py`: + +```python +# __init__.py +"""Custom policy package for LeRobot.""" + +try: + import lerobot # noqa: F401 +except ImportError: + raise ImportError( + "lerobot is not installed. Please install lerobot to use this policy package." + ) + +from .configuration_my_custom_policy import MyCustomPolicyConfig +from .modeling_my_custom_policy import MyCustomPolicy +from .processor_my_custom_policy import make_my_custom_policy_pre_post_processors + +__all__ = [ + "MyCustomPolicyConfig", + "MyCustomPolicy", + "make_my_custom_policy_pre_post_processors", +] +``` + +## Step 6: Installation and Usage + +### Install Your Policy Package + +```bash +cd lerobot_policy_my_custom_policy +pip install -e . + +# Or install from PyPI if published +pip install lerobot_policy_my_custom_policy +``` + +### Use Your Policy + +Once installed, your policy automatically integrates with LeRobot's training and evaluation tools: + +```bash +lerobot-train \ + --policy.type my_custom_policy \ + --env.type pusht \ + --steps 200000 +``` + +## Examples and Community Contributions + +Check out these example policy implementations: + +- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) - Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow) + +Share your policy implementations with the community! 🤗 diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index bdad5cbb3f..1798dab9f5 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -16,6 +16,7 @@ from __future__ import annotations +import importlib import logging from typing import Any, TypedDict @@ -107,7 +108,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: return GrootPolicy else: - raise NotImplementedError(f"Policy with name {name} is not implemented.") + try: + return _get_policy_cls_from_policy_name(name=name) + except Exception as e: + raise ValueError(f"Policy type '{name}' is not available.") from e def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: @@ -150,7 +154,11 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: elif policy_type == "groot": return GrootConfig(**kwargs) else: - raise ValueError(f"Policy type '{policy_type}' is not available.") + try: + config_cls = PreTrainedConfig.get_choice_class(policy_type) + return config_cls(**kwargs) + except Exception as e: + raise ValueError(f"Policy type '{policy_type}' is not available.") from e class ProcessorConfigKwargs(TypedDict, total=False): @@ -330,7 +338,13 @@ def make_pre_post_processors( ) else: - raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.") + try: + processors = _make_processors_from_policy_config( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + except Exception as e: + raise ValueError(f"Processor for policy type '{policy_cfg.type}' is not implemented.") from e return processors @@ -437,3 +451,65 @@ def make_policy( f'"observation.images.top": "observation.images.camera2"}}\'' ) return policy + + +def _get_policy_cls_from_policy_name(name: str) -> type[PreTrainedConfig]: + """Get policy class from its registered name using dynamic imports. + + This is used as a helper function to import policies from 3rd party lerobot plugins. + + Args: + name: The name of the policy. + Returns: + The policy class corresponding to the given name. + """ + if name not in PreTrainedConfig.get_known_choices(): + raise ValueError( + f"Unknown policy name '{name}'. Available policies: {PreTrainedConfig.get_known_choices()}" + ) + + config_cls = PreTrainedConfig.get_choice_class(name) + config_cls_name = config_cls.__name__ + + model_name = config_cls_name.removesuffix("Config") # e.g., DiffusionConfig -> Diffusion + if model_name == config_cls_name: + raise ValueError( + f"The config class name '{config_cls_name}' does not follow the expected naming convention." + f"Make sure it ends with 'Config'!" + ) + cls_name = model_name + "Policy" # e.g., DiffusionConfig -> DiffusionPolicy + module_path = config_cls.__module__.replace( + "configuration_", "modeling_" + ) # e.g., configuration_diffusion -> modeling_diffusion + + module = importlib.import_module(module_path) + policy_cls = getattr(module, cls_name) + return policy_cls + + +def _make_processors_from_policy_config( + config: PreTrainedConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[Any, Any]: + """Create pre- and post-processors from a policy configuration using dynamic imports. + + This is used as a helper function to import processor factories from 3rd party lerobot plugins. + + Args: + config: The policy configuration object. + dataset_stats: Dataset statistics for normalization. + Returns: + A tuple containing the input (pre-processor) and output (post-processor) pipelines. + """ + + policy_type = config.type + function_name = f"make_{policy_type}_pre_post_processors" + module_path = config.__class__.__module__.replace( + "configuration_", "processor_" + ) # e.g., configuration_diffusion -> processor_diffusion + logging.debug( + f"Instantiating pre/post processors using function '{function_name}' from module '{module_path}'" + ) + module = importlib.import_module(module_path) + function = getattr(module, function_name) + return function(config, dataset_stats=dataset_stats) diff --git a/src/lerobot/scripts/lerobot_calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py index 0f247caefe..8247ec0535 100644 --- a/src/lerobot/scripts/lerobot_calibrate.py +++ b/src/lerobot/scripts/lerobot_calibrate.py @@ -52,7 +52,7 @@ so100_leader, so101_leader, ) -from lerobot.utils.import_utils import register_third_party_devices +from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.utils import init_logging @@ -84,7 +84,7 @@ def calibrate(cfg: CalibrateConfig): def main(): - register_third_party_devices() + register_third_party_plugins() calibrate() diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index 0d66fa1aa0..79823f62be 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -82,6 +82,7 @@ from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.processor import PolicyAction, PolicyProcessorPipeline from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD +from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.io_utils import write_video from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( @@ -760,6 +761,7 @@ def _agg_from_list(xs): def main(): init_logging() + register_third_party_plugins() eval_main() diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 6df92d893b..5220d69875 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -118,7 +118,7 @@ sanity_check_dataset_name, sanity_check_dataset_robot_compatibility, ) -from lerobot.utils.import_utils import register_third_party_devices +from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import ( get_safe_torch_device, @@ -512,7 +512,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset: def main(): - register_third_party_devices() + register_third_party_plugins() record() diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index ffd7b2b22a..7f00aacb9c 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -61,7 +61,7 @@ so101_follower, ) from lerobot.utils.constants import ACTION -from lerobot.utils.import_utils import register_third_party_devices +from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import ( init_logging, @@ -127,7 +127,7 @@ def replay(cfg: ReplayConfig): def main(): - register_third_party_devices() + register_third_party_plugins() replay() diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index 0a418f3bca..8fea937f3a 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -88,7 +88,7 @@ so100_leader, so101_leader, ) -from lerobot.utils.import_utils import register_third_party_devices +from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import init_logging, move_cursor_up from lerobot.utils.visualization_utils import init_rerun, log_rerun_data @@ -216,7 +216,7 @@ def teleoperate(cfg: TeleoperateConfig): def main(): - register_third_party_devices() + register_third_party_plugins() teleoperate() diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 0cc6e037fd..c999d58e22 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -36,6 +36,7 @@ from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.rl.wandb_utils import WandBLogger from lerobot.scripts.lerobot_eval import eval_policy_all +from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.logging_utils import AverageMeter, MetricsTracker from lerobot.utils.random_utils import set_seed from lerobot.utils.train_utils import ( @@ -441,6 +442,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): def main(): + register_third_party_plugins() train() diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index b9a9e68252..0dd9db5160 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -130,14 +130,14 @@ class name and tries a few candidate modules where the device implementation is ) -def register_third_party_devices() -> None: +def register_third_party_plugins() -> None: """ Discover and import third-party lerobot_* plugins so they can register themselves. Scans top-level modules on sys.path for packages starting with - 'lerobot_robot_', 'lerobot_camera_' or 'lerobot_teleoperator_' and imports them. + 'lerobot_robot_', 'lerobot_camera_', 'lerobot_teleoperator_' or 'lerobot_policy_' and imports them. """ - prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_") + prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_", "lerobot_policy_") imported: list[str] = [] failed: list[str] = []