Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
title: Imitation Learning for Robots
- local: cameras
title: Cameras
- local: bring_your_own_policies
title: Bring Your Own Policies
- local: integrate_hardware
title: Bring Your Own Hardware
- local: hilserl
Expand Down
175 changes: 175 additions & 0 deletions docs/source/bring_your_own_policies.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Bring Your Own Policies

This tutorial explains how to integrate your own custom policy implementations into the LeRobot ecosystem, allowing you to leverage all LeRobot tools for training, evaluation, and deployment while using your own algorithms.

## Step 1: Create a Policy Package

Your custom policy should be organized as an installable Python package following LeRobot's plugin conventions.

### Package Structure

Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:

```bash
lerobot_policy_my_custom_policy/
├── pyproject.toml
└── src/
└── lerobot_policy_my_custom_policy/
├── __init__.py
├── configuration_my_custom_policy.py
├── modeling_my_custom_policy.py
└── processor_my_custom_policy.py
```

### Package Configuration

Set up your `pyproject.toml`:

```toml
[project]
name = "lerobot_policy_my_custom_policy"
version = "0.1.0"
dependencies = [
# your policy-specific dependencies
]
requires-python = ">= 3.11"

[build-system]
build-backend = "hatchling.build"
requires = ["hatchling"]
```

## Step 2: Define the Policy Configuration

Create a configuration class that inherits from `PreTrainedConfig` and registers your policy type:

```python
# configuration_my_custom_policy.py
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode

@PreTrainedConfig.register_subclass("my_custom_policy")
@dataclass
class MyCustomPolicyConfig(PreTrainedConfig):
"""Configuration class for MyCustomPolicy.

Args:
n_obs_steps: Number of observation steps to use as input
horizon: Action prediction horizon
n_action_steps: Number of action steps to execute
hidden_dim: Hidden dimension for the policy network
# Add your policy-specific parameters here
"""
# ...PreTrainedConfig fields...
pass

def __post_init__(self):
super().__post_init__()
# Add any validation logic here

def validate_features(self) -> None:
"""Validate input/output feature compatibility."""
# Implement validation logic for your policy's requirements
pass
```

## Step 3: Implement the Policy Class

Create your policy implementation by inheriting from LeRobot's base `PreTrainedPolicy` class:

```python
# modeling_my_custom_policy.py
import torch
import torch.nn as nn
from typing import Dict, Any

from lerobot.policies.pretrained import PreTrainedPolicy
from .configuration_my_custom_policy import MyCustomPolicyConfig

class MyCustomPolicy(PreTrainedPolicy):
config_class = MyCustomPolicyConfig
name = "my_custom_policy"

def __init__(self, config: MyCustomPolicyConfig, dataset_stats: Dict[str, Any] = None):
super().__init__(config, dataset_stats)
...
```

## Step 4: Add Data Processors

Create processor functions:

```python
# processor_my_custom_policy.py
from typing import Dict, Any
import torch


def make_my_custom_policy_pre_post_processors(
config,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Create preprocessing and postprocessing functions for your policy."""
pass # Define your preprocessing and postprocessing logic here

```

## Step 5: Package Initialization

Expose your classes in the package's `__init__.py`:

```python
# __init__.py
"""Custom policy package for LeRobot."""

try:
import lerobot # noqa: F401
except ImportError:
raise ImportError(
"lerobot is not installed. Please install lerobot to use this policy package."
)

from .configuration_my_custom_policy import MyCustomPolicyConfig
from .modeling_my_custom_policy import MyCustomPolicy
from .processor_my_custom_policy import make_my_custom_policy_pre_post_processors

__all__ = [
"MyCustomPolicyConfig",
"MyCustomPolicy",
"make_my_custom_policy_pre_post_processors",
]
```

## Step 6: Installation and Usage

### Install Your Policy Package

```bash
cd lerobot_policy_my_custom_policy
pip install -e .

# Or install from PyPI if published
pip install lerobot_policy_my_custom_policy
```

### Use Your Policy

Once installed, your policy automatically integrates with LeRobot's training and evaluation tools:

```bash
lerobot-train \
--policy.type my_custom_policy \
--env.type pusht \
--steps 200000
```

## Examples and Community Contributions

Check out these example policy implementations:

- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) - Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)

Share your policy implementations with the community! 🤗
82 changes: 79 additions & 3 deletions src/lerobot/policies/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import importlib
import logging
from typing import Any, TypedDict

Expand Down Expand Up @@ -107,7 +108,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:

return GrootPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
try:
return _get_policy_cls_from_policy_name(name=name)
except Exception as e:
raise ValueError(f"Policy type '{name}' is not available.") from e


def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Expand Down Expand Up @@ -150,7 +154,11 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
elif policy_type == "groot":
return GrootConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
return config_cls(**kwargs)
except Exception as e:
raise ValueError(f"Policy type '{policy_type}' is not available.") from e


class ProcessorConfigKwargs(TypedDict, total=False):
Expand Down Expand Up @@ -330,7 +338,13 @@ def make_pre_post_processors(
)

else:
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
try:
processors = _make_processors_from_policy_config(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
except Exception as e:
raise ValueError(f"Processor for policy type '{policy_cfg.type}' is not implemented.") from e

return processors

Expand Down Expand Up @@ -437,3 +451,65 @@ def make_policy(
f'"observation.images.top": "observation.images.camera2"}}\''
)
return policy


def _get_policy_cls_from_policy_name(name: str) -> type[PreTrainedConfig]:
"""Get policy class from its registered name using dynamic imports.

This is used as a helper function to import policies from 3rd party lerobot plugins.

Args:
name: The name of the policy.
Returns:
The policy class corresponding to the given name.
"""
if name not in PreTrainedConfig.get_known_choices():
raise ValueError(
f"Unknown policy name '{name}'. Available policies: {PreTrainedConfig.get_known_choices()}"
)

config_cls = PreTrainedConfig.get_choice_class(name)
config_cls_name = config_cls.__name__

model_name = config_cls_name.removesuffix("Config") # e.g., DiffusionConfig -> Diffusion
if model_name == config_cls_name:
raise ValueError(
f"The config class name '{config_cls_name}' does not follow the expected naming convention."
f"Make sure it ends with 'Config'!"
)
cls_name = model_name + "Policy" # e.g., DiffusionConfig -> DiffusionPolicy
module_path = config_cls.__module__.replace(
"configuration_", "modeling_"
) # e.g., configuration_diffusion -> modeling_diffusion

module = importlib.import_module(module_path)
policy_cls = getattr(module, cls_name)
return policy_cls


def _make_processors_from_policy_config(
config: PreTrainedConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[Any, Any]:
"""Create pre- and post-processors from a policy configuration using dynamic imports.

This is used as a helper function to import processor factories from 3rd party lerobot plugins.

Args:
config: The policy configuration object.
dataset_stats: Dataset statistics for normalization.
Returns:
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
"""

policy_type = config.type
function_name = f"make_{policy_type}_pre_post_processors"
module_path = config.__class__.__module__.replace(
"configuration_", "processor_"
) # e.g., configuration_diffusion -> processor_diffusion
logging.debug(
f"Instantiating pre/post processors using function '{function_name}' from module '{module_path}'"
)
module = importlib.import_module(module_path)
function = getattr(module, function_name)
return function(config, dataset_stats=dataset_stats)
4 changes: 2 additions & 2 deletions src/lerobot/scripts/lerobot_calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
so100_leader,
so101_leader,
)
from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.utils import init_logging


Expand Down Expand Up @@ -84,7 +84,7 @@ def calibrate(cfg: CalibrateConfig):


def main():
register_third_party_devices()
register_third_party_plugins()
calibrate()


Expand Down
2 changes: 2 additions & 0 deletions src/lerobot/scripts/lerobot_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import (
Expand Down Expand Up @@ -760,6 +761,7 @@ def _agg_from_list(xs):

def main():
init_logging()
register_third_party_plugins()
eval_main()


Expand Down
4 changes: 2 additions & 2 deletions src/lerobot/scripts/lerobot_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@
sanity_check_dataset_name,
sanity_check_dataset_robot_compatibility,
)
from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import (
get_safe_torch_device,
Expand Down Expand Up @@ -512,7 +512,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:


def main():
register_third_party_devices()
register_third_party_plugins()
record()


Expand Down
4 changes: 2 additions & 2 deletions src/lerobot/scripts/lerobot_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
so101_follower,
)
from lerobot.utils.constants import ACTION
from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import (
init_logging,
Expand Down Expand Up @@ -127,7 +127,7 @@ def replay(cfg: ReplayConfig):


def main():
register_third_party_devices()
register_third_party_plugins()
replay()


Expand Down
4 changes: 2 additions & 2 deletions src/lerobot/scripts/lerobot_teleoperate.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
so100_leader,
so101_leader,
)
from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import init_logging, move_cursor_up
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
Expand Down Expand Up @@ -216,7 +216,7 @@ def teleoperate(cfg: TeleoperateConfig):


def main():
register_third_party_devices()
register_third_party_plugins()
teleoperate()


Expand Down
Loading