Skip to content

Commit 064ef49

Browse files
CNN docstrings
1 parent b2f620f commit 064ef49

File tree

3 files changed

+31
-14
lines changed

3 files changed

+31
-14
lines changed

rsl_rl/algorithms/ppo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from itertools import chain
1212
from tensordict import TensorDict
1313

14-
from rsl_rl.modules import ActorCritic, ActorCriticRecurrent
14+
from rsl_rl.modules import ActorCritic, ActorCriticPerceptive, ActorCriticRecurrent
1515
from rsl_rl.modules.rnd import RandomNetworkDistillation
1616
from rsl_rl.storage import RolloutStorage
1717
from rsl_rl.utils import string_to_callable
@@ -20,12 +20,12 @@
2020
class PPO:
2121
"""Proximal Policy Optimization algorithm (https://arxiv.org/abs/1707.06347)."""
2222

23-
policy: ActorCritic | ActorCriticRecurrent
23+
policy: ActorCritic | ActorCriticRecurrent | ActorCriticPerceptive
2424
"""The actor critic module."""
2525

2626
def __init__(
2727
self,
28-
policy: ActorCritic | ActorCriticRecurrent,
28+
policy: ActorCritic | ActorCriticRecurrent | ActorCriticPerceptive,
2929
num_learning_epochs: int = 5,
3030
num_mini_batches: int = 4,
3131
clip_param: float = 0.2,

rsl_rl/networks/cnn.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212

1313

1414
class CNN(nn.Sequential):
15+
"""Convolutional Neural Network (CNN).
16+
17+
The CNN network is a sequence of convolutional layers, optional batch normalization, activation functions, and
18+
optional max pooling. The final output can be flattened or pooled depending on the configuration.
19+
"""
20+
1521
def __init__(
1622
self,
1723
in_channels: int,
@@ -24,13 +30,25 @@ def __init__(
2430
batchnorm: bool | list[bool] = False,
2531
max_pool: bool | list[bool] = False,
2632
) -> None:
27-
"""Convolutional Neural Network model.
33+
"""Initialize the CNN.
34+
35+
Args:
36+
in_channels: Number of input channels.
37+
activation: Activation function to use.
38+
out_channels: List of output channels for each convolutional layer.
39+
kernel_size: List of kernel sizes for each convolutional layer or a single kernel size for all layers.
40+
stride: List of strides for each convolutional layer or a single stride for all layers.
41+
flatten: Whether to flatten the output tensor.
42+
avg_pool: If specified, applies an adaptive average pooling to the given output size after the convolutions.
43+
batchnorm: Whether to apply batch normalization after each convolutional layer.
44+
max_pool: Whether to apply max pooling after each convolutional layer.
2845
2946
.. note::
3047
Do not save config to allow for the model to be jit compiled.
3148
"""
3249
super().__init__()
3350

51+
# If parameters are not lists, convert them to lists
3452
if isinstance(batchnorm, bool):
3553
batchnorm = [batchnorm] * len(out_channels)
3654
if isinstance(max_pool, bool):
@@ -40,12 +58,11 @@ def __init__(
4058
if isinstance(stride, int):
4159
stride = [stride] * len(out_channels)
4260

43-
# get activation function
61+
# Resolve activation function
4462
activation_function = resolve_nn_activation(activation)
4563

46-
# build model layers
64+
# Create layers sequentially
4765
layers = []
48-
4966
for idx in range(len(out_channels)):
5067
in_channels = in_channels if idx == 0 else out_channels[idx - 1]
5168
layers.append(
@@ -62,16 +79,17 @@ def __init__(
6279
if max_pool[idx]:
6380
layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
6481

65-
# register the layers
82+
# Register the layers
6683
for idx, layer in enumerate(layers):
6784
self.add_module(f"{idx}", layer)
6885

86+
# Add avgpool if specified
6987
if avg_pool is not None:
7088
self.avgpool = nn.AdaptiveAvgPool2d(avg_pool)
7189
else:
7290
self.avgpool = None
7391

74-
# save flatten config for forward function
92+
# Save flatten flag for forward function
7593
self.flatten = flatten
7694

7795
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -84,9 +102,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
84102
x = x.flatten(start_dim=1)
85103
return x
86104

87-
def init_weights(self, scales: float | tuple[float]) -> None:
88-
"""Initialize the weights of the CNN."""
89-
# initialize the weights
105+
def init_weights(self) -> None:
106+
"""Initialize the weights of the CNN with Xavier initialization."""
90107
for idx, module in enumerate(self):
91108
if isinstance(module, nn.Conv2d):
92109
nn.init.xavier_uniform_(module.weight)

rsl_rl/networks/memory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212

1313

1414
class Memory(nn.Module):
15-
"""Memory module for recurrent networks.
15+
"""Memory network for recurrent architectures.
1616
17-
This module is used to store the hidden states of the policy. It currently only supports GRU and LSTM.
17+
This network is used to store the hidden states of the policy. It currently only supports GRU and LSTM.
1818
"""
1919

2020
def __init__(self, input_size: int, hidden_dim: int = 256, num_layers: int = 1, type: str = "lstm") -> None:

0 commit comments

Comments
 (0)