-
Notifications
You must be signed in to change notification settings - Fork 2.9k
feat(policies): Add remote http policy, to benchmark custom models #2330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
grach0v
wants to merge
9
commits into
huggingface:main
Choose a base branch
from
grach0v:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 4 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
84fa84d
Remote policy v0.1
grach0v d6010a4
Simplify changes
denis-grachev 52581d2
Fix max_parallel_workers and some other little bugs
denis-grachev 76d64e6
Little bug fixed
denis-grachev d441d5f
Adjust for copilot codereview
denis-grachev 30c91a0
move to grpc
denis-grachev 23e8fda
Remove unused messaging.py
denis-grachev d6427f0
restore toml file
denis-grachev 5b58269
cosmetic fixes
denis-grachev File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| import torch | ||
| import numpy as np | ||
| from fastapi import FastAPI, Request, Response | ||
|
|
||
| from lerobot.utils.messaging import pack_msg, unpack_msg | ||
|
|
||
| app = FastAPI() | ||
|
|
||
|
|
||
| @app.post("/predict") | ||
| async def predict(request: Request): | ||
| data = await request.body() | ||
| obs_input = unpack_msg(data) | ||
|
|
||
| inf_cfg = obs_input.get("inference_config", {}) | ||
| dataset_info = obs_input.get("dataset_info", {}) | ||
| n_action_steps = inf_cfg.get("n_action_steps", 10) | ||
| action_dim = dataset_info.get("action_dof", 7) | ||
|
|
||
| # Try to infer batch size from any array-like input | ||
| B = None | ||
| for v in obs_input.values(): | ||
| if isinstance(v, torch.Tensor) or isinstance(v, np.ndarray): | ||
| if v.ndim >= 1: | ||
| B = int(v.shape[0]) | ||
| break | ||
|
|
||
| actions = torch.zeros((B, n_action_steps, action_dim), dtype=torch.float32) | ||
|
|
||
| packed = pack_msg(actions) | ||
| return Response(content=packed, media_type="application/octet-stream") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| from .configuration_remote import RemoteConfig | ||
| from .modeling_remote import RemotePolicy | ||
| from .processor_remote import make_remote_pre_post_processors | ||
|
|
||
| __all__ = ["RemoteConfig", "RemotePolicy", "make_remote_pre_post_processors"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,55 @@ | ||||||
| from dataclasses import dataclass, field | ||||||
| from typing import Any | ||||||
|
|
||||||
| from lerobot.configs.policies import PreTrainedConfig | ||||||
| from lerobot.optim.optimizers import AdamWConfig | ||||||
|
|
||||||
| @PreTrainedConfig.register_subclass("remote") | ||||||
| @dataclass | ||||||
| class RemoteConfig(PreTrainedConfig): | ||||||
| # Identity and device placement | ||||||
| type: str = field(default="remote", metadata={"help": "Policy type name"}) | ||||||
| device: str = field(default="cpu", metadata={"help": "Device used for returned tensors"}) | ||||||
|
|
||||||
| # Action execution | ||||||
| # How many environment steps to execute per policy call. Used by the runtime action queue. | ||||||
| n_action_steps: int = field(default=1, metadata={"help": "Number of env steps to execute per call"}) | ||||||
|
|
||||||
| # Remote-specific | ||||||
| server_url: str = field(default="http://localhost:8000", metadata={"help": "Remote policy server URL"}) | ||||||
| timeout: float = field(default=30.0, metadata={"help": "HTTP timeout in seconds"}) | ||||||
| attempts: int = field(default=1, metadata={"help": "Number of retry attempts for failed requests"}) | ||||||
|
|
||||||
| # Additional arguments to inject directly into the observation dict (e.g. {"inference_config": {...}}) | ||||||
| additional_args: dict[str, Any] = field( | ||||||
| default_factory=dict, | ||||||
| metadata={"help": "Extra observation keys to inject directly into observation"}, | ||||||
| ) | ||||||
|
|
||||||
| # --- Abstract API implementations required by PreTrainedConfig --- | ||||||
| def get_optimizer_preset(self) -> AdamWConfig: | ||||||
| """Remote policy is inference-only; return a inert preset for API compatibility.""" | ||||||
|
||||||
| """Remote policy is inference-only; return a inert preset for API compatibility.""" | |
| """Remote policy is inference-only; return an inert preset for API compatibility.""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| from collections import deque | ||
| import threading | ||
|
|
||
| import numpy as np | ||
| import requests | ||
| import torch | ||
| from torch import Tensor | ||
|
|
||
| from lerobot.utils.messaging import pack_msg, unpack_msg | ||
| from lerobot.policies.pretrained import PreTrainedPolicy | ||
| from .configuration_remote import RemoteConfig | ||
|
|
||
|
|
||
| class RemotePolicy(PreTrainedPolicy): | ||
| """ | ||
| A policy that proxies inference to a remote HTTP server. | ||
| """ | ||
|
|
||
| config_class = RemoteConfig | ||
| name = "remote" | ||
|
|
||
| def __init__(self, config: RemoteConfig): | ||
| super().__init__(config) | ||
| self.server_url = config.server_url.rstrip("/") | ||
| self.timeout = config.timeout | ||
| self._thread_state = threading.local() | ||
| self.reset() | ||
|
|
||
| def get_optim_params(self) -> dict: | ||
| return {} | ||
|
|
||
| def reset(self): | ||
| # Reinitialize thread-local state so each worker gets its own queue/session | ||
| self._thread_state = threading.local() | ||
|
|
||
| def _state(self): | ||
| state = self._thread_state | ||
| if not hasattr(state, "session"): | ||
| state.session = requests.Session() | ||
| if not hasattr(state, "action_queue"): | ||
| state.action_queue = deque(maxlen=self.config.n_action_steps) | ||
| return state | ||
|
|
||
| def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict] | tuple[Tensor, None]: | ||
| raise NotImplementedError("RemotePolicy is inference-only") | ||
|
|
||
| @torch.no_grad() | ||
| def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor: | ||
| state = self._state() | ||
|
|
||
| # Build payload with raw tensors/arrays; pack_msg handles encoding | ||
| add_args = self.config.additional_args or {} | ||
| payload = batch | add_args | ||
|
|
||
| packed = pack_msg(payload) | ||
|
|
||
| last_exception = None | ||
| for _ in range(self.config.attempts): | ||
| try: | ||
| resp = state.session.post( | ||
| f"{self.server_url}/predict", | ||
| data=packed, | ||
| headers={"Content-Type": "application/octet-stream"}, | ||
| timeout=self.timeout, | ||
| ) | ||
| resp.raise_for_status() | ||
| break | ||
| except requests.RequestException as e: | ||
| last_exception = e | ||
|
|
||
| if last_exception: | ||
| raise last_exception | ||
|
|
||
| unpacked = unpack_msg(resp.content) | ||
| if isinstance(unpacked, torch.Tensor): | ||
| actions = unpacked | ||
| else: | ||
| actions_np = np.asarray(unpacked) | ||
| actions = torch.from_numpy(actions_np) | ||
|
|
||
| device = torch.device(self.config.device) | ||
| return actions.to(device=device, dtype=torch.float32) | ||
|
|
||
| @torch.no_grad() | ||
| def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor: | ||
| self.eval() | ||
|
|
||
| queue = self._state().action_queue | ||
|
|
||
| if len(queue) == 0: | ||
| actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] | ||
| queue.extend(actions.transpose(0, 1)) # [(B, A)] x T | ||
|
|
||
| return queue.popleft() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| from dataclasses import dataclass, field | ||
| from typing import Any | ||
|
|
||
| import torch | ||
|
|
||
| from lerobot.policies.remote.configuration_remote import RemoteConfig | ||
| from lerobot.processor import ( | ||
| AddBatchDimensionProcessorStep, | ||
| RenameObservationsProcessorStep, | ||
| PolicyAction, | ||
| PolicyProcessorPipeline, | ||
| ProcessorStep, | ||
| ) | ||
| from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action | ||
| from lerobot.processor.core import EnvTransition, TransitionKey | ||
| from lerobot.utils.constants import ( | ||
| POLICY_POSTPROCESSOR_DEFAULT_NAME, | ||
| POLICY_PREPROCESSOR_DEFAULT_NAME, | ||
| ) | ||
|
|
||
|
|
||
| def make_remote_pre_post_processors( | ||
| config: RemoteConfig, | ||
| dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, | ||
| rename_map: dict[str, str] = {}, | ||
| ) -> tuple[ | ||
| PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], | ||
| PolicyProcessorPipeline[PolicyAction, PolicyAction], | ||
| ]: | ||
| """ | ||
| Custom pre/post processors for the Remote policy. | ||
|
|
||
| Pre: | ||
| - Normalizer (if stats provided) | ||
| - AddBatchDimension | ||
| - AppendInferenceConfig (copies config.inference_config into the batch) | ||
| - Device placement | ||
|
|
||
| Post: | ||
| - Device to CPU | ||
| - Unnormalize outputs (if stats provided) | ||
| """ | ||
|
|
||
| # Pre: allow renaming features and add batch dim. Rename map can be overridden at runtime | ||
| # through preprocessor_overrides with the key "rename_observations_processor". | ||
| input_steps: list[ProcessorStep] = [ | ||
| RenameObservationsProcessorStep(rename_map=rename_map), | ||
| AddBatchDimensionProcessorStep(), | ||
| ] | ||
|
|
||
| # Minimal postprocessor: identity (no steps) | ||
| output_steps: list[ProcessorStep] = [] | ||
|
|
||
| return ( | ||
| PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( | ||
| steps=input_steps, | ||
| name=POLICY_PREPROCESSOR_DEFAULT_NAME, | ||
| ), | ||
| PolicyProcessorPipeline[PolicyAction, PolicyAction]( | ||
| steps=output_steps, | ||
| name=POLICY_POSTPROCESSOR_DEFAULT_NAME, | ||
| to_transition=policy_action_to_transition, | ||
| to_output=transition_to_policy_action, | ||
| ), | ||
| ) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential NoneType error if
Bis not inferred from any tensor/array inobs_input. Add a fallback default or raise a descriptive error ifBremains None.