Skip to content

Commit 7e3e65d

Browse files
fix docstring formatting
1 parent 56415cf commit 7e3e65d

File tree

12 files changed

+72
-103
lines changed

12 files changed

+72
-103
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ For documentation, we adopt the [Google Style Guide](https://sphinxcontrib-napol
5757
We use the following tools for maintaining code quality:
5858

5959
- [pre-commit](https://pre-commit.com/): Runs a list of formatters and linters over the codebase.
60-
- [black](https://black.readthedocs.io/en/stable/): The uncompromising code formatter.
61-
- [flake8](https://flake8.pycqa.org/en/latest/): A wrapper around PyFlakes, pycodestyle, and McCabe complexity checker.
60+
- [ruff](https://github.com/astral-sh/ruff): An extremely fast Python linter and code formatter, written in Rust.
6261

6362
Please check [here](https://pre-commit.com/#install) for instructions to set these up. To run over the entire repository, please execute the following command in the terminal:
6463

rsl_rl/env/vec_env.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
class VecEnv(ABC):
1414
"""Abstract class for a vectorized environment.
1515
16-
The vectorized environment is a collection of environments that are synchronized. This means that
17-
the same type of action is applied to all environments and the same type of observation is returned from all
18-
environments.
16+
The vectorized environment is a collection of environments that are synchronized. This means that the same type of
17+
action is applied to all environments and the same type of observation is returned from all environments.
1918
"""
2019

2120
num_envs: int
@@ -41,16 +40,12 @@ class VecEnv(ABC):
4140
cfg: dict | object
4241
"""Configuration object."""
4342

44-
"""
45-
Operations.
46-
"""
47-
4843
@abstractmethod
4944
def get_observations(self) -> TensorDict:
5045
"""Return the current observations.
5146
5247
Returns:
53-
observations: Observations from the environment.
48+
The observations from the environment.
5449
"""
5550
raise NotImplementedError
5651

@@ -62,13 +57,12 @@ def step(self, actions: torch.Tensor) -> tuple[TensorDict, torch.Tensor, torch.T
6257
actions: Input actions to apply. Shape: (num_envs, num_actions)
6358
6459
Returns:
65-
observations: Observations from the environment.
66-
rewards: Rewards from the environment. Shape: (num_envs,)
67-
dones: Done flags from the environment. Shape: (num_envs,)
68-
extras: Extra information from the environment.
60+
observations: Observations from the environment.
61+
rewards: Rewards from the environment. Shape: (num_envs,)
62+
dones: Done flags from the environment. Shape: (num_envs,)
63+
extras: Extra information from the environment.
6964
7065
Observations:
71-
7266
The observations TensorDict usually contains multiple observation groups. The `obs_groups`
7367
dictionary of the runner configuration specifies which observation groups are used for which
7468
purpose, i.e., it maps the available observation groups to observation sets. The observation sets
@@ -83,7 +77,6 @@ def step(self, actions: torch.Tensor) -> tuple[TensorDict, torch.Tensor, torch.T
8377
`rsl_rl/utils/utils.py`.
8478
8579
Extras:
86-
8780
The extras dictionary includes metrics such as the episode reward, episode length, etc. The following
8881
dictionary keys are used by rsl_rl:
8982

rsl_rl/modules/actor_critic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,12 @@ def load_state_dict(self, state_dict: dict, strict: bool = True) -> bool:
192192
193193
Args:
194194
state_dict: State dictionary of the model.
195-
strict: Whether to strictly enforce that the keys in state_dict match the keys returned by this
196-
module's state_dict() function.
195+
strict: Whether to strictly enforce that the keys in `state_dict` match the keys returned by this module's
196+
:meth:`state_dict` function.
197197
198198
Returns:
199-
bool: Whether this training resumes a previous training. This flag is used by the `load()` function of
200-
`OnPolicyRunner` to determine how to load further parameters (relevant for, e.g., distillation).
199+
Whether this training resumes a previous training. This flag is used by the :func:`load` function of
200+
:class:`OnPolicyRunner` to determine how to load further parameters (relevant for, e.g., distillation).
201201
"""
202202
super().load_state_dict(state_dict, strict=strict)
203203
return True

rsl_rl/modules/actor_critic_recurrent.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,12 @@ def load_state_dict(self, state_dict: dict, strict: bool = True) -> bool:
223223
224224
Args:
225225
state_dict: State dictionary of the model.
226-
strict: Whether to strictly enforce that the keys in state_dict match the keys returned by this
227-
module's state_dict() function.
226+
strict: Whether to strictly enforce that the keys in `state_dict` match the keys returned by this module's
227+
:meth:`state_dict` function.
228228
229229
Returns:
230-
bool: Whether this training resumes a previous training. This flag is used by the `load()` function of
231-
`OnPolicyRunner` to determine how to load further parameters (relevant for, e.g., distillation).
230+
Whether this training resumes a previous training. This flag is used by the :func:`load` function of
231+
:class:`OnPolicyRunner` to determine how to load further parameters (relevant for, e.g., distillation).
232232
"""
233233
super().load_state_dict(state_dict, strict=strict)
234234
return True

rsl_rl/modules/rnd.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,40 +41,37 @@ def __init__(
4141
layer.
4242
- If :attr:`reward_normalization` is True, then the intrinsic reward is normalized using an Empirical Discounted
4343
Variation Normalization layer.
44-
45-
.. note::
46-
If the hidden dimensions are -1 in the predictor and target networks configuration, then the number of
47-
states is used as the hidden dimension.
44+
- If the hidden dimensions are -1 in the predictor and target networks configuration, then the number of states
45+
is used as the hidden dimension.
4846
4947
Args:
5048
num_states: Number of states/inputs to the predictor and target networks.
5149
obs_groups: Dictionary of observation groups.
5250
num_outputs: Number of outputs (embedding size) of the predictor and target networks.
5351
predictor_hidden_dims: List of hidden dimensions of the predictor network.
5452
target_hidden_dims: List of hidden dimensions of the target network.
55-
activation: Activation function. Defaults to "elu".
56-
weight: Scaling factor of the intrinsic reward. Defaults to 0.0.
57-
state_normalization: Whether to normalize the input state. Defaults to False.
58-
reward_normalization: Whether to normalize the intrinsic reward. Defaults to False.
59-
device: Device to use. Defaults to "cpu".
60-
weight_schedule: The type of schedule to use for the RND weight parameter.
61-
Defaults to None, in which case the weight parameter is constant.
53+
activation: Activation function.
54+
weight: Scaling factor of the intrinsic reward.
55+
state_normalization: Whether to normalize the input state.
56+
reward_normalization: Whether to normalize the intrinsic reward.
57+
device: Device to use.
58+
weight_schedule: Type of schedule to use for the RND weight parameter.
6259
It is a dictionary with the following keys:
6360
64-
- "mode": The type of schedule to use for the RND weight parameter.
61+
- "mode": Type of schedule to use for the RND weight parameter.
6562
- "constant": Constant weight schedule.
6663
- "step": Step weight schedule.
6764
- "linear": Linear weight schedule.
6865
6966
For the "step" weight schedule, the following parameters are required:
7067
71-
- "final_step": The step at which the weight parameter is set to the final value.
72-
- "final_value": The final value of the weight parameter.
68+
- "final_step": Step at which the weight parameter is set to the final value.
69+
- "final_value": Final value of the weight parameter.
7370
7471
For the "linear" weight schedule, the following parameters are required:
75-
- "initial_step": The step at which the weight parameter is set to the initial value.
76-
- "final_step": The step at which the weight parameter is set to the final value.
77-
- "final_value": The final value of the weight parameter.
72+
- "initial_step": Step at which the weight parameter is set to the initial value.
73+
- "final_step": Step at which the weight parameter is set to the final value.
74+
- "final_value": Final value of the weight parameter.
7875
"""
7976
# Initialize parent class
8077
super().__init__()
@@ -165,10 +162,6 @@ def update_normalization(self, obs: TensorDict) -> None:
165162
rnd_state = self.get_rnd_state(obs)
166163
self.state_normalizer.update(rnd_state)
167164

168-
"""
169-
Different weight schedules.
170-
"""
171-
172165
def _constant_weight_schedule(self, step: int, **kwargs: dict[str, Any]) -> float:
173166
return self.initial_weight
174167

@@ -192,10 +185,10 @@ def resolve_rnd_config(alg_cfg: dict, obs: TensorDict, obs_groups: dict[str, lis
192185
"""Resolve the RND configuration.
193186
194187
Args:
195-
alg_cfg: The algorithm configuration dictionary.
196-
obs: The observation dictionary.
197-
obs_groups: The observation groups dictionary.
198-
env: The environment.
188+
alg_cfg: Algorithm configuration dictionary.
189+
obs: Observation dictionary.
190+
obs_groups: Observation groups dictionary.
191+
env: Environment object.
199192
200193
Returns:
201194
The resolved algorithm configuration dictionary.

rsl_rl/modules/student_teacher.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,12 @@ def load_state_dict(self, state_dict: dict, strict: bool = True) -> bool:
172172
173173
Args:
174174
state_dict: State dictionary of the model.
175-
strict: Whether to strictly enforce that the keys in state_dict match the keys returned by this
176-
module's state_dict() function.
175+
strict: Whether to strictly enforce that the keys in `state_dict` match the keys returned by this module's
176+
:meth:`state_dict` function.
177177
178178
Returns:
179-
bool: Whether this training resumes a previous training. This flag is used by the `load()` function of
180-
`OnPolicyRunner` to determine how to load further parameters.
179+
Whether this training resumes a previous training. This flag is used by the :func:`load` function of
180+
:class:`OnPolicyRunner` to determine how to load further parameters.
181181
"""
182182
# Check if state_dict contains teacher and student or just teacher parameters
183183
if any("actor" in key for key in state_dict): # Load parameters from rl training

rsl_rl/modules/student_teacher_recurrent.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,12 @@ def load_state_dict(self, state_dict: dict, strict: bool = True) -> bool:
203203
204204
Args:
205205
state_dict: State dictionary of the model.
206-
strict: Whether to strictly enforce that the keys in state_dict match the keys returned by this
207-
module's state_dict() function.
206+
strict: Whether to strictly enforce that the keys in `state_dict` match the keys returned by this module's
207+
:meth:`state_dict` function.
208208
209209
Returns:
210-
bool: Whether this training resumes a previous training. This flag is used by the `load()` function of
211-
`OnPolicyRunner` to determine how to load further parameters.
210+
Whether this training resumes a previous training. This flag is used by the :func:`load` function of
211+
:class:`OnPolicyRunner` to determine how to load further parameters.
212212
"""
213213
# Check if state_dict contains teacher and student or just teacher parameters
214214
if any("actor" in key for key in state_dict): # Load parameters from rl training

rsl_rl/modules/symmetry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ def resolve_symmetry_config(alg_cfg: dict, env: VecEnv) -> dict:
1212
"""Resolve the symmetry configuration.
1313
1414
Args:
15-
alg_cfg: The algorithm configuration dictionary.
16-
env: The environment.
15+
alg_cfg: Algorithm configuration dictionary.
16+
env: Environment object.
1717
1818
Returns:
1919
The resolved algorithm configuration dictionary.

rsl_rl/networks/memory.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
class Memory(nn.Module):
1515
"""Memory module for recurrent networks.
1616
17-
This module is used to store the hidden states of the policy.
18-
Currently only supports GRU and LSTM.
17+
This module is used to store the hidden states of the policy. It currently only supports GRU and LSTM.
1918
"""
2019

2120
def __init__(self, input_size: int, hidden_dim: int = 256, num_layers: int = 1, type: str = "lstm") -> None:

rsl_rl/networks/mlp.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,12 @@
1515
class MLP(nn.Sequential):
1616
"""Multi-layer perceptron.
1717
18-
The MLP network is a sequence of linear layers and activation functions. The
19-
last layer is a linear layer that outputs the desired dimension unless the
20-
last activation function is specified.
18+
The MLP network is a sequence of linear layers and activation functions. The last layer is a linear layer that
19+
outputs the desired dimension unless the last activation function is specified.
2120
2221
It provides additional conveniences:
23-
24-
- If the hidden dimensions have a value of ``-1``, the dimension is inferred
25-
from the input dimension.
26-
- If the output dimension is a tuple, the output is reshaped to the desired
27-
shape.
28-
22+
- If the hidden dimensions have a value of ``-1``, the dimension is inferred from the input dimension.
23+
- If the output dimension is a tuple, the output is reshaped to the desired shape.
2924
"""
3025

3126
def __init__(
@@ -41,11 +36,10 @@ def __init__(
4136
Args:
4237
input_dim: Dimension of the input.
4338
output_dim: Dimension of the output.
44-
hidden_dims: Dimensions of the hidden layers. A value of ``-1`` indicates
45-
that the dimension should be inferred from the input dimension.
46-
activation: Activation function. Defaults to "elu".
47-
last_activation: Activation function of the last layer. Defaults to None,
48-
in which case the last layer is linear.
39+
hidden_dims: Dimensions of the hidden layers. A value of ``-1`` indicates that the dimension should be
40+
inferred from the input dimension.
41+
activation: Activation function.
42+
last_activation: Activation function of the last layer. None results in a linear last layer.
4943
"""
5044
super().__init__()
5145

0 commit comments

Comments
 (0)