diff --git a/README.md b/README.md index 56d82c0c70..5d8ad6719d 100644 --- a/README.md +++ b/README.md @@ -290,6 +290,47 @@ lerobot-train --config_path=lerobot/diffusion_pusht reproduces SOTA results for Diffusion Policy on the PushT task. +### Remote policy evaluation (experimental) + +You can delegate action selection to a remote machine by pointing the `remote` +policy to the async inference gRPC policy server. Start the server either +directly or through the compatibility wrapper: + +```bash +# Option 1: run the async inference server module +python -m lerobot.async_inference.policy_server --host 0.0.0.0 --port 8080 + +# Option 2: backward-compatible entry point +python examples/remote/remote_policy_server.py --host 0.0.0.0 --port 8080 +``` + +Then launch evaluation with the remote policy pointing to that server: + +```bash +lerobot-eval \ + --env.type=libero \ + --env.task=libero_spatial \ + --env.max_parallel_tasks=1 \ + --eval.batch_size=1 \ + --eval.n_episodes=3 \ + --policy.type=remote \ + --policy.server_address=localhost:8080 \ + --policy.request_timeout=30 \ + --policy.retries=3 \ + --policy.n_action_steps=10 \ + --policy.remote_policy_type=pi05 \ + --policy.remote_pretrained_name_or_path=lerobot/pi05_libero_finetuned \ + --policy.remote_policy_device=cuda \ + --rename_map='"--rename_map={"observation.images.empty_camera_0":"observation.images.image"}' \ + --output_dir=./eval_logs_libero_spatial +``` + +The optional `additional_args` payload is forwarded to the async inference server +alongside the observation batch and can be adjusted to match your remote model’s +expectations. + +If you omit `--policy.remote_policy_type`, the remote checkpoint’s config is loaded to infer it automatically. + ## Contribute If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md). diff --git a/examples/remote/remote_policy_server.py b/examples/remote/remote_policy_server.py new file mode 100644 index 0000000000..b502902c14 --- /dev/null +++ b/examples/remote/remote_policy_server.py @@ -0,0 +1,12 @@ +""" +Backward-compatible entry point for the async inference policy server. + +Rather than running a custom FastAPI stub, the remote policy now relies on the +shared async inference gRPC implementation. You can start the server from this +module or via ``python -m lerobot.async_inference.policy_server``. +""" + +from lerobot.async_inference.policy_server import serve + +if __name__ == "__main__": + serve() diff --git a/src/lerobot/async_inference/constants.py b/src/lerobot/async_inference/constants.py index 1b1dac0f57..d55aae9c56 100644 --- a/src/lerobot/async_inference/constants.py +++ b/src/lerobot/async_inference/constants.py @@ -23,7 +23,7 @@ DEFAULT_OBS_QUEUE_TIMEOUT = 2 # All action chunking policies -SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"] +SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05", "remote"] # TODO: Add all other robots SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower"] diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index 49f1e0f955..59a0cf01e5 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -20,6 +20,7 @@ from .smolvla.processor_smolvla import SmolVLANewLineProcessor from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig +from .remote.configuration_remote import RemoteConfig as RemoteConfig __all__ = [ "ACTConfig", @@ -29,4 +30,5 @@ "SmolVLAConfig", "TDMPCConfig", "VQBeTConfig", + "RemoteConfig", ] diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 6e524f2ab0..e4173c9b20 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -33,6 +33,7 @@ from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.remote.configuration_remote import RemoteConfig from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig @@ -101,6 +102,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy return SmolVLAPolicy + elif name == "remote": + from lerobot.policies.remote.modeling_remote import RemotePolicy + + return RemotePolicy else: raise NotImplementedError(f"Policy with name {name} is not implemented.") @@ -142,6 +147,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return SmolVLAConfig(**kwargs) elif policy_type == "reward_classifier": return RewardClassifierConfig(**kwargs) + elif policy_type == "remote": + return RemoteConfig(**kwargs) else: raise ValueError(f"Policy type '{policy_type}' is not available.") @@ -293,6 +300,17 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, RemoteConfig): + from lerobot.policies.remote.processor_remote import make_remote_pre_post_processors + + overrides = kwargs.get("preprocessor_overrides") or {} + + processors = make_remote_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + rename_map=overrides.get("rename_observations_processor", {}).get("rename_map", {}), + ) + else: raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.") @@ -350,6 +368,9 @@ def make_policy( policy_cls = get_policy_class(cfg.type) kwargs = {} + if cfg.type == "remote": + cfg.rename_map = rename_map or {} + if ds_meta is not None: features = dataset_to_policy_features(ds_meta.features) else: diff --git a/src/lerobot/policies/remote/__init__.py b/src/lerobot/policies/remote/__init__.py new file mode 100644 index 0000000000..21d9580c69 --- /dev/null +++ b/src/lerobot/policies/remote/__init__.py @@ -0,0 +1,5 @@ +from .configuration_remote import RemoteConfig +from .modeling_remote import RemotePolicy +from .processor_remote import make_remote_pre_post_processors + +__all__ = ["RemoteConfig", "RemotePolicy", "make_remote_pre_post_processors"] diff --git a/src/lerobot/policies/remote/configuration_remote.py b/src/lerobot/policies/remote/configuration_remote.py new file mode 100644 index 0000000000..9cd13ee664 --- /dev/null +++ b/src/lerobot/policies/remote/configuration_remote.py @@ -0,0 +1,127 @@ +from dataclasses import dataclass, field +from typing import Any + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.optim.optimizers import AdamWConfig + + +@PreTrainedConfig.register_subclass("remote") +@dataclass +class RemoteConfig(PreTrainedConfig): + # Identity and device placement + type: str = field(default="remote", metadata={"help": "Policy type name"}) + device: str = field(default="cpu", metadata={"help": "Device used for returned tensors"}) + + # Action execution + # How many environment steps to execute per policy call. Used by the runtime action queue. + n_action_steps: int = field(default=1, metadata={"help": "Number of env steps to execute per call"}) + + # Remote-specific (gRPC policy server) + server_address: str = field( + default="localhost:8080", metadata={"help": "Async inference policy server address (host:port)"} + ) + request_timeout: float = field(default=30.0, metadata={"help": "gRPC request timeout in seconds"}) + retries: int = field(default=3, metadata={"help": "Number of retry attempts for failed RPC calls"}) + + remote_policy_type: str = field( + default="", + metadata={"help": "Policy type for the async inference server to load (e.g. act, diffusion)"}, + ) + remote_pretrained_name_or_path: str = field( + default="", + metadata={ + "help": ( + "Pretrained model repo ID or path for the async inference server. " + "Should match a directory containing policy weights or a Hugging Face repo ID." + ) + }, + ) + remote_policy_device: str = field( + default="cpu", metadata={"help": "Device on which the async inference server loads the policy"} + ) + + actions_per_chunk: int | None = field( + default=None, + metadata={ + "help": ( + "Number of actions returned per chunk by the remote server. " + "Defaults to `n_action_steps` when not provided." + ) + }, + ) + rename_map: dict[str, str] = field( + default_factory=dict, + metadata={ + "help": ( + "Observation rename map forwarded to the async inference server so it can match " + "environment keys to the policy's expected features." + ) + }, + ) + + # Additional arguments to inject directly into the observation dict (e.g. {"inference_config": {...}}) + additional_args: dict[str, Any] = field( + default_factory=dict, + metadata={"help": "Extra observation keys to inject directly into observation"}, + ) + + # --- Abstract API implementations required by PreTrainedConfig --- + def get_optimizer_preset(self) -> AdamWConfig: + """Remote policy is inference-only; return an inert preset for API compatibility.""" + return AdamWConfig(lr=1e-5, weight_decay=0.0, grad_clip_norm=1.0) + + def get_scheduler_preset(self): + # No scheduler needed for inference-only policy + return None + + def validate_features(self) -> None: + if not self.remote_pretrained_name_or_path: + raise ValueError( + "RemoteConfig expects `remote_pretrained_name_or_path` to be provided so the server can load the policy." + ) + + remote_cfg: PreTrainedConfig | None = None + if not self.remote_policy_type or not self.input_features or not self.output_features: + remote_cfg = PreTrainedConfig.from_pretrained(self.remote_pretrained_name_or_path) + + if not self.remote_policy_type: + self.remote_policy_type = remote_cfg.type if remote_cfg is not None else "" + + if remote_cfg is not None and remote_cfg.type != self.remote_policy_type: + raise ValueError( + f"Loaded remote policy config type '{remote_cfg.type}' does not match " + f"requested remote_policy_type '{self.remote_policy_type}'." + ) + + if not self.input_features and remote_cfg is not None: + self.input_features = remote_cfg.input_features + + if not self.output_features and remote_cfg is not None: + self.output_features = remote_cfg.output_features + + if not self.input_features: + raise ValueError("RemoteConfig requires `input_features` to be defined.") + if not self.remote_policy_type: + raise ValueError("RemoteConfig expects `remote_policy_type` to be set for async inference.") + if self.effective_actions_per_chunk <= 0: + raise ValueError("RemoteConfig requires `actions_per_chunk` or `n_action_steps` to be positive.") + if self.retries < 1: + raise ValueError("RemoteConfig expects `retries` to be at least 1.") + + @property + def effective_actions_per_chunk(self) -> int: + return self.actions_per_chunk or self.n_action_steps + + @property + def observation_delta_indices(self): + # No temporal deltas required for observations by default + return None + + @property + def action_delta_indices(self): + # Minimal behavior: align deltas to n_action_steps + return list(range(self.n_action_steps)) + + @property + def reward_delta_indices(self): + return None diff --git a/src/lerobot/policies/remote/modeling_remote.py b/src/lerobot/policies/remote/modeling_remote.py new file mode 100644 index 0000000000..616284fdae --- /dev/null +++ b/src/lerobot/policies/remote/modeling_remote.py @@ -0,0 +1,287 @@ +import logging +import pickle # nosec B403 - trusted channel between client/server +import threading +import time +from collections import deque +from typing import Any + +import grpc +import torch +from torch import Tensor + +from lerobot.async_inference.helpers import RemotePolicyConfig, TimedAction, TimedObservation +from lerobot.configs.types import FeatureType +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.transport import services_pb2, services_pb2_grpc +from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks +from lerobot.utils.constants import OBS_STR + +from .configuration_remote import RemoteConfig + +logger = logging.getLogger(__name__) + + +class RemotePolicy(PreTrainedPolicy): + """ + A policy that proxies inference to the async inference gRPC policy server. + """ + + config_class = RemoteConfig + name = "remote" + + def __init__(self, config: RemoteConfig): + super().__init__(config) + config.validate_features() + self._vector_name_map: dict[str, list[str]] = {} + self._image_key_map: dict[str, str] = {} + self._lerobot_features = self._build_lerobot_features() + self._thread_state = threading.local() + self.reset() + + def get_optim_params(self) -> dict: + return {} + + def reset(self): + # Reinitialize thread-local state so each worker gets its own queue/session + self._thread_state = threading.local() + + def _state(self): + state = self._thread_state + if not hasattr(state, "action_queue"): + state.action_queue = deque(maxlen=self.config.n_action_steps) + if not hasattr(state, "stub") or state.stub is None: + self._initialize_connection(state) + return state + + def _initialize_connection(self, state) -> None: + state.channel = grpc.insecure_channel( + self.config.server_address, + options=grpc_channel_options(), + ) + state.stub = services_pb2_grpc.AsyncInferenceStub(state.channel) + state.next_timestep = 0 + + policy_cfg = RemotePolicyConfig( + policy_type=self.config.remote_policy_type, + pretrained_name_or_path=self.config.remote_pretrained_name_or_path, + lerobot_features=self._lerobot_features, + actions_per_chunk=self.config.effective_actions_per_chunk, + device=self.config.remote_policy_device, + rename_map=self.config.rename_map, + ) + + payload = pickle.dumps(policy_cfg) # nosec B301 - config originates from local process + request = services_pb2.PolicySetup(data=payload) + + for attempt in range(1, self.config.retries + 1): + try: + state.stub.Ready(services_pb2.Empty(), timeout=self.config.request_timeout) + state.stub.SendPolicyInstructions(request, timeout=self.config.request_timeout) + logger.debug("Remote policy handshake completed on attempt %d", attempt) + return + except grpc.RpcError as err: + logger.warning("Remote policy handshake failed on attempt %d: %s", attempt, err) + self._close_channel(state) + if attempt == self.config.retries: + raise + time.sleep(0.1) + state.channel = grpc.insecure_channel( + self.config.server_address, + options=grpc_channel_options(), + ) + state.stub = services_pb2_grpc.AsyncInferenceStub(state.channel) + + def _close_channel(self, state) -> None: + if getattr(state, "channel", None) is not None: + state.channel.close() + state.stub = None + + def _build_lerobot_features(self) -> dict[str, dict[str, Any]]: + """ + Build a hw-style feature dictionary expected by the async inference server. + Vector features (state/env) are split into individual scalar names, while image features + are mapped to (H, W, C) tensors keyed by their camera name. + """ + features: dict[str, dict[str, Any]] = {} + vector_name_map: dict[str, list[str]] = {} + image_key_map: dict[str, str] = {} + + for key, feature in self.config.input_features.items(): + if feature.type in (FeatureType.STATE, FeatureType.ENV): + if not feature.shape or len(feature.shape) != 1: + raise ValueError( + f"RemotePolicy only supports 1D state features, got shape {feature.shape} for '{key}'." + ) + dim = feature.shape[0] + names = [f"{key.replace('.', '_')}_d{idx}" for idx in range(dim)] + features[key] = { + "dtype": "float32", + "shape": (dim,), + "names": names, + } + vector_name_map[key] = names + elif feature.type is FeatureType.VISUAL: + if not feature.shape or len(feature.shape) != 3: + raise ValueError( + f"RemotePolicy only supports 3D visual features, got shape {feature.shape} for '{key}'." + ) + channels, height, width = feature.shape + camera_base = key.removeprefix(f"{OBS_STR}.images.") + # Ensure uniqueness if multiple features share the same suffix + raw_key = camera_base + counter = 1 + while raw_key in image_key_map.values(): + raw_key = f"{camera_base}_{counter}" + counter += 1 + + features[key] = { + "dtype": "video", + "shape": (height, width, channels), + "names": ["height", "width", "channels"], + } + image_key_map[key] = raw_key + else: + logger.debug("Skipping unsupported feature '%s' of type '%s'", key, feature.type) + + self._vector_name_map = vector_name_map + self._image_key_map = image_key_map + return features + + def _prepare_payload(self, batch: dict[str, Tensor]) -> dict[str, Any]: + if not batch: + raise ValueError("RemotePolicy received an empty batch.") + + payload: dict[str, Any] = {} + cpu_batch: dict[str, Any] = { + key: value.detach().cpu() if isinstance(value, torch.Tensor) else value + for key, value in batch.items() + } + + # Serialize vector features (state/env) into individual scalar entries + for key, names in self._vector_name_map.items(): + tensor = cpu_batch.get(key) + if tensor is None: + continue + + if isinstance(tensor, torch.Tensor): + if tensor.ndim == 2: + tensor = tensor.squeeze(0) + tensor = tensor.flatten() + if tensor.numel() != len(names): + raise ValueError( + f"Feature '{key}' expected {len(names)} values, got shape {tuple(tensor.shape)}." + ) + for idx, name in enumerate(names): + payload[name] = float(tensor[idx].item()) + else: + raise TypeError(f"Expected tensor for feature '{key}', got {type(tensor)}") + + # Serialize image features (convert to HWC uint8 tensors) + for key, raw_key in self._image_key_map.items(): + tensor = cpu_batch.get(key) + if tensor is None: + continue + + if not isinstance(tensor, torch.Tensor): + raise TypeError(f"Expected tensor for image feature '{key}', got {type(tensor)}") + + if tensor.ndim == 4: + tensor = tensor.squeeze(0) + if tensor.ndim != 3: + raise ValueError( + f"Image feature '{key}' must have 3 dimensions after squeeze, got {tensor.ndim}" + ) + + if tensor.dtype != torch.uint8: + tensor = (tensor.clamp(0.0, 1.0) * 255.0).to(torch.uint8) + + payload[raw_key] = tensor.permute(1, 2, 0).contiguous() + + # Optional task/instruction keys + for extra_key in ["task", "instruction"]: + if extra_key in cpu_batch: + payload[extra_key] = cpu_batch[extra_key] + + for key, value in (self.config.additional_args or {}).items(): + payload[key] = value + + return payload + + def _timed_actions_to_tensor(self, timed_actions: list[TimedAction]) -> Tensor: + if not timed_actions: + raise RuntimeError("Remote policy server returned an empty action chunk.") + + actions = [] + for timed_action in timed_actions: + action = timed_action.get_action() + if isinstance(action, torch.Tensor): + actions.append(action.detach().cpu()) + else: + actions.append(torch.as_tensor(action, dtype=torch.float32)) + + stacked = torch.stack(actions, dim=0).unsqueeze(0) # (B=1, T, A) + return stacked.to(device=self.config.device, dtype=torch.float32) + + def _request_action_chunk(self, state, batch: dict[str, Tensor]) -> Tensor: + payload = self._prepare_payload(batch) + timestamp = time.time() + timestep = state.next_timestep + state.next_timestep += 1 + + observation = TimedObservation( + timestamp=timestamp, + timestep=timestep, + observation=payload, + must_go=True, + ) + packed = pickle.dumps(observation) # nosec B301 - observation built locally + + iterator = send_bytes_in_chunks( + packed, + services_pb2.Observation, + log_prefix="[RemotePolicy]", + silent=True, + ) + state.stub.SendObservations(iterator, timeout=self.config.request_timeout) + actions_msg = state.stub.GetActions(services_pb2.Empty(), timeout=self.config.request_timeout) + if not actions_msg.data: + raise RuntimeError("Remote policy server returned an empty response payload.") + + timed_actions = pickle.loads(actions_msg.data) # nosec B301 - server is trusted peer + return self._timed_actions_to_tensor(timed_actions) + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict] | tuple[Tensor, None]: + raise NotImplementedError("RemotePolicy is inference-only") + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor: + last_error: Exception | None = None + + for attempt in range(1, self.config.retries + 1): + state = self._state() + try: + return self._request_action_chunk(state, batch) + except grpc.RpcError as err: + logger.warning("Remote policy RPC failed on attempt %d: %s", attempt, err) + last_error = err + self._close_channel(state) + time.sleep(0.1) + except Exception as err: + logger.error("Unexpected error when requesting remote action chunk: %s", err) + last_error = err + break + + assert last_error is not None + raise last_error + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor: + self.eval() + + queue = self._state().action_queue + + if len(queue) == 0: + actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] + queue.extend(actions.transpose(0, 1)) # [(B, A)] x T + + return queue.popleft() diff --git a/src/lerobot/policies/remote/processor_remote.py b/src/lerobot/policies/remote/processor_remote.py new file mode 100644 index 0000000000..7329f0dc05 --- /dev/null +++ b/src/lerobot/policies/remote/processor_remote.py @@ -0,0 +1,65 @@ +from dataclasses import dataclass, field +from typing import Any + +import torch + +from lerobot.policies.remote.configuration_remote import RemoteConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + RenameObservationsProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.utils.constants import ( + POLICY_POSTPROCESSOR_DEFAULT_NAME, + POLICY_PREPROCESSOR_DEFAULT_NAME, +) + + +def make_remote_pre_post_processors( + config: RemoteConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, + rename_map: dict[str, str] = {}, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Custom pre/post processors for the Remote policy. + + Pre: + - Normalizer (if stats provided) + - AddBatchDimension + - AppendInferenceConfig (copies config.inference_config into the batch) + - Device placement + + Post: + - Device to CPU + - Unnormalize outputs (if stats provided) + """ + + # Pre: allow renaming features and add batch dim. Rename map can be overridden at runtime + # through preprocessor_overrides with the key "rename_observations_processor". + input_steps: list[ProcessorStep] = [ + RenameObservationsProcessorStep(rename_map=rename_map), + AddBatchDimensionProcessorStep(), + ] + + # Minimal postprocessor: identity (no steps) + output_steps: list[ProcessorStep] = [] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) \ No newline at end of file