Skip to content

Commit 42e406d

Browse files
format actor_critic_perceptive
1 parent 064ef49 commit 42e406d

File tree

3 files changed

+111
-93
lines changed

3 files changed

+111
-93
lines changed

rsl_rl/modules/actor_critic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,8 @@ def __init__(
4949
assert len(obs[obs_group].shape) == 2, "The ActorCritic module only supports 1D observations."
5050
num_critic_obs += obs[obs_group].shape[-1]
5151

52-
self.state_dependent_std = state_dependent_std
53-
5452
# Actor
53+
self.state_dependent_std = state_dependent_std
5554
if self.state_dependent_std:
5655
self.actor = MLP(num_actor_obs, [2, num_actions], actor_hidden_dims, activation)
5756
else:
@@ -121,7 +120,7 @@ def action_std(self) -> torch.Tensor:
121120
def entropy(self) -> torch.Tensor:
122121
return self.distribution.entropy().sum(dim=-1)
123122

124-
def _update_distribution(self, obs: TensorDict) -> None:
123+
def _update_distribution(self, obs: torch.Tensor) -> None:
125124
if self.state_dependent_std:
126125
# Compute mean and standard deviation
127126
mean_and_std = self.actor(obs)

rsl_rl/modules/actor_critic_perceptive.py

Lines changed: 107 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
import torch
99
import torch.nn as nn
10+
from tensordict import TensorDict
1011
from torch.distributions import Normal
12+
from typing import Any
1113

1214
from rsl_rl.networks import CNN, MLP, EmpiricalNormalization
1315

@@ -17,19 +19,20 @@
1719
class ActorCriticPerceptive(ActorCritic):
1820
def __init__(
1921
self,
20-
obs,
21-
obs_groups,
22-
num_actions,
22+
obs: TensorDict,
23+
obs_groups: dict[str, list[str]],
24+
num_actions: int,
2325
actor_obs_normalization: bool = False,
2426
critic_obs_normalization: bool = False,
2527
actor_hidden_dims: list[int] = [256, 256, 256],
2628
critic_hidden_dims: list[int] = [256, 256, 256],
27-
actor_cnn_config: dict[str, dict] | dict | None = None,
28-
critic_cnn_config: dict[str, dict] | dict | None = None,
29+
actor_cnn_cfg: dict[str, dict] | dict | None = None,
30+
critic_cnn_cfg: dict[str, dict] | dict | None = None,
2931
activation: str = "elu",
3032
init_noise_std: float = 1.0,
3133
noise_std_type: str = "scalar",
32-
**kwargs,
34+
state_dependent_std: bool = False,
35+
**kwargs: dict[str, Any],
3336
) -> None:
3437
if kwargs:
3538
print(
@@ -38,195 +41,212 @@ def __init__(
3841
)
3942
nn.Module.__init__(self)
4043

41-
# get the observation dimensions
44+
# Get the observation dimensions
4245
self.obs_groups = obs_groups
4346
num_actor_obs = 0
4447
num_actor_in_channels = []
45-
self.actor_obs_group_1d = []
46-
self.actor_obs_group_2d = []
48+
self.actor_obs_groups_1d = []
49+
self.actor_obs_groups_2d = []
4750
for obs_group in obs_groups["policy"]:
4851
if len(obs[obs_group].shape) == 4: # B, C, H, W
49-
self.actor_obs_group_2d.append(obs_group)
52+
self.actor_obs_groups_2d.append(obs_group)
5053
num_actor_in_channels.append(obs[obs_group].shape[1])
5154
elif len(obs[obs_group].shape) == 2: # B, C
52-
self.actor_obs_group_1d.append(obs_group)
55+
self.actor_obs_groups_1d.append(obs_group)
5356
num_actor_obs += obs[obs_group].shape[-1]
5457
else:
5558
raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}")
56-
57-
self.critic_obs_group_1d = []
58-
self.critic_obs_group_2d = []
5959
num_critic_obs = 0
6060
num_critic_in_channels = []
61+
self.critic_obs_groups_1d = []
62+
self.critic_obs_groups_2d = []
6163
for obs_group in obs_groups["critic"]:
6264
if len(obs[obs_group].shape) == 4: # B, C, H, W
63-
self.critic_obs_group_2d.append(obs_group)
65+
self.critic_obs_groups_2d.append(obs_group)
6466
num_critic_in_channels.append(obs[obs_group].shape[1])
6567
elif len(obs[obs_group].shape) == 2: # B, C
66-
self.critic_obs_group_1d.append(obs_group)
68+
self.critic_obs_groups_1d.append(obs_group)
6769
num_critic_obs += obs[obs_group].shape[-1]
6870
else:
6971
raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}")
7072

71-
# actor cnn
72-
if self.actor_obs_group_2d:
73-
assert actor_cnn_config is not None, "Actor CNN config is required for 2D actor observations."
73+
# Actor CNN
74+
if self.actor_obs_groups_2d:
75+
assert actor_cnn_cfg is not None, "An actor CNN configuration is required for 2D actor observations."
7476

75-
# check if multiple 2D actor observations are provided
76-
if len(self.actor_obs_group_2d) > 1 and all(isinstance(item, dict) for item in actor_cnn_config.values()):
77-
assert len(actor_cnn_config) == len(self.actor_obs_group_2d), (
78-
"Number of CNN configs must match number of 2D actor observations."
77+
# Check if multiple 2D actor observations are provided
78+
if len(self.actor_obs_groups_2d) > 1 and all(isinstance(item, dict) for item in actor_cnn_cfg.values()):
79+
assert len(actor_cnn_cfg) == len(self.actor_obs_groups_2d), (
80+
"The number of CNN configurations must match the number of 2D actor observations."
7981
)
80-
elif len(self.actor_obs_group_2d) > 1:
82+
elif len(self.actor_obs_groups_2d) > 1:
8183
print(
82-
"Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups."
84+
"Only one CNN configuration for multiple 2D actor observations given, using the same configuration "
85+
"for all groups."
8386
)
84-
actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config] * len(self.actor_obs_group_2d)))
87+
actor_cnn_cfg = dict(zip(self.actor_obs_groups_2d, [actor_cnn_cfg] * len(self.actor_obs_groups_2d)))
8588
else:
86-
actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config]))
89+
actor_cnn_cfg = dict(zip(self.actor_obs_groups_2d, [actor_cnn_cfg]))
8790

91+
# Create CNNs for each 2D actor observation
8892
self.actor_cnns = nn.ModuleDict()
8993
encoding_dims = []
90-
for idx, obs_group in enumerate(self.actor_obs_group_2d):
91-
self.actor_cnns[obs_group] = CNN(num_actor_in_channels[idx], activation, **actor_cnn_config[obs_group])
94+
for idx, obs_group in enumerate(self.actor_obs_groups_2d):
95+
self.actor_cnns[obs_group] = CNN(num_actor_in_channels[idx], activation, **actor_cnn_cfg[obs_group])
9296
print(f"Actor CNN for {obs_group}: {self.actor_cnns[obs_group]}")
9397

94-
# compute the encoding dimension (cpu necessary as model not moved to device yet)
98+
# Compute the encoding dimension (cpu necessary as model not moved to device yet)
9599
encoding_dims.append(self.actor_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1])
96-
97100
encoding_dim = sum(encoding_dims)
98101
else:
99102
self.actor_cnns = None
100103
encoding_dim = 0
101104

102-
# actor mlp
103-
self.actor = MLP(num_actor_obs + encoding_dim, num_actions, actor_hidden_dims, activation)
105+
# Actor MLP
106+
self.state_dependent_std = state_dependent_std
107+
if self.state_dependent_std:
108+
self.actor = MLP(num_actor_obs + encoding_dim, [2, num_actions], actor_hidden_dims, activation)
109+
else:
110+
self.actor = MLP(num_actor_obs + encoding_dim, num_actions, actor_hidden_dims, activation)
111+
print(f"Actor MLP: {self.actor}")
104112

105-
# actor observation normalization (only for 1D actor observations)
113+
# Actor observation normalization (only for 1D actor observations)
106114
self.actor_obs_normalization = actor_obs_normalization
107115
if actor_obs_normalization:
108116
self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs)
109117
else:
110118
self.actor_obs_normalizer = torch.nn.Identity()
111-
print(f"Actor MLP: {self.actor}")
112119

113-
# critic cnn
114-
if self.critic_obs_group_2d:
115-
assert critic_cnn_config is not None, "Critic CNN config is required for 2D critic observations."
120+
# Critic CNN
121+
if self.critic_obs_groups_2d:
122+
assert critic_cnn_cfg is not None, " A critic CNN configuration is required for 2D critic observations."
116123

117124
# check if multiple 2D critic observations are provided
118-
if len(self.critic_obs_group_2d) > 1 and all(isinstance(item, dict) for item in critic_cnn_config.values()):
119-
assert len(critic_cnn_config) == len(self.critic_obs_group_2d), (
120-
"Number of CNN configs must match number of 2D critic observations."
125+
if len(self.critic_obs_groups_2d) > 1 and all(isinstance(item, dict) for item in critic_cnn_cfg.values()):
126+
assert len(critic_cnn_cfg) == len(self.critic_obs_groups_2d), (
127+
"The number of CNN configurations must match the number of 2D critic observations."
121128
)
122-
elif len(self.critic_obs_group_2d) > 1:
129+
elif len(self.critic_obs_groups_2d) > 1:
123130
print(
124-
"Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups."
125-
)
126-
critic_cnn_config = dict(
127-
zip(self.critic_obs_group_2d, [critic_cnn_config] * len(self.critic_obs_group_2d))
131+
"Only one CNN configuration for multiple 2D critic observations given, using the same configuration"
132+
" for all groups."
128133
)
134+
critic_cnn_cfg = dict(zip(self.critic_obs_groups_2d, [critic_cnn_cfg] * len(self.critic_obs_groups_2d)))
129135
else:
130-
critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config]))
136+
critic_cnn_cfg = dict(zip(self.critic_obs_groups_2d, [critic_cnn_cfg]))
131137

138+
# Create CNNs for each 2D critic observation
132139
self.critic_cnns = nn.ModuleDict()
133140
encoding_dims = []
134-
for idx, obs_group in enumerate(self.critic_obs_group_2d):
135-
self.critic_cnns[obs_group] = CNN(
136-
num_critic_in_channels[idx], activation, **critic_cnn_config[obs_group]
137-
)
141+
for idx, obs_group in enumerate(self.critic_obs_groups_2d):
142+
self.critic_cnns[obs_group] = CNN(num_critic_in_channels[idx], activation, **critic_cnn_cfg[obs_group])
138143
print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}")
139144

140-
# compute the encoding dimension (cpu necessary as model not moved to device yet)
145+
# Compute the encoding dimension (cpu necessary as model not moved to device yet)
141146
encoding_dims.append(self.critic_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1])
142-
143147
encoding_dim = sum(encoding_dims)
144148
else:
145149
self.critic_cnns = None
146150
encoding_dim = 0
147151

148-
# critic mlp
152+
# Critic MLP
149153
self.critic = MLP(num_critic_obs + encoding_dim, 1, critic_hidden_dims, activation)
154+
print(f"Critic MLP: {self.critic}")
150155

151-
# critic observation normalization (only for 1D critic observations)
156+
# Critic observation normalization (only for 1D critic observations)
152157
self.critic_obs_normalization = critic_obs_normalization
153158
if critic_obs_normalization:
154159
self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs)
155160
else:
156161
self.critic_obs_normalizer = torch.nn.Identity()
157-
print(f"Critic MLP: {self.critic}")
158162

159163
# Action noise
160164
self.noise_std_type = noise_std_type
161-
if self.noise_std_type == "scalar":
162-
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
163-
elif self.noise_std_type == "log":
164-
self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
165+
if self.state_dependent_std:
166+
torch.nn.init.zeros_(self.actor[-2].weight[num_actions:])
167+
if self.noise_std_type == "scalar":
168+
torch.nn.init.constant_(self.actor[-2].bias[num_actions:], init_noise_std)
169+
elif self.noise_std_type == "log":
170+
torch.nn.init.constant_(
171+
self.actor[-2].bias[num_actions:], torch.log(torch.tensor(init_noise_std + 1e-7))
172+
)
173+
else:
174+
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
165175
else:
166-
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
176+
if self.noise_std_type == "scalar":
177+
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
178+
elif self.noise_std_type == "log":
179+
self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
180+
else:
181+
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
182+
183+
# Action distribution
184+
# Note: Populated in update_distribution
185+
self.distribution = None
167186

168-
# Action distribution (populated in update_distribution)
169-
self.distribution: Normal = None
170-
# disable args validation for speedup
187+
# Disable args validation for speedup
171188
Normal.set_default_validate_args(False)
172189

173-
def update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]) -> None:
190+
def _update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]) -> None:
174191
if self.actor_cnns is not None:
175-
# encode the 2D actor observations
176-
cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_group_2d]
192+
# Encode the 2D actor observations
193+
cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_groups_2d]
177194
cnn_enc = torch.cat(cnn_enc_list, dim=-1)
178-
# update mlp obs
195+
# Concatenate to the MLP observations
179196
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1)
180197

181-
super().update_distribution(mlp_obs)
198+
super()._update_distribution(mlp_obs)
182199

183-
def act(self, obs, **kwargs):
200+
def act(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor:
184201
mlp_obs, cnn_obs = self.get_actor_obs(obs)
185202
mlp_obs = self.actor_obs_normalizer(mlp_obs)
186-
self.update_distribution(mlp_obs, cnn_obs)
203+
self._update_distribution(mlp_obs, cnn_obs)
187204
return self.distribution.sample()
188205

189-
def act_inference(self, obs):
206+
def act_inference(self, obs: TensorDict) -> torch.Tensor:
190207
mlp_obs, cnn_obs = self.get_actor_obs(obs)
191208
mlp_obs = self.actor_obs_normalizer(mlp_obs)
192209

193210
if self.actor_cnns is not None:
194-
# encode the 2D actor observations
195-
cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_group_2d]
211+
# Encode the 2D actor observations
212+
cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_groups_2d]
196213
cnn_enc = torch.cat(cnn_enc_list, dim=-1)
197-
# update mlp obs
214+
# Concatenate to the MLP observations
198215
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1)
199216

200-
return self.actor(mlp_obs)
217+
if self.state_dependent_std:
218+
return self.actor(obs)[..., 0, :]
219+
else:
220+
return self.actor(mlp_obs)
201221

202-
def evaluate(self, obs, **kwargs):
222+
def evaluate(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor:
203223
mlp_obs, cnn_obs = self.get_critic_obs(obs)
204224
mlp_obs = self.critic_obs_normalizer(mlp_obs)
205225

206226
if self.critic_cnns is not None:
207-
# encode the 2D critic observations
208-
cnn_enc_list = [self.critic_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.critic_obs_group_2d]
227+
# Encode the 2D critic observations
228+
cnn_enc_list = [self.critic_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.critic_obs_groups_2d]
209229
cnn_enc = torch.cat(cnn_enc_list, dim=-1)
210-
# update mlp obs
230+
# Concatenate to the MLP observations
211231
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1)
212232

213233
return self.critic(mlp_obs)
214234

215-
def get_actor_obs(self, obs):
235+
def get_actor_obs(self, obs: TensorDict) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
236+
obs_list_1d = [obs[obs_group] for obs_group in self.actor_obs_groups_1d]
216237
obs_dict_2d = {}
217-
obs_list_1d = [obs[obs_group] for obs_group in self.actor_obs_group_1d]
218-
for obs_group in self.actor_obs_group_2d:
238+
for obs_group in self.actor_obs_groups_2d:
219239
obs_dict_2d[obs_group] = obs[obs_group]
220240
return torch.cat(obs_list_1d, dim=-1), obs_dict_2d
221241

222-
def get_critic_obs(self, obs):
242+
def get_critic_obs(self, obs: TensorDict) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
243+
obs_list_1d = [obs[obs_group] for obs_group in self.critic_obs_groups_1d]
223244
obs_dict_2d = {}
224-
obs_list_1d = [obs[obs_group] for obs_group in self.critic_obs_group_1d]
225-
for obs_group in self.critic_obs_group_2d:
245+
for obs_group in self.critic_obs_groups_2d:
226246
obs_dict_2d[obs_group] = obs[obs_group]
227247
return torch.cat(obs_list_1d, dim=-1), obs_dict_2d
228248

229-
def update_normalization(self, obs) -> None:
249+
def update_normalization(self, obs: TensorDict) -> None:
230250
if self.actor_obs_normalization:
231251
actor_obs, _ = self.get_actor_obs(obs)
232252
self.actor_obs_normalizer.update(actor_obs)

rsl_rl/modules/actor_critic_recurrent.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,8 @@ def __init__(
6161
assert len(obs[obs_group].shape) == 2, "The ActorCriticRecurrent module only supports 1D observations."
6262
num_critic_obs += obs[obs_group].shape[-1]
6363

64-
self.state_dependent_std = state_dependent_std
65-
6664
# Actor
65+
self.state_dependent_std = state_dependent_std
6766
self.memory_a = Memory(num_actor_obs, rnn_hidden_dim, rnn_num_layers, rnn_type)
6867
if self.state_dependent_std:
6968
self.actor = MLP(rnn_hidden_dim, [2, num_actions], actor_hidden_dims, activation)
@@ -138,7 +137,7 @@ def reset(self, dones: torch.Tensor | None = None) -> None:
138137
def forward(self) -> NoReturn:
139138
raise NotImplementedError
140139

141-
def _update_distribution(self, obs: TensorDict) -> None:
140+
def _update_distribution(self, obs: torch.Tensor) -> None:
142141
if self.state_dependent_std:
143142
# Compute mean and standard deviation
144143
mean_and_std = self.actor(obs)

0 commit comments

Comments
 (0)