Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions rsl_rl/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from itertools import chain
from tensordict import TensorDict

from rsl_rl.modules import ActorCritic, ActorCriticRecurrent
from rsl_rl.modules import ActorCritic, ActorCriticPerceptive, ActorCriticRecurrent
from rsl_rl.modules.rnd import RandomNetworkDistillation
from rsl_rl.storage import RolloutStorage
from rsl_rl.utils import string_to_callable
Expand All @@ -20,12 +20,12 @@
class PPO:
"""Proximal Policy Optimization algorithm (https://arxiv.org/abs/1707.06347)."""

policy: ActorCritic | ActorCriticRecurrent
policy: ActorCritic | ActorCriticRecurrent | ActorCriticPerceptive
"""The actor critic module."""

def __init__(
self,
policy: ActorCritic | ActorCriticRecurrent,
policy: ActorCritic | ActorCriticRecurrent | ActorCriticPerceptive,
num_learning_epochs: int = 5,
num_mini_batches: int = 4,
clip_param: float = 0.2,
Expand Down
2 changes: 2 additions & 0 deletions rsl_rl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""Definitions for neural-network components for RL-agents."""

from .actor_critic import ActorCritic
from .actor_critic_perceptive import ActorCriticPerceptive
from .actor_critic_recurrent import ActorCriticRecurrent
from .rnd import RandomNetworkDistillation, resolve_rnd_config
from .student_teacher import StudentTeacher
Expand All @@ -14,6 +15,7 @@

__all__ = [
"ActorCritic",
"ActorCriticPerceptive",
"ActorCriticRecurrent",
"RandomNetworkDistillation",
"StudentTeacher",
Expand Down
5 changes: 2 additions & 3 deletions rsl_rl/modules/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,8 @@ def __init__(
assert len(obs[obs_group].shape) == 2, "The ActorCritic module only supports 1D observations."
num_critic_obs += obs[obs_group].shape[-1]

self.state_dependent_std = state_dependent_std

# Actor
self.state_dependent_std = state_dependent_std
if self.state_dependent_std:
self.actor = MLP(num_actor_obs, [2, num_actions], actor_hidden_dims, activation)
else:
Expand Down Expand Up @@ -121,7 +120,7 @@ def action_std(self) -> torch.Tensor:
def entropy(self) -> torch.Tensor:
return self.distribution.entropy().sum(dim=-1)

def _update_distribution(self, obs: TensorDict) -> None:
def _update_distribution(self, obs: torch.Tensor) -> None:
if self.state_dependent_std:
# Compute mean and standard deviation
mean_and_std = self.actor(obs)
Expand Down
269 changes: 269 additions & 0 deletions rsl_rl/modules/actor_critic_perceptive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import torch
import torch.nn as nn
from tensordict import TensorDict
from torch.distributions import Normal
from typing import Any

from rsl_rl.networks import CNN, MLP, EmpiricalNormalization

from .actor_critic import ActorCritic


class ActorCriticPerceptive(ActorCritic):
def __init__(
self,
obs: TensorDict,
obs_groups: dict[str, list[str]],
num_actions: int,
actor_obs_normalization: bool = False,
critic_obs_normalization: bool = False,
actor_hidden_dims: tuple[int] | list[int] = [256, 256, 256],
critic_hidden_dims: tuple[int] | list[int] = [256, 256, 256],
actor_cnn_cfg: dict[str, dict] | dict | None = None,
critic_cnn_cfg: dict[str, dict] | dict | None = None,
activation: str = "elu",
init_noise_std: float = 1.0,
noise_std_type: str = "scalar",
state_dependent_std: bool = False,
**kwargs: dict[str, Any],
) -> None:
if kwargs:
print(
"PerceptiveActorCritic.__init__ got unexpected arguments, which will be ignored: "
+ str([key for key in kwargs])
)
nn.Module.__init__(self)

# Get the observation dimensions
self.obs_groups = obs_groups
num_actor_obs_1d = 0
self.actor_obs_groups_1d = []
actor_in_dims_2d = []
actor_in_channels_2d = []
self.actor_obs_groups_2d = []
for obs_group in obs_groups["policy"]:
if len(obs[obs_group].shape) == 4: # B, C, H, W
self.actor_obs_groups_2d.append(obs_group)
actor_in_dims_2d.append(obs[obs_group].shape[2:4])
actor_in_channels_2d.append(obs[obs_group].shape[1])
elif len(obs[obs_group].shape) == 2: # B, C
self.actor_obs_groups_1d.append(obs_group)
num_actor_obs_1d += obs[obs_group].shape[-1]
else:
raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}")
num_critic_obs_1d = 0
self.critic_obs_groups_1d = []
critic_in_dims_2d = []
critic_in_channels_2d = []
self.critic_obs_groups_2d = []
for obs_group in obs_groups["critic"]:
if len(obs[obs_group].shape) == 4: # B, C, H, W
self.critic_obs_groups_2d.append(obs_group)
critic_in_dims_2d.append(obs[obs_group].shape[2:4])
critic_in_channels_2d.append(obs[obs_group].shape[1])
elif len(obs[obs_group].shape) == 2: # B, C
self.critic_obs_groups_1d.append(obs_group)
num_critic_obs_1d += obs[obs_group].shape[-1]
else:
raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}")

# Actor CNN
if self.actor_obs_groups_2d:
assert actor_cnn_cfg is not None, "An actor CNN configuration is required for 2D actor observations."

# Check if multiple 2D actor observations are provided
if len(self.actor_obs_groups_2d) > 1 and all(isinstance(item, dict) for item in actor_cnn_cfg.values()):
assert len(actor_cnn_cfg) == len(self.actor_obs_groups_2d), (
"The number of CNN configurations must match the number of 2D actor observations."
)
elif len(self.actor_obs_groups_2d) > 1:
print(
"Only one CNN configuration for multiple 2D actor observations given, using the same configuration "
"for all groups."
)
actor_cnn_cfg = dict(zip(self.actor_obs_groups_2d, [actor_cnn_cfg] * len(self.actor_obs_groups_2d)))
else:
actor_cnn_cfg = dict(zip(self.actor_obs_groups_2d, [actor_cnn_cfg]))

# Create CNNs for each 2D actor observation
self.actor_cnns = nn.ModuleDict()
encoding_dim = 0
for idx, obs_group in enumerate(self.actor_obs_groups_2d):
self.actor_cnns[obs_group] = CNN(
input_dim=actor_in_dims_2d[idx],
input_channels=actor_in_channels_2d[idx],
**actor_cnn_cfg[obs_group],
)
print(f"Actor CNN for {obs_group}: {self.actor_cnns[obs_group]}")
# Get the output dimension of the CNN
if self.actor_cnns[obs_group].output_channels is None:
encoding_dim += int(self.actor_cnns[obs_group].output_dim) # type: ignore
else:
raise ValueError("The output of the actor CNN must be flattened before passing it to the MLP.")
else:
self.actor_cnns = None
encoding_dim = 0

# Actor MLP
self.state_dependent_std = state_dependent_std
if self.state_dependent_std:
self.actor = MLP(num_actor_obs_1d + encoding_dim, [2, num_actions], actor_hidden_dims, activation)
else:
self.actor = MLP(num_actor_obs_1d + encoding_dim, num_actions, actor_hidden_dims, activation)
print(f"Actor MLP: {self.actor}")

# Actor observation normalization (only for 1D actor observations)
self.actor_obs_normalization = actor_obs_normalization
if actor_obs_normalization:
self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs_1d)
else:
self.actor_obs_normalizer = torch.nn.Identity()

# Critic CNN
if self.critic_obs_groups_2d:
assert critic_cnn_cfg is not None, " A critic CNN configuration is required for 2D critic observations."

# check if multiple 2D critic observations are provided
if len(self.critic_obs_groups_2d) > 1 and all(isinstance(item, dict) for item in critic_cnn_cfg.values()):
assert len(critic_cnn_cfg) == len(self.critic_obs_groups_2d), (
"The number of CNN configurations must match the number of 2D critic observations."
)
elif len(self.critic_obs_groups_2d) > 1:
print(
"Only one CNN configuration for multiple 2D critic observations given, using the same configuration"
" for all groups."
)
critic_cnn_cfg = dict(zip(self.critic_obs_groups_2d, [critic_cnn_cfg] * len(self.critic_obs_groups_2d)))
else:
critic_cnn_cfg = dict(zip(self.critic_obs_groups_2d, [critic_cnn_cfg]))

# Create CNNs for each 2D critic observation
self.critic_cnns = nn.ModuleDict()
encoding_dim = 0
for idx, obs_group in enumerate(self.critic_obs_groups_2d):
self.critic_cnns[obs_group] = CNN(
input_dim=critic_in_dims_2d[idx],
input_channels=critic_in_channels_2d[idx],
**critic_cnn_cfg[obs_group],
)
print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}")
# Get the output dimension of the CNN
if self.critic_cnns[obs_group].output_channels is None:
encoding_dim += int(self.critic_cnns[obs_group].output_dim) # type: ignore
else:
raise ValueError("The output of the critic CNN must be flattened before passing it to the MLP.")
else:
self.critic_cnns = None
encoding_dim = 0

# Critic MLP
self.critic = MLP(num_critic_obs_1d + encoding_dim, 1, critic_hidden_dims, activation)
print(f"Critic MLP: {self.critic}")

# Critic observation normalization (only for 1D critic observations)
self.critic_obs_normalization = critic_obs_normalization
if critic_obs_normalization:
self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs_1d)
else:
self.critic_obs_normalizer = torch.nn.Identity()

# Action noise
self.noise_std_type = noise_std_type
if self.state_dependent_std:
torch.nn.init.zeros_(self.actor[-2].weight[num_actions:])
if self.noise_std_type == "scalar":
torch.nn.init.constant_(self.actor[-2].bias[num_actions:], init_noise_std)
elif self.noise_std_type == "log":
torch.nn.init.constant_(
self.actor[-2].bias[num_actions:], torch.log(torch.tensor(init_noise_std + 1e-7))
)
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
else:
if self.noise_std_type == "scalar":
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
elif self.noise_std_type == "log":
self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")

# Action distribution
# Note: Populated in update_distribution
self.distribution = None

# Disable args validation for speedup
Normal.set_default_validate_args(False)

def _update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]) -> None:
if self.actor_cnns is not None:
# Encode the 2D actor observations
cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_groups_2d]
cnn_enc = torch.cat(cnn_enc_list, dim=-1)
# Concatenate to the MLP observations
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1)

super()._update_distribution(mlp_obs)

def act(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor:
mlp_obs, cnn_obs = self.get_actor_obs(obs)
mlp_obs = self.actor_obs_normalizer(mlp_obs)
self._update_distribution(mlp_obs, cnn_obs)
return self.distribution.sample()

def act_inference(self, obs: TensorDict) -> torch.Tensor:
mlp_obs, cnn_obs = self.get_actor_obs(obs)
mlp_obs = self.actor_obs_normalizer(mlp_obs)

if self.actor_cnns is not None:
# Encode the 2D actor observations
cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_groups_2d]
cnn_enc = torch.cat(cnn_enc_list, dim=-1)
# Concatenate to the MLP observations
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1)

if self.state_dependent_std:
return self.actor(obs)[..., 0, :]
else:
return self.actor(mlp_obs)

def evaluate(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor:
mlp_obs, cnn_obs = self.get_critic_obs(obs)
mlp_obs = self.critic_obs_normalizer(mlp_obs)

if self.critic_cnns is not None:
# Encode the 2D critic observations
cnn_enc_list = [self.critic_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.critic_obs_groups_2d]
cnn_enc = torch.cat(cnn_enc_list, dim=-1)
# Concatenate to the MLP observations
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1)

return self.critic(mlp_obs)

def get_actor_obs(self, obs: TensorDict) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
obs_list_1d = [obs[obs_group] for obs_group in self.actor_obs_groups_1d]
obs_dict_2d = {}
for obs_group in self.actor_obs_groups_2d:
obs_dict_2d[obs_group] = obs[obs_group]
return torch.cat(obs_list_1d, dim=-1), obs_dict_2d

def get_critic_obs(self, obs: TensorDict) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
obs_list_1d = [obs[obs_group] for obs_group in self.critic_obs_groups_1d]
obs_dict_2d = {}
for obs_group in self.critic_obs_groups_2d:
obs_dict_2d[obs_group] = obs[obs_group]
return torch.cat(obs_list_1d, dim=-1), obs_dict_2d

def update_normalization(self, obs: TensorDict) -> None:
if self.actor_obs_normalization:
actor_obs, _ = self.get_actor_obs(obs)
self.actor_obs_normalizer.update(actor_obs)
if self.critic_obs_normalization:
critic_obs, _ = self.get_critic_obs(obs)
self.critic_obs_normalizer.update(critic_obs)
5 changes: 2 additions & 3 deletions rsl_rl/modules/actor_critic_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,8 @@ def __init__(
assert len(obs[obs_group].shape) == 2, "The ActorCriticRecurrent module only supports 1D observations."
num_critic_obs += obs[obs_group].shape[-1]

self.state_dependent_std = state_dependent_std

# Actor
self.state_dependent_std = state_dependent_std
self.memory_a = Memory(num_actor_obs, rnn_hidden_dim, rnn_num_layers, rnn_type)
if self.state_dependent_std:
self.actor = MLP(rnn_hidden_dim, [2, num_actions], actor_hidden_dims, activation)
Expand Down Expand Up @@ -138,7 +137,7 @@ def reset(self, dones: torch.Tensor | None = None) -> None:
def forward(self) -> NoReturn:
raise NotImplementedError

def _update_distribution(self, obs: TensorDict) -> None:
def _update_distribution(self, obs: torch.Tensor) -> None:
if self.state_dependent_std:
# Compute mean and standard deviation
mean_and_std = self.actor(obs)
Expand Down
2 changes: 2 additions & 0 deletions rsl_rl/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

"""Definitions for components of modules."""

from .cnn import CNN
from .memory import HiddenState, Memory
from .mlp import MLP
from .normalization import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization

__all__ = [
"CNN",
"MLP",
"EmpiricalDiscountedVariationNormalization",
"EmpiricalNormalization",
Expand Down
Loading