Skip to content

Commit d57e146

Browse files
formatting 1
1 parent 8a0a959 commit d57e146

File tree

5 files changed

+15
-17
lines changed

5 files changed

+15
-17
lines changed

rsl_rl/modules/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,17 @@
66
"""Definitions for neural-network components for RL-agents."""
77

88
from .actor_critic import ActorCritic
9+
from .actor_critic_perceptive import ActorCriticPerceptive
910
from .actor_critic_recurrent import ActorCriticRecurrent
10-
from .perceptive_actor_critic import PerceptiveActorCritic
1111
from .rnd import RandomNetworkDistillation, resolve_rnd_config
1212
from .student_teacher import StudentTeacher
1313
from .student_teacher_recurrent import StudentTeacherRecurrent
1414
from .symmetry import resolve_symmetry_config
1515

1616
__all__ = [
1717
"ActorCritic",
18+
"ActorCriticPerceptive",
1819
"ActorCriticRecurrent",
19-
"PerceptiveActorCritic",
2020
"RandomNetworkDistillation",
2121
"StudentTeacher",
2222
"StudentTeacherRecurrent",

rsl_rl/modules/perceptive_actor_critic.py renamed to rsl_rl/modules/actor_critic_perceptive.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from .actor_critic import ActorCritic
1515

1616

17-
class PerceptiveActorCritic(ActorCritic):
18-
def __init__( # noqa: C901
17+
class ActorCriticPerceptive(ActorCritic):
18+
def __init__(
1919
self,
2020
obs,
2121
obs_groups,
@@ -34,7 +34,7 @@ def __init__( # noqa: C901
3434
if kwargs:
3535
print(
3636
"PerceptiveActorCritic.__init__ got unexpected arguments, which will be ignored: "
37-
+ str([key for key in kwargs.keys()])
37+
+ str([key for key in kwargs])
3838
)
3939
nn.Module.__init__(self)
4040

@@ -74,9 +74,9 @@ def __init__( # noqa: C901
7474

7575
# check if multiple 2D actor observations are provided
7676
if len(self.actor_obs_group_2d) > 1 and all(isinstance(item, dict) for item in actor_cnn_config.values()):
77-
assert len(actor_cnn_config) == len(
78-
self.actor_obs_group_2d
79-
), "Number of CNN configs must match number of 2D actor observations."
77+
assert len(actor_cnn_config) == len(self.actor_obs_group_2d), (
78+
"Number of CNN configs must match number of 2D actor observations."
79+
)
8080
elif len(self.actor_obs_group_2d) > 1:
8181
print(
8282
"Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups."
@@ -116,9 +116,9 @@ def __init__( # noqa: C901
116116

117117
# check if multiple 2D critic observations are provided
118118
if len(self.critic_obs_group_2d) > 1 and all(isinstance(item, dict) for item in critic_cnn_config.values()):
119-
assert len(critic_cnn_config) == len(
120-
self.critic_obs_group_2d
121-
), "Number of CNN configs must match number of 2D critic observations."
119+
assert len(critic_cnn_config) == len(self.critic_obs_group_2d), (
120+
"Number of CNN configs must match number of 2D critic observations."
121+
)
122122
elif len(self.critic_obs_group_2d) > 1:
123123
print(
124124
"Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups."
@@ -171,7 +171,6 @@ def __init__( # noqa: C901
171171
Normal.set_default_validate_args(False)
172172

173173
def update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]):
174-
175174
if self.actor_cnns is not None:
176175
# encode the 2D actor observations
177176
cnn_enc_list = []

rsl_rl/networks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .normalization import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization
1212

1313
__all__ = [
14+
"CNN",
1415
"MLP",
1516
"EmpiricalDiscountedVariationNormalization",
1617
"EmpiricalNormalization",

rsl_rl/networks/cnn.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ def __init__(
2424
batchnorm: bool | list[bool] = False,
2525
max_pool: bool | list[bool] = False,
2626
):
27-
"""
28-
Convolutional Neural Network model.
27+
"""Convolutional Neural Network model.
2928
3029
.. note::
3130
Do not save config to allow for the model to be jit compiled.
@@ -87,7 +86,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8786

8887
def init_weights(self, scales: float | tuple[float]):
8988
"""Initialize the weights of the CNN."""
90-
9189
# initialize the weights
9290
for idx, module in enumerate(self):
9391
if isinstance(module, nn.Conv2d):

rsl_rl/runners/on_policy_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from rsl_rl.env import VecEnv
1919
from rsl_rl.modules import (
2020
ActorCritic,
21+
ActorCriticPerceptive,
2122
ActorCriticRecurrent,
22-
PerceptiveActorCritic,
2323
resolve_rnd_config,
2424
resolve_symmetry_config,
2525
)
@@ -420,7 +420,7 @@ def _construct_algorithm(self, obs: TensorDict) -> PPO:
420420

421421
# Initialize the policy
422422
actor_critic_class = eval(self.policy_cfg.pop("class_name"))
423-
actor_critic: ActorCritic | ActorCriticRecurrent | PerceptiveActorCritic = actor_critic_class(
423+
actor_critic: ActorCritic | ActorCriticRecurrent | ActorCriticPerceptive = actor_critic_class(
424424
obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg
425425
).to(self.device)
426426

0 commit comments

Comments
 (0)