Skip to content

Commit d46bc2e

Browse files
committed
add files for perceptive example
1 parent bc1c7c4 commit d46bc2e

File tree

5 files changed

+339
-6
lines changed

5 files changed

+339
-6
lines changed

rsl_rl/modules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from .actor_critic import ActorCritic
99
from .actor_critic_recurrent import ActorCriticRecurrent
10+
from .perceptive_actor_critic import PerceptiveActorCritic
1011
from .rnd import *
1112
from .student_teacher import StudentTeacher
1213
from .student_teacher_recurrent import StudentTeacherRecurrent
@@ -15,6 +16,7 @@
1516
__all__ = [
1617
"ActorCritic",
1718
"ActorCriticRecurrent",
19+
"PerceptiveActorCritic",
1820
"StudentTeacher",
1921
"StudentTeacherRecurrent",
2022
]

rsl_rl/modules/actor_critic.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ def __init__(
2020
obs,
2121
obs_groups,
2222
num_actions,
23-
actor_obs_normalization=False,
24-
critic_obs_normalization=False,
25-
actor_hidden_dims=[256, 256, 256],
26-
critic_hidden_dims=[256, 256, 256],
27-
activation="elu",
28-
init_noise_std=1.0,
23+
actor_obs_normalization: bool = False,
24+
critic_obs_normalization: bool = False,
25+
actor_hidden_dims: list[int] = [256, 256, 256],
26+
critic_hidden_dims: list[int] = [256, 256, 256],
27+
activation: str = "elu",
28+
init_noise_std: float = 1.0,
2929
noise_std_type: str = "scalar",
3030
**kwargs,
3131
):
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2+
# All rights reserved.
3+
#
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
6+
from __future__ import annotations
7+
8+
import torch
9+
import torch.nn as nn
10+
from torch.distributions import Normal
11+
12+
from .actor_critic import ActorCritic
13+
14+
from rsl_rl.networks import MLP, CNN, CNNConfig, EmpiricalNormalization
15+
16+
17+
class PerceptiveActorCritic(ActorCritic):
18+
def __init__(
19+
self,
20+
obs,
21+
obs_groups,
22+
num_actions,
23+
actor_obs_normalization: bool = False,
24+
critic_obs_normalization: bool = False,
25+
actor_hidden_dims: list[int] = [256, 256, 256],
26+
critic_hidden_dims: list[int] = [256, 256, 256],
27+
actor_cnn_config: dict[str, CNNConfig] | CNNConfig | None = None,
28+
critic_cnn_config: dict[str, CNNConfig] | CNNConfig | None = None,
29+
activation: str = "elu",
30+
init_noise_std: float = 1.0,
31+
noise_std_type: str = "scalar",
32+
**kwargs,
33+
):
34+
if kwargs:
35+
print(
36+
"PerceptiveActorCritic.__init__ got unexpected arguments, which will be ignored: "
37+
+ str([key for key in kwargs.keys()])
38+
)
39+
nn.Module.__init__(self)
40+
41+
# get the observation dimensions
42+
self.obs_groups = obs_groups
43+
num_actor_obs = 0
44+
num_actor_in_channels = []
45+
self.actor_obs_group_1d = []
46+
self.actor_obs_group_2d = []
47+
for obs_group in obs_groups["policy"]:
48+
if len(obs[obs_group].shape) == 2: # FIXME: should be 3???
49+
self.actor_obs_group_2d.append(obs_group)
50+
num_actor_in_channels.append(obs[obs_group].shape[0])
51+
elif len(obs[obs_group].shape) == 1:
52+
self.actor_obs_group_1d.append(obs_group)
53+
num_actor_obs += obs[obs_group].shape[-1]
54+
else:
55+
raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}")
56+
57+
self.critic_obs_group_1d = []
58+
self.critic_obs_group_2d = []
59+
num_critic_obs = 0
60+
num_critic_in_channels = []
61+
for obs_group in obs_groups["critic"]:
62+
if len(obs[obs_group].shape) == 2: # FIXME: should be 3???
63+
self.critic_obs_group_2d.append(obs_group)
64+
num_critic_in_channels.append(obs[obs_group].shape[0])
65+
else:
66+
self.critic_obs_group_1d.append(obs_group)
67+
num_critic_obs += obs[obs_group].shape[-1]
68+
69+
# actor cnn
70+
if self.actor_obs_group_2d:
71+
assert actor_cnn_config is not None, "Actor CNN config is required for 2D actor observations."
72+
73+
# check if multiple 2D actor observations are provided
74+
if len(self.actor_obs_group_2d) > 1 and isinstance(actor_cnn_config, CNNConfig):
75+
print(f"Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups.")
76+
actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config] * len(self.actor_obs_group_2d)))
77+
elif len(self.actor_obs_group_2d) > 1 and isinstance(actor_cnn_config, dict):
78+
assert len(actor_cnn_config) == len(self.actor_obs_group_2d), "Number of CNN configs must match number of 2D actor observations."
79+
elif len(self.actor_obs_group_2d) == 1 and isinstance(actor_cnn_config, CNNConfig):
80+
actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config]))
81+
else:
82+
raise ValueError(f"Invalid combination of 2D actor observations {self.actor_obs_group_2d} and actor CNN config {actor_cnn_config}.")
83+
84+
self.actor_cnns = {}
85+
encoding_dims = []
86+
for idx, obs_group in enumerate(self.actor_obs_group_2d):
87+
self.actor_cnns[obs_group] = CNN(actor_cnn_config[obs_group], num_actor_in_channels[idx], activation)
88+
print(f"Actor CNN for {obs_group}: {self.actor_cnns[obs_group]}")
89+
90+
# compute the encoding dimension
91+
encoding_dims.append(self.actor_cnns[obs_group](obs[obs_group]).shape[-1])
92+
93+
encoding_dim = sum(encoding_dims)
94+
else:
95+
self.actor_cnns = None
96+
encoding_dim = 0
97+
98+
# actor mlp
99+
self.actor = MLP(num_actor_obs + encoding_dim, num_actions, actor_hidden_dims, activation)
100+
101+
# actor observation normalization (only for 1D actor observations)
102+
self.actor_obs_normalization = actor_obs_normalization
103+
if actor_obs_normalization:
104+
self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs)
105+
else:
106+
self.actor_obs_normalizer = torch.nn.Identity()
107+
print(f"Actor MLP: {self.actor}")
108+
109+
# critic cnn
110+
if self.critic_obs_group_2d:
111+
assert critic_cnn_config is not None, "Critic CNN config is required for 2D critic observations."
112+
113+
# check if multiple 2D critic observations are provided
114+
if len(self.critic_obs_group_2d) > 1 and isinstance(critic_cnn_config, CNNConfig):
115+
print(f"Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups.")
116+
critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config] * len(self.critic_obs_group_2d)))
117+
elif len(self.critic_obs_group_2d) > 1 and isinstance(critic_cnn_config, dict):
118+
assert len(critic_cnn_config) == len(self.critic_obs_group_2d), "Number of CNN configs must match number of 2D critic observations."
119+
elif len(self.critic_obs_group_2d) == 1 and isinstance(critic_cnn_config, CNNConfig):
120+
critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config]))
121+
else:
122+
raise ValueError(f"Invalid combination of 2D critic observations {self.critic_obs_group_2d} and critic CNN config {critic_cnn_config}.")
123+
124+
self.critic_cnns = {}
125+
encoding_dims = []
126+
for idx, obs_group in enumerate(self.critic_obs_group_2d):
127+
self.critic_cnns[obs_group] = CNN(critic_cnn_config[obs_group], num_critic_in_channels[idx], activation)
128+
print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}")
129+
130+
# compute the encoding dimension
131+
encoding_dims.append(self.critic_cnns[obs_group](obs[obs_group]).shape[-1])
132+
133+
encoding_dim = sum(encoding_dims)
134+
else:
135+
self.critic_cnns = None
136+
encoding_dim = 0
137+
138+
# critic mlp
139+
self.critic = MLP(num_critic_obs + encoding_dim, 1, critic_hidden_dims, activation)
140+
141+
# critic observation normalization (only for 1D critic observations)
142+
self.critic_obs_normalization = critic_obs_normalization
143+
if critic_obs_normalization:
144+
self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs)
145+
else:
146+
self.critic_obs_normalizer = torch.nn.Identity()
147+
print(f"Critic MLP: {self.critic}")
148+
149+
# Action noise
150+
self.noise_std_type = noise_std_type
151+
if self.noise_std_type == "scalar":
152+
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
153+
elif self.noise_std_type == "log":
154+
self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
155+
else:
156+
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
157+
158+
# Action distribution (populated in update_distribution)
159+
self.distribution: Normal = None
160+
# disable args validation for speedup
161+
Normal.set_default_validate_args(False)
162+
163+
def update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]):
164+
165+
if self.actor_cnns is not None:
166+
# encode the 2D actor observations
167+
cnn_enc_list = []
168+
for obs_group in self.actor_obs_group_2d:
169+
cnn_enc_list.append(self.actor_cnns[obs_group](cnn_obs[obs_group]))
170+
cnn_enc = torch.cat(cnn_enc_list, dim=-1)
171+
# update mlp obs
172+
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1)
173+
174+
super().update_distribution(mlp_obs)
175+
176+
def act(self, obs, **kwargs):
177+
mlp_obs, cnn_obs = self.get_actor_obs(obs)
178+
mlp_obs = self.actor_obs_normalizer(mlp_obs)
179+
self.update_distribution(mlp_obs, cnn_obs)
180+
return self.distribution.sample()
181+
182+
def act_inference(self, obs):
183+
mlp_obs, cnn_obs = self.get_actor_obs(obs)
184+
mlp_obs = self.actor_obs_normalizer(mlp_obs)
185+
186+
if self.actor_cnns is not None:
187+
# encode the 2D actor observations
188+
cnn_enc_list = []
189+
for obs_group in self.actor_obs_group_2d:
190+
cnn_enc_list.append(self.actor_cnns[obs_group](cnn_obs[obs_group]))
191+
cnn_enc = torch.cat(cnn_enc_list, dim=-1)
192+
# update mlp obs
193+
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1)
194+
195+
return self.actor(mlp_obs)
196+
197+
def evaluate(self, obs, **kwargs):
198+
mlp_obs, cnn_obs = self.get_critic_obs(obs)
199+
mlp_obs = self.critic_obs_normalizer(mlp_obs)
200+
201+
if self.critic_cnns is not None:
202+
# encode the 2D critic observations
203+
cnn_enc_list = []
204+
for obs_group in self.critic_obs_group_2d:
205+
cnn_enc_list.append(self.critic_cnns[obs_group](cnn_obs[obs_group]))
206+
cnn_enc = torch.cat(cnn_enc_list, dim=-1)
207+
# update mlp obs
208+
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1)
209+
210+
return self.critic(mlp_obs)
211+
212+
def get_actor_obs(self, obs):
213+
obs_list_1d = []
214+
obs_dict_2d = {}
215+
for obs_group in self.actor_obs_group_1d:
216+
obs_list_1d.append(obs[obs_group])
217+
for obs_group in self.actor_obs_group_2d:
218+
obs_dict_2d[obs_group] = obs[obs_group]
219+
return torch.cat(obs_list_1d, dim=-1), obs_dict_2d
220+
221+
def get_critic_obs(self, obs):
222+
obs_list_1d = []
223+
obs_dict_2d = {}
224+
for obs_group in self.critic_obs_group_1d:
225+
obs_list_1d.append(obs[obs_group])
226+
for obs_group in self.critic_obs_group_2d:
227+
obs_dict_2d[obs_group] = obs[obs_group]
228+
return torch.cat(obs_list_1d, dim=-1), obs_dict_2d
229+
230+
def update_normalization(self, obs):
231+
if self.actor_obs_normalization:
232+
actor_obs, _ = self.get_actor_obs(obs)
233+
self.actor_obs_normalizer.update(actor_obs)
234+
if self.critic_obs_normalization:
235+
critic_obs, _ = self.get_critic_obs(obs)
236+
self.critic_obs_normalizer.update(critic_obs)

rsl_rl/networks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@
77

88
from .memory import Memory
99
from .mlp import MLP
10+
from .cnn import CNN, CNNConfig
1011
from .normalization import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization

rsl_rl/networks/cnn.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2+
# All rights reserved.
3+
#
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
6+
from __future__ import annotations
7+
8+
import torch
9+
from dataclasses import MISSING, dataclass
10+
from torch import nn as nn
11+
12+
from rsl_rl.utils import resolve_nn_activation
13+
14+
15+
@dataclass
16+
class CNNConfig:
17+
out_channels: list[int] = MISSING
18+
kernel_size: list[tuple[int, int]] | tuple[int, int] = MISSING
19+
stride: list[int] | int = 1
20+
flatten: bool = True
21+
avg_pool: tuple[int, int] | None = None
22+
batchnorm: bool | list[bool] = False
23+
max_pool: bool | list[bool] = False
24+
25+
26+
class CNN(nn.Module):
27+
def __init__(self, cfg: CNNConfig, in_channels: int, activation: str):
28+
"""
29+
Convolutional Neural Network model.
30+
31+
.. note::
32+
Do not save config to allow for the model to be jit compiled.
33+
"""
34+
super().__init__()
35+
36+
if isinstance(cfg.batchnorm, bool):
37+
cfg.batchnorm = [cfg.batchnorm] * len(cfg.out_channels)
38+
if isinstance(cfg.max_pool, bool):
39+
cfg.max_pool = [cfg.max_pool] * len(cfg.out_channels)
40+
if isinstance(cfg.kernel_size, tuple):
41+
cfg.kernel_size = [cfg.kernel_size] * len(cfg.out_channels)
42+
if isinstance(cfg.stride, int):
43+
cfg.stride = [cfg.stride] * len(cfg.out_channels)
44+
45+
# get activation function
46+
activation_function = resolve_nn_activation(activation)
47+
48+
# build model layers
49+
modules = []
50+
51+
for idx in range(len(cfg.out_channels)):
52+
in_channels = cfg.in_channels if idx == 0 else cfg.out_channels[idx - 1]
53+
modules.append(
54+
nn.Conv2d(
55+
in_channels=in_channels,
56+
out_channels=cfg.out_channels[idx],
57+
kernel_size=cfg.kernel_size[idx],
58+
stride=cfg.stride[idx],
59+
)
60+
)
61+
if cfg.batchnorm[idx]:
62+
modules.append(nn.BatchNorm2d(num_features=cfg.out_channels[idx]))
63+
modules.append(activation_function)
64+
if cfg.max_pool[idx]:
65+
modules.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
66+
67+
self.architecture = nn.Sequential(*modules)
68+
69+
if cfg.avg_pool is not None:
70+
self.avgpool = nn.AdaptiveAvgPool2d(cfg.avg_pool)
71+
else:
72+
self.avgpool = None
73+
74+
# initialize weights
75+
self.init_weights(self.architecture)
76+
77+
# save flatten config for forward function
78+
self.flatten = cfg.flatten
79+
80+
def forward(self, x: torch.Tensor) -> torch.Tensor:
81+
x = self.architecture(x)
82+
if self.flatten:
83+
x = x.flatten(start_dim=1)
84+
elif self.avgpool is not None:
85+
x = self.avgpool(x)
86+
x = x.flatten(start_dim=1)
87+
return x
88+
89+
@staticmethod
90+
def init_weights(sequential):
91+
[
92+
torch.nn.init.xavier_uniform_(module.weight)
93+
for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Conv2d))
94+
]

0 commit comments

Comments
 (0)