Skip to content

Commit 3f4d485

Browse files
committed
working training
1 parent d46bc2e commit 3f4d485

File tree

4 files changed

+69
-82
lines changed

4 files changed

+69
-82
lines changed

rsl_rl/modules/perceptive_actor_critic.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from .actor_critic import ActorCritic
1313

14-
from rsl_rl.networks import MLP, CNN, CNNConfig, EmpiricalNormalization
14+
from rsl_rl.networks import MLP, CNN, EmpiricalNormalization
1515

1616

1717
class PerceptiveActorCritic(ActorCritic):
@@ -24,8 +24,8 @@ def __init__(
2424
critic_obs_normalization: bool = False,
2525
actor_hidden_dims: list[int] = [256, 256, 256],
2626
critic_hidden_dims: list[int] = [256, 256, 256],
27-
actor_cnn_config: dict[str, CNNConfig] | CNNConfig | None = None,
28-
critic_cnn_config: dict[str, CNNConfig] | CNNConfig | None = None,
27+
actor_cnn_config: dict[str, dict] | dict | None = None,
28+
critic_cnn_config: dict[str, dict] | dict | None = None,
2929
activation: str = "elu",
3030
init_noise_std: float = 1.0,
3131
noise_std_type: str = "scalar",
@@ -45,10 +45,10 @@ def __init__(
4545
self.actor_obs_group_1d = []
4646
self.actor_obs_group_2d = []
4747
for obs_group in obs_groups["policy"]:
48-
if len(obs[obs_group].shape) == 2: # FIXME: should be 3???
48+
if len(obs[obs_group].shape) == 4: # B, C, H, W
4949
self.actor_obs_group_2d.append(obs_group)
50-
num_actor_in_channels.append(obs[obs_group].shape[0])
51-
elif len(obs[obs_group].shape) == 1:
50+
num_actor_in_channels.append(obs[obs_group].shape[1])
51+
elif len(obs[obs_group].shape) == 2: # B, C
5252
self.actor_obs_group_1d.append(obs_group)
5353
num_actor_obs += obs[obs_group].shape[-1]
5454
else:
@@ -59,36 +59,36 @@ def __init__(
5959
num_critic_obs = 0
6060
num_critic_in_channels = []
6161
for obs_group in obs_groups["critic"]:
62-
if len(obs[obs_group].shape) == 2: # FIXME: should be 3???
62+
if len(obs[obs_group].shape) == 4: # B, C, H, W
6363
self.critic_obs_group_2d.append(obs_group)
64-
num_critic_in_channels.append(obs[obs_group].shape[0])
65-
else:
64+
num_critic_in_channels.append(obs[obs_group].shape[1])
65+
elif len(obs[obs_group].shape) == 2: # B, C
6666
self.critic_obs_group_1d.append(obs_group)
6767
num_critic_obs += obs[obs_group].shape[-1]
68+
else:
69+
raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}")
6870

6971
# actor cnn
7072
if self.actor_obs_group_2d:
7173
assert actor_cnn_config is not None, "Actor CNN config is required for 2D actor observations."
7274

7375
# check if multiple 2D actor observations are provided
74-
if len(self.actor_obs_group_2d) > 1 and isinstance(actor_cnn_config, CNNConfig):
76+
if len(self.actor_obs_group_2d) > 1 and all(isinstance(item, dict) for item in actor_cnn_config.values()):
77+
assert len(actor_cnn_config) == len(self.actor_obs_group_2d), "Number of CNN configs must match number of 2D actor observations."
78+
elif len(self.actor_obs_group_2d) > 1:
7579
print(f"Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups.")
7680
actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config] * len(self.actor_obs_group_2d)))
77-
elif len(self.actor_obs_group_2d) > 1 and isinstance(actor_cnn_config, dict):
78-
assert len(actor_cnn_config) == len(self.actor_obs_group_2d), "Number of CNN configs must match number of 2D actor observations."
79-
elif len(self.actor_obs_group_2d) == 1 and isinstance(actor_cnn_config, CNNConfig):
80-
actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config]))
8181
else:
82-
raise ValueError(f"Invalid combination of 2D actor observations {self.actor_obs_group_2d} and actor CNN config {actor_cnn_config}.")
82+
actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config]))
8383

84-
self.actor_cnns = {}
84+
self.actor_cnns = nn.ModuleDict()
8585
encoding_dims = []
8686
for idx, obs_group in enumerate(self.actor_obs_group_2d):
87-
self.actor_cnns[obs_group] = CNN(actor_cnn_config[obs_group], num_actor_in_channels[idx], activation)
87+
self.actor_cnns[obs_group] = CNN(num_actor_in_channels[idx], activation, **actor_cnn_config[obs_group])
8888
print(f"Actor CNN for {obs_group}: {self.actor_cnns[obs_group]}")
8989

90-
# compute the encoding dimension
91-
encoding_dims.append(self.actor_cnns[obs_group](obs[obs_group]).shape[-1])
90+
# compute the encoding dimension (cpu necessary as model not moved to device yet)
91+
encoding_dims.append(self.actor_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1])
9292

9393
encoding_dim = sum(encoding_dims)
9494
else:
@@ -111,24 +111,22 @@ def __init__(
111111
assert critic_cnn_config is not None, "Critic CNN config is required for 2D critic observations."
112112

113113
# check if multiple 2D critic observations are provided
114-
if len(self.critic_obs_group_2d) > 1 and isinstance(critic_cnn_config, CNNConfig):
114+
if len(self.critic_obs_group_2d) > 1 and all(isinstance(item, dict) for item in critic_cnn_config.values()):
115+
assert len(critic_cnn_config) == len(self.critic_obs_group_2d), "Number of CNN configs must match number of 2D critic observations."
116+
elif len(self.critic_obs_group_2d) > 1:
115117
print(f"Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups.")
116118
critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config] * len(self.critic_obs_group_2d)))
117-
elif len(self.critic_obs_group_2d) > 1 and isinstance(critic_cnn_config, dict):
118-
assert len(critic_cnn_config) == len(self.critic_obs_group_2d), "Number of CNN configs must match number of 2D critic observations."
119-
elif len(self.critic_obs_group_2d) == 1 and isinstance(critic_cnn_config, CNNConfig):
120-
critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config]))
121119
else:
122-
raise ValueError(f"Invalid combination of 2D critic observations {self.critic_obs_group_2d} and critic CNN config {critic_cnn_config}.")
120+
critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config]))
123121

124-
self.critic_cnns = {}
122+
self.critic_cnns = nn.ModuleDict()
125123
encoding_dims = []
126124
for idx, obs_group in enumerate(self.critic_obs_group_2d):
127-
self.critic_cnns[obs_group] = CNN(critic_cnn_config[obs_group], num_critic_in_channels[idx], activation)
125+
self.critic_cnns[obs_group] = CNN(num_critic_in_channels[idx], activation, **critic_cnn_config[obs_group])
128126
print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}")
129127

130-
# compute the encoding dimension
131-
encoding_dims.append(self.critic_cnns[obs_group](obs[obs_group]).shape[-1])
128+
# compute the encoding dimension (cpu necessary as model not moved to device yet)
129+
encoding_dims.append(self.critic_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1])
132130

133131
encoding_dim = sum(encoding_dims)
134132
else:

rsl_rl/networks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77

88
from .memory import Memory
99
from .mlp import MLP
10-
from .cnn import CNN, CNNConfig
10+
from .cnn import CNN
1111
from .normalization import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization

rsl_rl/networks/cnn.py

Lines changed: 39 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,13 @@
66
from __future__ import annotations
77

88
import torch
9-
from dataclasses import MISSING, dataclass
109
from torch import nn as nn
1110

1211
from rsl_rl.utils import resolve_nn_activation
1312

1413

15-
@dataclass
16-
class CNNConfig:
17-
out_channels: list[int] = MISSING
18-
kernel_size: list[tuple[int, int]] | tuple[int, int] = MISSING
19-
stride: list[int] | int = 1
20-
flatten: bool = True
21-
avg_pool: tuple[int, int] | None = None
22-
batchnorm: bool | list[bool] = False
23-
max_pool: bool | list[bool] = False
24-
25-
26-
class CNN(nn.Module):
27-
def __init__(self, cfg: CNNConfig, in_channels: int, activation: str):
14+
class CNN(nn.Sequential):
15+
def __init__(self, in_channels: int, activation: str, out_channels: list[int], kernel_size: list[tuple[int, int]] | tuple[int, int], stride: list[int] | int = 1, flatten: bool = True, avg_pool: tuple[int, int] | None = None, batchnorm: bool | list[bool] = False, max_pool: bool | list[bool] = False):
2816
"""
2917
Convolutional Neural Network model.
3018
@@ -33,62 +21,63 @@ def __init__(self, cfg: CNNConfig, in_channels: int, activation: str):
3321
"""
3422
super().__init__()
3523

36-
if isinstance(cfg.batchnorm, bool):
37-
cfg.batchnorm = [cfg.batchnorm] * len(cfg.out_channels)
38-
if isinstance(cfg.max_pool, bool):
39-
cfg.max_pool = [cfg.max_pool] * len(cfg.out_channels)
40-
if isinstance(cfg.kernel_size, tuple):
41-
cfg.kernel_size = [cfg.kernel_size] * len(cfg.out_channels)
42-
if isinstance(cfg.stride, int):
43-
cfg.stride = [cfg.stride] * len(cfg.out_channels)
24+
if isinstance(batchnorm, bool):
25+
batchnorm = [batchnorm] * len(out_channels)
26+
if isinstance(max_pool, bool):
27+
max_pool = [max_pool] * len(out_channels)
28+
if isinstance(kernel_size, tuple):
29+
kernel_size = [kernel_size] * len(out_channels)
30+
if isinstance(stride, int):
31+
stride = [stride] * len(out_channels)
4432

4533
# get activation function
4634
activation_function = resolve_nn_activation(activation)
4735

4836
# build model layers
49-
modules = []
37+
layers = []
5038

51-
for idx in range(len(cfg.out_channels)):
52-
in_channels = cfg.in_channels if idx == 0 else cfg.out_channels[idx - 1]
53-
modules.append(
39+
for idx in range(len(out_channels)):
40+
in_channels = in_channels if idx == 0 else out_channels[idx - 1]
41+
layers.append(
5442
nn.Conv2d(
5543
in_channels=in_channels,
56-
out_channels=cfg.out_channels[idx],
57-
kernel_size=cfg.kernel_size[idx],
58-
stride=cfg.stride[idx],
44+
out_channels=out_channels[idx],
45+
kernel_size=kernel_size[idx],
46+
stride=stride[idx],
5947
)
6048
)
61-
if cfg.batchnorm[idx]:
62-
modules.append(nn.BatchNorm2d(num_features=cfg.out_channels[idx]))
63-
modules.append(activation_function)
64-
if cfg.max_pool[idx]:
65-
modules.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
66-
67-
self.architecture = nn.Sequential(*modules)
68-
69-
if cfg.avg_pool is not None:
70-
self.avgpool = nn.AdaptiveAvgPool2d(cfg.avg_pool)
49+
if batchnorm[idx]:
50+
layers.append(nn.BatchNorm2d(num_features=out_channels[idx]))
51+
layers.append(activation_function)
52+
if max_pool[idx]:
53+
layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
54+
55+
# register the layers
56+
for idx, layer in enumerate(layers):
57+
self.add_module(f"{idx}", layer)
58+
59+
if avg_pool is not None:
60+
self.avgpool = nn.AdaptiveAvgPool2d(avg_pool)
7161
else:
7262
self.avgpool = None
7363

74-
# initialize weights
75-
self.init_weights(self.architecture)
76-
7764
# save flatten config for forward function
78-
self.flatten = cfg.flatten
65+
self.flatten = flatten
7966

8067
def forward(self, x: torch.Tensor) -> torch.Tensor:
81-
x = self.architecture(x)
68+
for layer in self:
69+
x = layer(x)
8270
if self.flatten:
8371
x = x.flatten(start_dim=1)
8472
elif self.avgpool is not None:
8573
x = self.avgpool(x)
8674
x = x.flatten(start_dim=1)
8775
return x
8876

89-
@staticmethod
90-
def init_weights(sequential):
91-
[
92-
torch.nn.init.xavier_uniform_(module.weight)
93-
for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Conv2d))
94-
]
77+
def init_weights(self, scales: float | tuple[float]):
78+
"""Initialize the weights of the CNN."""
79+
80+
# initialize the weights
81+
for idx, module in enumerate(self):
82+
if isinstance(module, nn.Conv2d):
83+
nn.init.xavier_uniform_(module.weight)

rsl_rl/runners/on_policy_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import rsl_rl
1616
from rsl_rl.algorithms import PPO
1717
from rsl_rl.env import VecEnv
18-
from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, resolve_rnd_config, resolve_symmetry_config
18+
from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, PerceptiveActorCritic, resolve_rnd_config, resolve_symmetry_config
1919
from rsl_rl.utils import resolve_obs_groups, store_code_state
2020

2121

@@ -416,7 +416,7 @@ def _construct_algorithm(self, obs) -> PPO:
416416

417417
# initialize the actor-critic
418418
actor_critic_class = eval(self.policy_cfg.pop("class_name"))
419-
actor_critic: ActorCritic | ActorCriticRecurrent = actor_critic_class(
419+
actor_critic: ActorCritic | ActorCriticRecurrent | PerceptiveActorCritic = actor_critic_class(
420420
obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg
421421
).to(self.device)
422422

0 commit comments

Comments
 (0)