Skip to content

Commit b2f620f

Browse files
formatting 2
1 parent d57e146 commit b2f620f

File tree

2 files changed

+10
-20
lines changed

2 files changed

+10
-20
lines changed

rsl_rl/modules/actor_critic_perceptive.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
init_noise_std: float = 1.0,
3131
noise_std_type: str = "scalar",
3232
**kwargs,
33-
):
33+
) -> None:
3434
if kwargs:
3535
print(
3636
"PerceptiveActorCritic.__init__ got unexpected arguments, which will be ignored: "
@@ -170,12 +170,10 @@ def __init__(
170170
# disable args validation for speedup
171171
Normal.set_default_validate_args(False)
172172

173-
def update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]):
173+
def update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]) -> None:
174174
if self.actor_cnns is not None:
175175
# encode the 2D actor observations
176-
cnn_enc_list = []
177-
for obs_group in self.actor_obs_group_2d:
178-
cnn_enc_list.append(self.actor_cnns[obs_group](cnn_obs[obs_group]))
176+
cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_group_2d]
179177
cnn_enc = torch.cat(cnn_enc_list, dim=-1)
180178
# update mlp obs
181179
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1)
@@ -194,9 +192,7 @@ def act_inference(self, obs):
194192

195193
if self.actor_cnns is not None:
196194
# encode the 2D actor observations
197-
cnn_enc_list = []
198-
for obs_group in self.actor_obs_group_2d:
199-
cnn_enc_list.append(self.actor_cnns[obs_group](cnn_obs[obs_group]))
195+
cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_group_2d]
200196
cnn_enc = torch.cat(cnn_enc_list, dim=-1)
201197
# update mlp obs
202198
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1)
@@ -209,34 +205,28 @@ def evaluate(self, obs, **kwargs):
209205

210206
if self.critic_cnns is not None:
211207
# encode the 2D critic observations
212-
cnn_enc_list = []
213-
for obs_group in self.critic_obs_group_2d:
214-
cnn_enc_list.append(self.critic_cnns[obs_group](cnn_obs[obs_group]))
208+
cnn_enc_list = [self.critic_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.critic_obs_group_2d]
215209
cnn_enc = torch.cat(cnn_enc_list, dim=-1)
216210
# update mlp obs
217211
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1)
218212

219213
return self.critic(mlp_obs)
220214

221215
def get_actor_obs(self, obs):
222-
obs_list_1d = []
223216
obs_dict_2d = {}
224-
for obs_group in self.actor_obs_group_1d:
225-
obs_list_1d.append(obs[obs_group])
217+
obs_list_1d = [obs[obs_group] for obs_group in self.actor_obs_group_1d]
226218
for obs_group in self.actor_obs_group_2d:
227219
obs_dict_2d[obs_group] = obs[obs_group]
228220
return torch.cat(obs_list_1d, dim=-1), obs_dict_2d
229221

230222
def get_critic_obs(self, obs):
231-
obs_list_1d = []
232223
obs_dict_2d = {}
233-
for obs_group in self.critic_obs_group_1d:
234-
obs_list_1d.append(obs[obs_group])
224+
obs_list_1d = [obs[obs_group] for obs_group in self.critic_obs_group_1d]
235225
for obs_group in self.critic_obs_group_2d:
236226
obs_dict_2d[obs_group] = obs[obs_group]
237227
return torch.cat(obs_list_1d, dim=-1), obs_dict_2d
238228

239-
def update_normalization(self, obs):
229+
def update_normalization(self, obs) -> None:
240230
if self.actor_obs_normalization:
241231
actor_obs, _ = self.get_actor_obs(obs)
242232
self.actor_obs_normalizer.update(actor_obs)

rsl_rl/networks/cnn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(
2323
avg_pool: tuple[int, int] | None = None,
2424
batchnorm: bool | list[bool] = False,
2525
max_pool: bool | list[bool] = False,
26-
):
26+
) -> None:
2727
"""Convolutional Neural Network model.
2828
2929
.. note::
@@ -84,7 +84,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8484
x = x.flatten(start_dim=1)
8585
return x
8686

87-
def init_weights(self, scales: float | tuple[float]):
87+
def init_weights(self, scales: float | tuple[float]) -> None:
8888
"""Initialize the weights of the CNN."""
8989
# initialize the weights
9090
for idx, module in enumerate(self):

0 commit comments

Comments
 (0)