Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,47 @@ lerobot-train --config_path=lerobot/diffusion_pusht

reproduces SOTA results for Diffusion Policy on the PushT task.

### Remote policy evaluation (experimental)

You can delegate action selection to a remote machine by pointing the `remote`
policy to the async inference gRPC policy server. Start the server either
directly or through the compatibility wrapper:

```bash
# Option 1: run the async inference server module
python -m lerobot.async_inference.policy_server --host 0.0.0.0 --port 8080

# Option 2: backward-compatible entry point
python examples/remote/remote_policy_server.py --host 0.0.0.0 --port 8080
```

Then launch evaluation with the remote policy pointing to that server:

```bash
lerobot-eval \
--env.type=libero \
--env.task=libero_spatial \
--env.max_parallel_tasks=1 \
--eval.batch_size=1 \
--eval.n_episodes=3 \
--policy.type=remote \
--policy.server_address=localhost:8080 \
--policy.request_timeout=30 \
--policy.retries=3 \
--policy.n_action_steps=10 \
--policy.remote_policy_type=pi05 \
--policy.remote_pretrained_name_or_path=lerobot/pi05_libero_finetuned \
--policy.remote_policy_device=cuda \
--rename_map='"--rename_map={"observation.images.empty_camera_0":"observation.images.image"}' \
--output_dir=./eval_logs_libero_spatial
```

The optional `additional_args` payload is forwarded to the async inference server
alongside the observation batch and can be adjusted to match your remote model’s
expectations.

If you omit `--policy.remote_policy_type`, the remote checkpoint’s config is loaded to infer it automatically.

## Contribute

If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md).
Expand Down
12 changes: 12 additions & 0 deletions examples/remote/remote_policy_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
Backward-compatible entry point for the async inference policy server.

Rather than running a custom FastAPI stub, the remote policy now relies on the
shared async inference gRPC implementation. You can start the server from this
module or via ``python -m lerobot.async_inference.policy_server``.
"""

from lerobot.async_inference.policy_server import serve

if __name__ == "__main__":
serve()
2 changes: 1 addition & 1 deletion src/lerobot/async_inference/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
DEFAULT_OBS_QUEUE_TIMEOUT = 2

# All action chunking policies
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05", "remote"]

# TODO: Add all other robots
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower"]
2 changes: 2 additions & 0 deletions src/lerobot/policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
from .remote.configuration_remote import RemoteConfig as RemoteConfig

__all__ = [
"ACTConfig",
Expand All @@ -29,4 +30,5 @@
"SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",
"RemoteConfig",
]
21 changes: 21 additions & 0 deletions src/lerobot/policies/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.remote.configuration_remote import RemoteConfig
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
Expand Down Expand Up @@ -101,6 +102,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy

return SmolVLAPolicy
elif name == "remote":
from lerobot.policies.remote.modeling_remote import RemotePolicy

return RemotePolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")

Expand Down Expand Up @@ -142,6 +147,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return SmolVLAConfig(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
elif policy_type == "remote":
return RemoteConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")

Expand Down Expand Up @@ -293,6 +300,17 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)

elif isinstance(policy_cfg, RemoteConfig):
from lerobot.policies.remote.processor_remote import make_remote_pre_post_processors

overrides = kwargs.get("preprocessor_overrides") or {}

processors = make_remote_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
rename_map=overrides.get("rename_observations_processor", {}).get("rename_map", {}),
)

else:
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")

Expand Down Expand Up @@ -350,6 +368,9 @@ def make_policy(
policy_cls = get_policy_class(cfg.type)

kwargs = {}
if cfg.type == "remote":
cfg.rename_map = rename_map or {}

if ds_meta is not None:
features = dataset_to_policy_features(ds_meta.features)
else:
Expand Down
5 changes: 5 additions & 0 deletions src/lerobot/policies/remote/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .configuration_remote import RemoteConfig
from .modeling_remote import RemotePolicy
from .processor_remote import make_remote_pre_post_processors

__all__ = ["RemoteConfig", "RemotePolicy", "make_remote_pre_post_processors"]
127 changes: 127 additions & 0 deletions src/lerobot/policies/remote/configuration_remote.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from dataclasses import dataclass, field
from typing import Any

from lerobot.configs.policies import PreTrainedConfig
from lerobot.optim.optimizers import AdamWConfig


@PreTrainedConfig.register_subclass("remote")
@dataclass
class RemoteConfig(PreTrainedConfig):
# Identity and device placement
type: str = field(default="remote", metadata={"help": "Policy type name"})
device: str = field(default="cpu", metadata={"help": "Device used for returned tensors"})

# Action execution
# How many environment steps to execute per policy call. Used by the runtime action queue.
n_action_steps: int = field(default=1, metadata={"help": "Number of env steps to execute per call"})

# Remote-specific (gRPC policy server)
server_address: str = field(
default="localhost:8080", metadata={"help": "Async inference policy server address (host:port)"}
)
request_timeout: float = field(default=30.0, metadata={"help": "gRPC request timeout in seconds"})
retries: int = field(default=3, metadata={"help": "Number of retry attempts for failed RPC calls"})

remote_policy_type: str = field(
default="",
metadata={"help": "Policy type for the async inference server to load (e.g. act, diffusion)"},
)
remote_pretrained_name_or_path: str = field(
default="",
metadata={
"help": (
"Pretrained model repo ID or path for the async inference server. "
"Should match a directory containing policy weights or a Hugging Face repo ID."
)
},
)
remote_policy_device: str = field(
default="cpu", metadata={"help": "Device on which the async inference server loads the policy"}
)

actions_per_chunk: int | None = field(
default=None,
metadata={
"help": (
"Number of actions returned per chunk by the remote server. "
"Defaults to `n_action_steps` when not provided."
)
},
)
rename_map: dict[str, str] = field(
default_factory=dict,
metadata={
"help": (
"Observation rename map forwarded to the async inference server so it can match "
"environment keys to the policy's expected features."
)
},
)

# Additional arguments to inject directly into the observation dict (e.g. {"inference_config": {...}})
additional_args: dict[str, Any] = field(
default_factory=dict,
metadata={"help": "Extra observation keys to inject directly into observation"},
)

# --- Abstract API implementations required by PreTrainedConfig ---
def get_optimizer_preset(self) -> AdamWConfig:
"""Remote policy is inference-only; return an inert preset for API compatibility."""
return AdamWConfig(lr=1e-5, weight_decay=0.0, grad_clip_norm=1.0)

def get_scheduler_preset(self):
# No scheduler needed for inference-only policy
return None

def validate_features(self) -> None:
if not self.remote_pretrained_name_or_path:
raise ValueError(
"RemoteConfig expects `remote_pretrained_name_or_path` to be provided so the server can load the policy."
)

remote_cfg: PreTrainedConfig | None = None
if not self.remote_policy_type or not self.input_features or not self.output_features:
remote_cfg = PreTrainedConfig.from_pretrained(self.remote_pretrained_name_or_path)

if not self.remote_policy_type:
self.remote_policy_type = remote_cfg.type if remote_cfg is not None else ""

if remote_cfg is not None and remote_cfg.type != self.remote_policy_type:
raise ValueError(
f"Loaded remote policy config type '{remote_cfg.type}' does not match "
f"requested remote_policy_type '{self.remote_policy_type}'."
)

if not self.input_features and remote_cfg is not None:
self.input_features = remote_cfg.input_features

if not self.output_features and remote_cfg is not None:
self.output_features = remote_cfg.output_features

if not self.input_features:
raise ValueError("RemoteConfig requires `input_features` to be defined.")
if not self.remote_policy_type:
raise ValueError("RemoteConfig expects `remote_policy_type` to be set for async inference.")
if self.effective_actions_per_chunk <= 0:
raise ValueError("RemoteConfig requires `actions_per_chunk` or `n_action_steps` to be positive.")
if self.retries < 1:
raise ValueError("RemoteConfig expects `retries` to be at least 1.")

@property
def effective_actions_per_chunk(self) -> int:
return self.actions_per_chunk or self.n_action_steps

@property
def observation_delta_indices(self):
# No temporal deltas required for observations by default
return None

@property
def action_delta_indices(self):
# Minimal behavior: align deltas to n_action_steps
return list(range(self.n_action_steps))

@property
def reward_delta_indices(self):
return None
Loading