-
Notifications
You must be signed in to change notification settings - Fork 412
Adds perceptive actor-critic class #114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
pascal-roth
wants to merge
8
commits into
main
Choose a base branch
from
feature/perceptive-nav-rl
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+501
−31
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
677a7b2
add files for perceptive example
pascal-roth 6b2910f
working training
pascal-roth 6edccbe
formatter
pascal-roth 364dcab
formatting 1
ClemensSchwarke 1517bd0
formatting 2
ClemensSchwarke d7bfc7a
CNN docstrings
ClemensSchwarke 0a1756d
format actor_critic_perceptive
ClemensSchwarke 3132a7e
extend CNN to more configuration options and better exportability
ClemensSchwarke File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,269 @@ | ||
| # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION | ||
| # All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from tensordict import TensorDict | ||
| from torch.distributions import Normal | ||
| from typing import Any | ||
|
|
||
| from rsl_rl.networks import CNN, MLP, EmpiricalNormalization | ||
|
|
||
| from .actor_critic import ActorCritic | ||
|
|
||
|
|
||
| class ActorCriticPerceptive(ActorCritic): | ||
| def __init__( | ||
| self, | ||
| obs: TensorDict, | ||
| obs_groups: dict[str, list[str]], | ||
| num_actions: int, | ||
| actor_obs_normalization: bool = False, | ||
| critic_obs_normalization: bool = False, | ||
| actor_hidden_dims: tuple[int] | list[int] = [256, 256, 256], | ||
| critic_hidden_dims: tuple[int] | list[int] = [256, 256, 256], | ||
| actor_cnn_cfg: dict[str, dict] | dict | None = None, | ||
| critic_cnn_cfg: dict[str, dict] | dict | None = None, | ||
| activation: str = "elu", | ||
| init_noise_std: float = 1.0, | ||
| noise_std_type: str = "scalar", | ||
| state_dependent_std: bool = False, | ||
| **kwargs: dict[str, Any], | ||
| ) -> None: | ||
| if kwargs: | ||
| print( | ||
| "PerceptiveActorCritic.__init__ got unexpected arguments, which will be ignored: " | ||
| + str([key for key in kwargs]) | ||
| ) | ||
| nn.Module.__init__(self) | ||
|
|
||
| # Get the observation dimensions | ||
| self.obs_groups = obs_groups | ||
| num_actor_obs_1d = 0 | ||
| self.actor_obs_groups_1d = [] | ||
| actor_in_dims_2d = [] | ||
| actor_in_channels_2d = [] | ||
| self.actor_obs_groups_2d = [] | ||
| for obs_group in obs_groups["policy"]: | ||
| if len(obs[obs_group].shape) == 4: # B, C, H, W | ||
| self.actor_obs_groups_2d.append(obs_group) | ||
| actor_in_dims_2d.append(obs[obs_group].shape[2:4]) | ||
| actor_in_channels_2d.append(obs[obs_group].shape[1]) | ||
| elif len(obs[obs_group].shape) == 2: # B, C | ||
| self.actor_obs_groups_1d.append(obs_group) | ||
| num_actor_obs_1d += obs[obs_group].shape[-1] | ||
| else: | ||
| raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") | ||
| num_critic_obs_1d = 0 | ||
| self.critic_obs_groups_1d = [] | ||
| critic_in_dims_2d = [] | ||
| critic_in_channels_2d = [] | ||
| self.critic_obs_groups_2d = [] | ||
| for obs_group in obs_groups["critic"]: | ||
| if len(obs[obs_group].shape) == 4: # B, C, H, W | ||
| self.critic_obs_groups_2d.append(obs_group) | ||
| critic_in_dims_2d.append(obs[obs_group].shape[2:4]) | ||
| critic_in_channels_2d.append(obs[obs_group].shape[1]) | ||
| elif len(obs[obs_group].shape) == 2: # B, C | ||
| self.critic_obs_groups_1d.append(obs_group) | ||
| num_critic_obs_1d += obs[obs_group].shape[-1] | ||
| else: | ||
| raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") | ||
|
|
||
| # Actor CNN | ||
| if self.actor_obs_groups_2d: | ||
| assert actor_cnn_cfg is not None, "An actor CNN configuration is required for 2D actor observations." | ||
|
|
||
| # Check if multiple 2D actor observations are provided | ||
| if len(self.actor_obs_groups_2d) > 1 and all(isinstance(item, dict) for item in actor_cnn_cfg.values()): | ||
| assert len(actor_cnn_cfg) == len(self.actor_obs_groups_2d), ( | ||
| "The number of CNN configurations must match the number of 2D actor observations." | ||
| ) | ||
| elif len(self.actor_obs_groups_2d) > 1: | ||
| print( | ||
| "Only one CNN configuration for multiple 2D actor observations given, using the same configuration " | ||
| "for all groups." | ||
| ) | ||
| actor_cnn_cfg = dict(zip(self.actor_obs_groups_2d, [actor_cnn_cfg] * len(self.actor_obs_groups_2d))) | ||
| else: | ||
| actor_cnn_cfg = dict(zip(self.actor_obs_groups_2d, [actor_cnn_cfg])) | ||
|
|
||
| # Create CNNs for each 2D actor observation | ||
| self.actor_cnns = nn.ModuleDict() | ||
| encoding_dim = 0 | ||
| for idx, obs_group in enumerate(self.actor_obs_groups_2d): | ||
| self.actor_cnns[obs_group] = CNN( | ||
| input_dim=actor_in_dims_2d[idx], | ||
| input_channels=actor_in_channels_2d[idx], | ||
| **actor_cnn_cfg[obs_group], | ||
| ) | ||
| print(f"Actor CNN for {obs_group}: {self.actor_cnns[obs_group]}") | ||
| # Get the output dimension of the CNN | ||
| if self.actor_cnns[obs_group].output_channels is None: | ||
| encoding_dim += int(self.actor_cnns[obs_group].output_dim) # type: ignore | ||
| else: | ||
| raise ValueError("The output of the actor CNN must be flattened before passing it to the MLP.") | ||
| else: | ||
| self.actor_cnns = None | ||
| encoding_dim = 0 | ||
|
|
||
| # Actor MLP | ||
| self.state_dependent_std = state_dependent_std | ||
| if self.state_dependent_std: | ||
| self.actor = MLP(num_actor_obs_1d + encoding_dim, [2, num_actions], actor_hidden_dims, activation) | ||
| else: | ||
| self.actor = MLP(num_actor_obs_1d + encoding_dim, num_actions, actor_hidden_dims, activation) | ||
| print(f"Actor MLP: {self.actor}") | ||
|
|
||
| # Actor observation normalization (only for 1D actor observations) | ||
| self.actor_obs_normalization = actor_obs_normalization | ||
| if actor_obs_normalization: | ||
| self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs_1d) | ||
| else: | ||
| self.actor_obs_normalizer = torch.nn.Identity() | ||
|
|
||
| # Critic CNN | ||
| if self.critic_obs_groups_2d: | ||
| assert critic_cnn_cfg is not None, " A critic CNN configuration is required for 2D critic observations." | ||
|
|
||
| # check if multiple 2D critic observations are provided | ||
| if len(self.critic_obs_groups_2d) > 1 and all(isinstance(item, dict) for item in critic_cnn_cfg.values()): | ||
| assert len(critic_cnn_cfg) == len(self.critic_obs_groups_2d), ( | ||
| "The number of CNN configurations must match the number of 2D critic observations." | ||
| ) | ||
| elif len(self.critic_obs_groups_2d) > 1: | ||
| print( | ||
| "Only one CNN configuration for multiple 2D critic observations given, using the same configuration" | ||
| " for all groups." | ||
| ) | ||
| critic_cnn_cfg = dict(zip(self.critic_obs_groups_2d, [critic_cnn_cfg] * len(self.critic_obs_groups_2d))) | ||
| else: | ||
| critic_cnn_cfg = dict(zip(self.critic_obs_groups_2d, [critic_cnn_cfg])) | ||
|
|
||
| # Create CNNs for each 2D critic observation | ||
| self.critic_cnns = nn.ModuleDict() | ||
| encoding_dim = 0 | ||
| for idx, obs_group in enumerate(self.critic_obs_groups_2d): | ||
| self.critic_cnns[obs_group] = CNN( | ||
| input_dim=critic_in_dims_2d[idx], | ||
| input_channels=critic_in_channels_2d[idx], | ||
| **critic_cnn_cfg[obs_group], | ||
| ) | ||
| print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}") | ||
| # Get the output dimension of the CNN | ||
| if self.critic_cnns[obs_group].output_channels is None: | ||
| encoding_dim += int(self.critic_cnns[obs_group].output_dim) # type: ignore | ||
| else: | ||
| raise ValueError("The output of the critic CNN must be flattened before passing it to the MLP.") | ||
| else: | ||
| self.critic_cnns = None | ||
| encoding_dim = 0 | ||
|
|
||
| # Critic MLP | ||
| self.critic = MLP(num_critic_obs_1d + encoding_dim, 1, critic_hidden_dims, activation) | ||
| print(f"Critic MLP: {self.critic}") | ||
|
|
||
| # Critic observation normalization (only for 1D critic observations) | ||
| self.critic_obs_normalization = critic_obs_normalization | ||
| if critic_obs_normalization: | ||
| self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs_1d) | ||
| else: | ||
| self.critic_obs_normalizer = torch.nn.Identity() | ||
|
|
||
| # Action noise | ||
| self.noise_std_type = noise_std_type | ||
| if self.state_dependent_std: | ||
| torch.nn.init.zeros_(self.actor[-2].weight[num_actions:]) | ||
| if self.noise_std_type == "scalar": | ||
| torch.nn.init.constant_(self.actor[-2].bias[num_actions:], init_noise_std) | ||
| elif self.noise_std_type == "log": | ||
| torch.nn.init.constant_( | ||
| self.actor[-2].bias[num_actions:], torch.log(torch.tensor(init_noise_std + 1e-7)) | ||
| ) | ||
| else: | ||
| raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") | ||
| else: | ||
| if self.noise_std_type == "scalar": | ||
| self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) | ||
| elif self.noise_std_type == "log": | ||
| self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions))) | ||
| else: | ||
| raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") | ||
|
|
||
| # Action distribution | ||
| # Note: Populated in update_distribution | ||
| self.distribution = None | ||
|
|
||
| # Disable args validation for speedup | ||
| Normal.set_default_validate_args(False) | ||
|
|
||
| def _update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]) -> None: | ||
| if self.actor_cnns is not None: | ||
| # Encode the 2D actor observations | ||
| cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_groups_2d] | ||
| cnn_enc = torch.cat(cnn_enc_list, dim=-1) | ||
| # Concatenate to the MLP observations | ||
| mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) | ||
|
|
||
| super()._update_distribution(mlp_obs) | ||
|
|
||
| def act(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor: | ||
| mlp_obs, cnn_obs = self.get_actor_obs(obs) | ||
| mlp_obs = self.actor_obs_normalizer(mlp_obs) | ||
| self._update_distribution(mlp_obs, cnn_obs) | ||
| return self.distribution.sample() | ||
|
|
||
| def act_inference(self, obs: TensorDict) -> torch.Tensor: | ||
| mlp_obs, cnn_obs = self.get_actor_obs(obs) | ||
| mlp_obs = self.actor_obs_normalizer(mlp_obs) | ||
|
|
||
| if self.actor_cnns is not None: | ||
| # Encode the 2D actor observations | ||
| cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_groups_2d] | ||
| cnn_enc = torch.cat(cnn_enc_list, dim=-1) | ||
| # Concatenate to the MLP observations | ||
| mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) | ||
|
|
||
| if self.state_dependent_std: | ||
| return self.actor(obs)[..., 0, :] | ||
| else: | ||
| return self.actor(mlp_obs) | ||
|
|
||
| def evaluate(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor: | ||
| mlp_obs, cnn_obs = self.get_critic_obs(obs) | ||
| mlp_obs = self.critic_obs_normalizer(mlp_obs) | ||
|
|
||
| if self.critic_cnns is not None: | ||
| # Encode the 2D critic observations | ||
| cnn_enc_list = [self.critic_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.critic_obs_groups_2d] | ||
| cnn_enc = torch.cat(cnn_enc_list, dim=-1) | ||
| # Concatenate to the MLP observations | ||
| mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) | ||
|
|
||
| return self.critic(mlp_obs) | ||
|
|
||
| def get_actor_obs(self, obs: TensorDict) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: | ||
| obs_list_1d = [obs[obs_group] for obs_group in self.actor_obs_groups_1d] | ||
| obs_dict_2d = {} | ||
| for obs_group in self.actor_obs_groups_2d: | ||
| obs_dict_2d[obs_group] = obs[obs_group] | ||
| return torch.cat(obs_list_1d, dim=-1), obs_dict_2d | ||
|
|
||
| def get_critic_obs(self, obs: TensorDict) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: | ||
| obs_list_1d = [obs[obs_group] for obs_group in self.critic_obs_groups_1d] | ||
| obs_dict_2d = {} | ||
| for obs_group in self.critic_obs_groups_2d: | ||
| obs_dict_2d[obs_group] = obs[obs_group] | ||
| return torch.cat(obs_list_1d, dim=-1), obs_dict_2d | ||
|
|
||
| def update_normalization(self, obs: TensorDict) -> None: | ||
| if self.actor_obs_normalization: | ||
| actor_obs, _ = self.get_actor_obs(obs) | ||
| self.actor_obs_normalizer.update(actor_obs) | ||
| if self.critic_obs_normalization: | ||
| critic_obs, _ = self.get_critic_obs(obs) | ||
| self.critic_obs_normalizer.update(critic_obs) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.