Skip to content

Commit 8a0a959

Browse files
pascal-rothClemensSchwarke
authored andcommitted
formatter
1 parent f35a2a1 commit 8a0a959

File tree

4 files changed

+54
-25
lines changed

4 files changed

+54
-25
lines changed

rsl_rl/modules/perceptive_actor_critic.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
import torch.nn as nn
1010
from torch.distributions import Normal
1111

12-
from .actor_critic import ActorCritic
12+
from rsl_rl.networks import CNN, MLP, EmpiricalNormalization
1313

14-
from rsl_rl.networks import MLP, CNN, EmpiricalNormalization
14+
from .actor_critic import ActorCritic
1515

1616

1717
class PerceptiveActorCritic(ActorCritic):
18-
def __init__(
18+
def __init__( # noqa: C901
1919
self,
2020
obs,
2121
obs_groups,
@@ -53,7 +53,7 @@ def __init__(
5353
num_actor_obs += obs[obs_group].shape[-1]
5454
else:
5555
raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}")
56-
56+
5757
self.critic_obs_group_1d = []
5858
self.critic_obs_group_2d = []
5959
num_critic_obs = 0
@@ -71,12 +71,16 @@ def __init__(
7171
# actor cnn
7272
if self.actor_obs_group_2d:
7373
assert actor_cnn_config is not None, "Actor CNN config is required for 2D actor observations."
74-
74+
7575
# check if multiple 2D actor observations are provided
7676
if len(self.actor_obs_group_2d) > 1 and all(isinstance(item, dict) for item in actor_cnn_config.values()):
77-
assert len(actor_cnn_config) == len(self.actor_obs_group_2d), "Number of CNN configs must match number of 2D actor observations."
77+
assert len(actor_cnn_config) == len(
78+
self.actor_obs_group_2d
79+
), "Number of CNN configs must match number of 2D actor observations."
7880
elif len(self.actor_obs_group_2d) > 1:
79-
print(f"Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups.")
81+
print(
82+
"Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups."
83+
)
8084
actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config] * len(self.actor_obs_group_2d)))
8185
else:
8286
actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config]))
@@ -89,15 +93,15 @@ def __init__(
8993

9094
# compute the encoding dimension (cpu necessary as model not moved to device yet)
9195
encoding_dims.append(self.actor_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1])
92-
96+
9397
encoding_dim = sum(encoding_dims)
9498
else:
9599
self.actor_cnns = None
96100
encoding_dim = 0
97101

98102
# actor mlp
99103
self.actor = MLP(num_actor_obs + encoding_dim, num_actions, actor_hidden_dims, activation)
100-
104+
101105
# actor observation normalization (only for 1D actor observations)
102106
self.actor_obs_normalization = actor_obs_normalization
103107
if actor_obs_normalization:
@@ -109,33 +113,41 @@ def __init__(
109113
# critic cnn
110114
if self.critic_obs_group_2d:
111115
assert critic_cnn_config is not None, "Critic CNN config is required for 2D critic observations."
112-
116+
113117
# check if multiple 2D critic observations are provided
114118
if len(self.critic_obs_group_2d) > 1 and all(isinstance(item, dict) for item in critic_cnn_config.values()):
115-
assert len(critic_cnn_config) == len(self.critic_obs_group_2d), "Number of CNN configs must match number of 2D critic observations."
119+
assert len(critic_cnn_config) == len(
120+
self.critic_obs_group_2d
121+
), "Number of CNN configs must match number of 2D critic observations."
116122
elif len(self.critic_obs_group_2d) > 1:
117-
print(f"Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups.")
118-
critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config] * len(self.critic_obs_group_2d)))
123+
print(
124+
"Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups."
125+
)
126+
critic_cnn_config = dict(
127+
zip(self.critic_obs_group_2d, [critic_cnn_config] * len(self.critic_obs_group_2d))
128+
)
119129
else:
120130
critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config]))
121131

122132
self.critic_cnns = nn.ModuleDict()
123133
encoding_dims = []
124134
for idx, obs_group in enumerate(self.critic_obs_group_2d):
125-
self.critic_cnns[obs_group] = CNN(num_critic_in_channels[idx], activation, **critic_cnn_config[obs_group])
135+
self.critic_cnns[obs_group] = CNN(
136+
num_critic_in_channels[idx], activation, **critic_cnn_config[obs_group]
137+
)
126138
print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}")
127139

128140
# compute the encoding dimension (cpu necessary as model not moved to device yet)
129141
encoding_dims.append(self.critic_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1])
130-
142+
131143
encoding_dim = sum(encoding_dims)
132144
else:
133145
self.critic_cnns = None
134146
encoding_dim = 0
135147

136148
# critic mlp
137149
self.critic = MLP(num_critic_obs + encoding_dim, 1, critic_hidden_dims, activation)
138-
150+
139151
# critic observation normalization (only for 1D critic observations)
140152
self.critic_obs_normalization = critic_obs_normalization
141153
if critic_obs_normalization:
@@ -159,7 +171,7 @@ def __init__(
159171
Normal.set_default_validate_args(False)
160172

161173
def update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]):
162-
174+
163175
if self.actor_cnns is not None:
164176
# encode the 2D actor observations
165177
cnn_enc_list = []
@@ -168,7 +180,7 @@ def update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Te
168180
cnn_enc = torch.cat(cnn_enc_list, dim=-1)
169181
# update mlp obs
170182
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1)
171-
183+
172184
super().update_distribution(mlp_obs)
173185

174186
def act(self, obs, **kwargs):
@@ -180,7 +192,7 @@ def act(self, obs, **kwargs):
180192
def act_inference(self, obs):
181193
mlp_obs, cnn_obs = self.get_actor_obs(obs)
182194
mlp_obs = self.actor_obs_normalizer(mlp_obs)
183-
195+
184196
if self.actor_cnns is not None:
185197
# encode the 2D actor observations
186198
cnn_enc_list = []
@@ -189,7 +201,7 @@ def act_inference(self, obs):
189201
cnn_enc = torch.cat(cnn_enc_list, dim=-1)
190202
# update mlp obs
191203
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1)
192-
204+
193205
return self.actor(mlp_obs)
194206

195207
def evaluate(self, obs, **kwargs):
@@ -204,7 +216,7 @@ def evaluate(self, obs, **kwargs):
204216
cnn_enc = torch.cat(cnn_enc_list, dim=-1)
205217
# update mlp obs
206218
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1)
207-
219+
208220
return self.critic(mlp_obs)
209221

210222
def get_actor_obs(self, obs):
@@ -231,4 +243,4 @@ def update_normalization(self, obs):
231243
self.actor_obs_normalizer.update(actor_obs)
232244
if self.critic_obs_normalization:
233245
critic_obs, _ = self.get_critic_obs(obs)
234-
self.critic_obs_normalizer.update(critic_obs)
246+
self.critic_obs_normalizer.update(critic_obs)

rsl_rl/networks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
"""Definitions for components of modules."""
77

8+
from .cnn import CNN
89
from .memory import Memory
910
from .mlp import MLP
10-
from .cnn import CNN
1111
from .normalization import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization
1212

1313
__all__ = [

rsl_rl/networks/cnn.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,18 @@
1212

1313

1414
class CNN(nn.Sequential):
15-
def __init__(self, in_channels: int, activation: str, out_channels: list[int], kernel_size: list[tuple[int, int]] | tuple[int, int], stride: list[int] | int = 1, flatten: bool = True, avg_pool: tuple[int, int] | None = None, batchnorm: bool | list[bool] = False, max_pool: bool | list[bool] = False):
15+
def __init__(
16+
self,
17+
in_channels: int,
18+
activation: str,
19+
out_channels: list[int],
20+
kernel_size: list[tuple[int, int]] | tuple[int, int],
21+
stride: list[int] | int = 1,
22+
flatten: bool = True,
23+
avg_pool: tuple[int, int] | None = None,
24+
batchnorm: bool | list[bool] = False,
25+
max_pool: bool | list[bool] = False,
26+
):
1627
"""
1728
Convolutional Neural Network model.
1829

rsl_rl/runners/on_policy_runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616
import rsl_rl
1717
from rsl_rl.algorithms import PPO
1818
from rsl_rl.env import VecEnv
19-
from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, PerceptiveActorCritic, resolve_rnd_config, resolve_symmetry_config
19+
from rsl_rl.modules import (
20+
ActorCritic,
21+
ActorCriticRecurrent,
22+
PerceptiveActorCritic,
23+
resolve_rnd_config,
24+
resolve_symmetry_config,
25+
)
2026
from rsl_rl.utils import resolve_obs_groups, store_code_state
2127

2228

0 commit comments

Comments
 (0)