diff --git a/rsl_rl/algorithms/ppo.py b/rsl_rl/algorithms/ppo.py index 1479c06a..410d52ea 100644 --- a/rsl_rl/algorithms/ppo.py +++ b/rsl_rl/algorithms/ppo.py @@ -11,7 +11,7 @@ from itertools import chain from tensordict import TensorDict -from rsl_rl.modules import ActorCritic, ActorCriticRecurrent +from rsl_rl.modules import ActorCritic, ActorCriticPerceptive, ActorCriticRecurrent from rsl_rl.modules.rnd import RandomNetworkDistillation from rsl_rl.storage import RolloutStorage from rsl_rl.utils import string_to_callable @@ -20,12 +20,12 @@ class PPO: """Proximal Policy Optimization algorithm (https://arxiv.org/abs/1707.06347).""" - policy: ActorCritic | ActorCriticRecurrent + policy: ActorCritic | ActorCriticRecurrent | ActorCriticPerceptive """The actor critic module.""" def __init__( self, - policy: ActorCritic | ActorCriticRecurrent, + policy: ActorCritic | ActorCriticRecurrent | ActorCriticPerceptive, num_learning_epochs: int = 5, num_mini_batches: int = 4, clip_param: float = 0.2, diff --git a/rsl_rl/modules/__init__.py b/rsl_rl/modules/__init__.py index efb8613a..7803aa08 100644 --- a/rsl_rl/modules/__init__.py +++ b/rsl_rl/modules/__init__.py @@ -6,6 +6,7 @@ """Definitions for neural-network components for RL-agents.""" from .actor_critic import ActorCritic +from .actor_critic_perceptive import ActorCriticPerceptive from .actor_critic_recurrent import ActorCriticRecurrent from .rnd import RandomNetworkDistillation, resolve_rnd_config from .student_teacher import StudentTeacher @@ -14,6 +15,7 @@ __all__ = [ "ActorCritic", + "ActorCriticPerceptive", "ActorCriticRecurrent", "RandomNetworkDistillation", "StudentTeacher", diff --git a/rsl_rl/modules/actor_critic.py b/rsl_rl/modules/actor_critic.py index 9f01b2f4..da55e704 100644 --- a/rsl_rl/modules/actor_critic.py +++ b/rsl_rl/modules/actor_critic.py @@ -49,9 +49,8 @@ def __init__( assert len(obs[obs_group].shape) == 2, "The ActorCritic module only supports 1D observations." num_critic_obs += obs[obs_group].shape[-1] - self.state_dependent_std = state_dependent_std - # Actor + self.state_dependent_std = state_dependent_std if self.state_dependent_std: self.actor = MLP(num_actor_obs, [2, num_actions], actor_hidden_dims, activation) else: @@ -121,7 +120,7 @@ def action_std(self) -> torch.Tensor: def entropy(self) -> torch.Tensor: return self.distribution.entropy().sum(dim=-1) - def _update_distribution(self, obs: TensorDict) -> None: + def _update_distribution(self, obs: torch.Tensor) -> None: if self.state_dependent_std: # Compute mean and standard deviation mean_and_std = self.actor(obs) diff --git a/rsl_rl/modules/actor_critic_perceptive.py b/rsl_rl/modules/actor_critic_perceptive.py new file mode 100644 index 00000000..46533860 --- /dev/null +++ b/rsl_rl/modules/actor_critic_perceptive.py @@ -0,0 +1,269 @@ +# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import torch +import torch.nn as nn +from tensordict import TensorDict +from torch.distributions import Normal +from typing import Any + +from rsl_rl.networks import CNN, MLP, EmpiricalNormalization + +from .actor_critic import ActorCritic + + +class ActorCriticPerceptive(ActorCritic): + def __init__( + self, + obs: TensorDict, + obs_groups: dict[str, list[str]], + num_actions: int, + actor_obs_normalization: bool = False, + critic_obs_normalization: bool = False, + actor_hidden_dims: tuple[int] | list[int] = [256, 256, 256], + critic_hidden_dims: tuple[int] | list[int] = [256, 256, 256], + actor_cnn_cfg: dict[str, dict] | dict | None = None, + critic_cnn_cfg: dict[str, dict] | dict | None = None, + activation: str = "elu", + init_noise_std: float = 1.0, + noise_std_type: str = "scalar", + state_dependent_std: bool = False, + **kwargs: dict[str, Any], + ) -> None: + if kwargs: + print( + "PerceptiveActorCritic.__init__ got unexpected arguments, which will be ignored: " + + str([key for key in kwargs]) + ) + nn.Module.__init__(self) + + # Get the observation dimensions + self.obs_groups = obs_groups + num_actor_obs_1d = 0 + self.actor_obs_groups_1d = [] + actor_in_dims_2d = [] + actor_in_channels_2d = [] + self.actor_obs_groups_2d = [] + for obs_group in obs_groups["policy"]: + if len(obs[obs_group].shape) == 4: # B, C, H, W + self.actor_obs_groups_2d.append(obs_group) + actor_in_dims_2d.append(obs[obs_group].shape[2:4]) + actor_in_channels_2d.append(obs[obs_group].shape[1]) + elif len(obs[obs_group].shape) == 2: # B, C + self.actor_obs_groups_1d.append(obs_group) + num_actor_obs_1d += obs[obs_group].shape[-1] + else: + raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") + num_critic_obs_1d = 0 + self.critic_obs_groups_1d = [] + critic_in_dims_2d = [] + critic_in_channels_2d = [] + self.critic_obs_groups_2d = [] + for obs_group in obs_groups["critic"]: + if len(obs[obs_group].shape) == 4: # B, C, H, W + self.critic_obs_groups_2d.append(obs_group) + critic_in_dims_2d.append(obs[obs_group].shape[2:4]) + critic_in_channels_2d.append(obs[obs_group].shape[1]) + elif len(obs[obs_group].shape) == 2: # B, C + self.critic_obs_groups_1d.append(obs_group) + num_critic_obs_1d += obs[obs_group].shape[-1] + else: + raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") + + # Actor CNN + if self.actor_obs_groups_2d: + assert actor_cnn_cfg is not None, "An actor CNN configuration is required for 2D actor observations." + + # Check if multiple 2D actor observations are provided + if len(self.actor_obs_groups_2d) > 1 and all(isinstance(item, dict) for item in actor_cnn_cfg.values()): + assert len(actor_cnn_cfg) == len(self.actor_obs_groups_2d), ( + "The number of CNN configurations must match the number of 2D actor observations." + ) + elif len(self.actor_obs_groups_2d) > 1: + print( + "Only one CNN configuration for multiple 2D actor observations given, using the same configuration " + "for all groups." + ) + actor_cnn_cfg = dict(zip(self.actor_obs_groups_2d, [actor_cnn_cfg] * len(self.actor_obs_groups_2d))) + else: + actor_cnn_cfg = dict(zip(self.actor_obs_groups_2d, [actor_cnn_cfg])) + + # Create CNNs for each 2D actor observation + self.actor_cnns = nn.ModuleDict() + encoding_dim = 0 + for idx, obs_group in enumerate(self.actor_obs_groups_2d): + self.actor_cnns[obs_group] = CNN( + input_dim=actor_in_dims_2d[idx], + input_channels=actor_in_channels_2d[idx], + **actor_cnn_cfg[obs_group], + ) + print(f"Actor CNN for {obs_group}: {self.actor_cnns[obs_group]}") + # Get the output dimension of the CNN + if self.actor_cnns[obs_group].output_channels is None: + encoding_dim += int(self.actor_cnns[obs_group].output_dim) # type: ignore + else: + raise ValueError("The output of the actor CNN must be flattened before passing it to the MLP.") + else: + self.actor_cnns = None + encoding_dim = 0 + + # Actor MLP + self.state_dependent_std = state_dependent_std + if self.state_dependent_std: + self.actor = MLP(num_actor_obs_1d + encoding_dim, [2, num_actions], actor_hidden_dims, activation) + else: + self.actor = MLP(num_actor_obs_1d + encoding_dim, num_actions, actor_hidden_dims, activation) + print(f"Actor MLP: {self.actor}") + + # Actor observation normalization (only for 1D actor observations) + self.actor_obs_normalization = actor_obs_normalization + if actor_obs_normalization: + self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs_1d) + else: + self.actor_obs_normalizer = torch.nn.Identity() + + # Critic CNN + if self.critic_obs_groups_2d: + assert critic_cnn_cfg is not None, " A critic CNN configuration is required for 2D critic observations." + + # check if multiple 2D critic observations are provided + if len(self.critic_obs_groups_2d) > 1 and all(isinstance(item, dict) for item in critic_cnn_cfg.values()): + assert len(critic_cnn_cfg) == len(self.critic_obs_groups_2d), ( + "The number of CNN configurations must match the number of 2D critic observations." + ) + elif len(self.critic_obs_groups_2d) > 1: + print( + "Only one CNN configuration for multiple 2D critic observations given, using the same configuration" + " for all groups." + ) + critic_cnn_cfg = dict(zip(self.critic_obs_groups_2d, [critic_cnn_cfg] * len(self.critic_obs_groups_2d))) + else: + critic_cnn_cfg = dict(zip(self.critic_obs_groups_2d, [critic_cnn_cfg])) + + # Create CNNs for each 2D critic observation + self.critic_cnns = nn.ModuleDict() + encoding_dim = 0 + for idx, obs_group in enumerate(self.critic_obs_groups_2d): + self.critic_cnns[obs_group] = CNN( + input_dim=critic_in_dims_2d[idx], + input_channels=critic_in_channels_2d[idx], + **critic_cnn_cfg[obs_group], + ) + print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}") + # Get the output dimension of the CNN + if self.critic_cnns[obs_group].output_channels is None: + encoding_dim += int(self.critic_cnns[obs_group].output_dim) # type: ignore + else: + raise ValueError("The output of the critic CNN must be flattened before passing it to the MLP.") + else: + self.critic_cnns = None + encoding_dim = 0 + + # Critic MLP + self.critic = MLP(num_critic_obs_1d + encoding_dim, 1, critic_hidden_dims, activation) + print(f"Critic MLP: {self.critic}") + + # Critic observation normalization (only for 1D critic observations) + self.critic_obs_normalization = critic_obs_normalization + if critic_obs_normalization: + self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs_1d) + else: + self.critic_obs_normalizer = torch.nn.Identity() + + # Action noise + self.noise_std_type = noise_std_type + if self.state_dependent_std: + torch.nn.init.zeros_(self.actor[-2].weight[num_actions:]) + if self.noise_std_type == "scalar": + torch.nn.init.constant_(self.actor[-2].bias[num_actions:], init_noise_std) + elif self.noise_std_type == "log": + torch.nn.init.constant_( + self.actor[-2].bias[num_actions:], torch.log(torch.tensor(init_noise_std + 1e-7)) + ) + else: + raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") + else: + if self.noise_std_type == "scalar": + self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) + elif self.noise_std_type == "log": + self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions))) + else: + raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") + + # Action distribution + # Note: Populated in update_distribution + self.distribution = None + + # Disable args validation for speedup + Normal.set_default_validate_args(False) + + def _update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]) -> None: + if self.actor_cnns is not None: + # Encode the 2D actor observations + cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_groups_2d] + cnn_enc = torch.cat(cnn_enc_list, dim=-1) + # Concatenate to the MLP observations + mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) + + super()._update_distribution(mlp_obs) + + def act(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor: + mlp_obs, cnn_obs = self.get_actor_obs(obs) + mlp_obs = self.actor_obs_normalizer(mlp_obs) + self._update_distribution(mlp_obs, cnn_obs) + return self.distribution.sample() + + def act_inference(self, obs: TensorDict) -> torch.Tensor: + mlp_obs, cnn_obs = self.get_actor_obs(obs) + mlp_obs = self.actor_obs_normalizer(mlp_obs) + + if self.actor_cnns is not None: + # Encode the 2D actor observations + cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_groups_2d] + cnn_enc = torch.cat(cnn_enc_list, dim=-1) + # Concatenate to the MLP observations + mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) + + if self.state_dependent_std: + return self.actor(obs)[..., 0, :] + else: + return self.actor(mlp_obs) + + def evaluate(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor: + mlp_obs, cnn_obs = self.get_critic_obs(obs) + mlp_obs = self.critic_obs_normalizer(mlp_obs) + + if self.critic_cnns is not None: + # Encode the 2D critic observations + cnn_enc_list = [self.critic_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.critic_obs_groups_2d] + cnn_enc = torch.cat(cnn_enc_list, dim=-1) + # Concatenate to the MLP observations + mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) + + return self.critic(mlp_obs) + + def get_actor_obs(self, obs: TensorDict) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + obs_list_1d = [obs[obs_group] for obs_group in self.actor_obs_groups_1d] + obs_dict_2d = {} + for obs_group in self.actor_obs_groups_2d: + obs_dict_2d[obs_group] = obs[obs_group] + return torch.cat(obs_list_1d, dim=-1), obs_dict_2d + + def get_critic_obs(self, obs: TensorDict) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + obs_list_1d = [obs[obs_group] for obs_group in self.critic_obs_groups_1d] + obs_dict_2d = {} + for obs_group in self.critic_obs_groups_2d: + obs_dict_2d[obs_group] = obs[obs_group] + return torch.cat(obs_list_1d, dim=-1), obs_dict_2d + + def update_normalization(self, obs: TensorDict) -> None: + if self.actor_obs_normalization: + actor_obs, _ = self.get_actor_obs(obs) + self.actor_obs_normalizer.update(actor_obs) + if self.critic_obs_normalization: + critic_obs, _ = self.get_critic_obs(obs) + self.critic_obs_normalizer.update(critic_obs) diff --git a/rsl_rl/modules/actor_critic_recurrent.py b/rsl_rl/modules/actor_critic_recurrent.py index 509b6821..0c3805be 100644 --- a/rsl_rl/modules/actor_critic_recurrent.py +++ b/rsl_rl/modules/actor_critic_recurrent.py @@ -61,9 +61,8 @@ def __init__( assert len(obs[obs_group].shape) == 2, "The ActorCriticRecurrent module only supports 1D observations." num_critic_obs += obs[obs_group].shape[-1] - self.state_dependent_std = state_dependent_std - # Actor + self.state_dependent_std = state_dependent_std self.memory_a = Memory(num_actor_obs, rnn_hidden_dim, rnn_num_layers, rnn_type) if self.state_dependent_std: self.actor = MLP(rnn_hidden_dim, [2, num_actions], actor_hidden_dims, activation) @@ -138,7 +137,7 @@ def reset(self, dones: torch.Tensor | None = None) -> None: def forward(self) -> NoReturn: raise NotImplementedError - def _update_distribution(self, obs: TensorDict) -> None: + def _update_distribution(self, obs: torch.Tensor) -> None: if self.state_dependent_std: # Compute mean and standard deviation mean_and_std = self.actor(obs) diff --git a/rsl_rl/networks/__init__.py b/rsl_rl/networks/__init__.py index 7ede0665..5050fcc0 100644 --- a/rsl_rl/networks/__init__.py +++ b/rsl_rl/networks/__init__.py @@ -5,11 +5,13 @@ """Definitions for components of modules.""" +from .cnn import CNN from .memory import HiddenState, Memory from .mlp import MLP from .normalization import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization __all__ = [ + "CNN", "MLP", "EmpiricalDiscountedVariationNormalization", "EmpiricalNormalization", diff --git a/rsl_rl/networks/cnn.py b/rsl_rl/networks/cnn.py new file mode 100644 index 00000000..bac6274b --- /dev/null +++ b/rsl_rl/networks/cnn.py @@ -0,0 +1,192 @@ +# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import math +import torch +from torch import nn as nn + +from rsl_rl.utils import get_param, resolve_nn_activation + + +class CNN(nn.Sequential): + """Convolutional Neural Network (CNN). + + The CNN network is a sequence of convolutional layers, optional normalization layers, optional activation functions, + and optional pooling. The final output can be flattened. + """ + + def __init__( + self, + input_dim: tuple[int, int], + input_channels: int, + output_channels: tuple[int] | list[int], + kernel_size: int | tuple[int] | list[int], + stride: int | tuple[int] | list[int] = 1, + dilation: int | tuple[int] | list[int] = 1, + padding: str = "none", + norm: str | tuple[str] | list[str] = "none", + activation: str = "elu", + max_pool: bool | tuple[bool] | list[bool] = False, + global_pool: str = "none", + flatten: bool = True, + ) -> None: + """Initialize the CNN. + + Args: + input_dim: Height and width of the input. + input_channels: Number of input channels. + output_channels: List of output channels for each convolutional layer. + kernel_size: List of kernel sizes for each convolutional layer or a single kernel size for all layers. + stride: List of strides for each convolutional layer or a single stride for all layers. + dilation: List of dilations for each convolutional layer or a single dilation for all layers. + padding: Padding type to use. Either 'none', 'zeros', 'reflect', 'replicate', or 'circular'. + norm: List of normalization types for each convolutional layer or a single type for all layers. Either + 'none', 'batch', or 'layer'. + activation: Activation function to use. + max_pool: List of booleans indicating whether to apply max pooling after each convolutional layer or a + single boolean for all layers. + global_pool: Global pooling type to apply at the end. Either 'none', 'max', or 'avg'. + flatten: Whether to flatten the output tensor. + """ + super().__init__() + + # Resolve activation function + activation_function = resolve_nn_activation(activation) + + # Create layers sequentially + layers = [] + last_channels = input_channels + last_dim = input_dim + for idx in range(len(output_channels)): + # Get parameters for the current layer + k = get_param(kernel_size, idx) + s = get_param(stride, idx) + d = get_param(dilation, idx) + p = ( + _compute_padding(last_dim, k, s, d) + if padding in ["zeros", "reflect", "replicate", "circular"] + else (0, 0) + ) + + # Append convolutional layer + layers.append( + nn.Conv2d( + in_channels=last_channels, + out_channels=output_channels[idx], + kernel_size=k, + stride=s, + padding=p, + dilation=d, + padding_mode=padding if padding in ["zeros", "reflect", "replicate", "circular"] else "zeros", + ) + ) + + # Append normalization layer if specified + n = get_param(norm, idx) + if n == "none": + pass + elif n == "batch": + layers.append(nn.BatchNorm2d(output_channels[idx])) + elif n == "layer": + norm_input_dim = _compute_output_dim(last_dim, k, s, d, p) + layers.append(nn.LayerNorm([output_channels[idx], norm_input_dim[0], norm_input_dim[1]])) + else: + raise ValueError( + f"Unsupported normalization type: {n}. Supported types are 'none', 'batch', and 'layer'." + ) + + # Append activation function + layers.append(activation_function) + + # Apply max pooling if specified + if get_param(max_pool, idx): + layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + + # Update last channels and dimensions + last_channels = output_channels[idx] + last_dim = _compute_output_dim(last_dim, k, s, d, p, is_max_pool=get_param(max_pool, idx)) + + # Apply global pooling if specified + if global_pool == "none": + pass + elif global_pool == "max": + layers.append(nn.AdaptiveMaxPool2d((1, 1))) + last_dim = (1, 1) + elif global_pool == "avg": + layers.append(nn.AdaptiveAvgPool2d((1, 1))) + last_dim = (1, 1) + else: + raise ValueError( + f"Unsupported global pooling type: {global_pool}. Supported types are 'none', 'max', and 'avg'." + ) + + # Apply flattening if specified + if flatten: + layers.append(nn.Flatten(start_dim=1)) + + # Store final output dimension + self._output_channels = last_channels if not flatten else None + self._output_dim = last_dim if not flatten else last_channels * last_dim[0] * last_dim[1] + + # Register the layers + for idx, layer in enumerate(layers): + self.add_module(f"{idx}", layer) + + @property + def output_channels(self) -> int | None: + """Get the number of output channels or None if output is flattened.""" + return self._output_channels + + @property + def output_dim(self) -> tuple[int, int] | int: + """Get the output height and width or total output dimension if output is flattened.""" + return self._output_dim + + def init_weights(self) -> None: + """Initialize the weights of the CNN with Xavier initialization.""" + for idx, module in enumerate(self): + if isinstance(module, nn.Conv2d): + torch.nn.init.kaiming_normal_(module.weight) + torch.nn.init.zeros_(module.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the CNN.""" + for layer in self: + x = layer(x) + return x + + +def _compute_padding(input_hw: tuple[int, int], kernel: int, stride: int, dilation: int) -> tuple[int, int]: + """Compute the optimal padding for the current layer. + + Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + """ + h = math.ceil((stride * math.floor(input_hw[0] / stride) - input_hw[0] - stride + dilation * (kernel - 1) + 1) / 2) + w = math.ceil((stride * math.floor(input_hw[1] / stride) - input_hw[1] - stride + dilation * (kernel - 1) + 1) / 2) + return (h, w) + + +def _compute_output_dim( + input_hw: tuple[int, int], + kernel: int, + stride: int, + dilation: int, + padding: tuple[int, int], + is_max_pool: bool = False, +) -> tuple[int, int]: + """Compute the output height and width of the current layer. + + Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + """ + h = math.floor((input_hw[0] + 2 * padding[0] - dilation * (kernel - 1) - 1) / stride + 1) + w = math.floor((input_hw[1] + 2 * padding[1] - dilation * (kernel - 1) - 1) / stride + 1) + + if is_max_pool: + h = math.ceil(h / 2) + w = math.ceil(w / 2) + + return (h, w) diff --git a/rsl_rl/networks/memory.py b/rsl_rl/networks/memory.py index dd40afc2..dc67abed 100644 --- a/rsl_rl/networks/memory.py +++ b/rsl_rl/networks/memory.py @@ -18,9 +18,9 @@ class Memory(nn.Module): - """Memory module for recurrent networks. + """Memory network for recurrent architectures. - This module is used to store the hidden state of the policy. It currently supports GRU and LSTM. + This network is used to store the hidden state of the policy. It currently supports GRU and LSTM. """ def __init__(self, input_size: int, hidden_dim: int = 256, num_layers: int = 1, type: str = "lstm") -> None: diff --git a/rsl_rl/networks/mlp.py b/rsl_rl/networks/mlp.py index f01a7577..25f26804 100644 --- a/rsl_rl/networks/mlp.py +++ b/rsl_rl/networks/mlp.py @@ -9,7 +9,7 @@ import torch.nn as nn from functools import reduce -from rsl_rl.utils import resolve_nn_activation +from rsl_rl.utils import get_param, resolve_nn_activation class MLP(nn.Sequential): @@ -82,27 +82,13 @@ def init_weights(self, scales: float | tuple[float]) -> None: Args: scales: Scale factor for the weights. """ - - def get_scale(idx: int) -> float: - """Get the scale factor for the weights of the MLP. - - Args: - idx: Index of the layer. - """ - return scales[idx] if isinstance(scales, (list, tuple)) else scales - - # Initialize the weights for idx, module in enumerate(self): if isinstance(module, nn.Linear): - nn.init.orthogonal_(module.weight, gain=get_scale(idx)) + nn.init.orthogonal_(module.weight, gain=get_param(scales, idx)) nn.init.zeros_(module.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass of the MLP. - - Args: - x: Input tensor. - """ + """Forward pass of the MLP.""" for layer in self: x = layer(x) return x diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index 46a9b524..2b0d7664 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -16,7 +16,13 @@ import rsl_rl from rsl_rl.algorithms import PPO from rsl_rl.env import VecEnv -from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, resolve_rnd_config, resolve_symmetry_config +from rsl_rl.modules import ( + ActorCritic, + ActorCriticPerceptive, + ActorCriticRecurrent, + resolve_rnd_config, + resolve_symmetry_config, +) from rsl_rl.utils import resolve_obs_groups, store_code_state @@ -414,7 +420,7 @@ def _construct_algorithm(self, obs: TensorDict) -> PPO: # Initialize the policy actor_critic_class = eval(self.policy_cfg.pop("class_name")) - actor_critic: ActorCritic | ActorCriticRecurrent = actor_critic_class( + actor_critic: ActorCritic | ActorCriticRecurrent | ActorCriticPerceptive = actor_critic_class( obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg ).to(self.device) diff --git a/rsl_rl/utils/__init__.py b/rsl_rl/utils/__init__.py index a11074e0..a44bc678 100644 --- a/rsl_rl/utils/__init__.py +++ b/rsl_rl/utils/__init__.py @@ -6,6 +6,7 @@ """Helper functions.""" from .utils import ( + get_param, resolve_nn_activation, resolve_obs_groups, resolve_optimizer, @@ -16,6 +17,7 @@ ) __all__ = [ + "get_param", "resolve_nn_activation", "resolve_obs_groups", "resolve_optimizer", diff --git a/rsl_rl/utils/utils.py b/rsl_rl/utils/utils.py index 7a044e83..c1638d29 100644 --- a/rsl_rl/utils/utils.py +++ b/rsl_rl/utils/utils.py @@ -12,7 +12,20 @@ import torch import warnings from tensordict import TensorDict -from typing import Callable +from typing import Any, Callable + + +def get_param(param: Any, idx: int) -> Any: + """Get a parameter for the given index. + + Args: + param: Parameter or list/tuple of parameters. + idx: Index to get the parameter for. + """ + if isinstance(param, (tuple, list)): + return param[idx] + else: + return param def resolve_nn_activation(act_name: str) -> torch.nn.Module: