Skip to content

Commit 2d5240d

Browse files
return type none
1 parent 743c60f commit 2d5240d

File tree

15 files changed

+90
-81
lines changed

15 files changed

+90
-81
lines changed

rsl_rl/algorithms/distillation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
device: str = "cpu",
3131
# Distributed training parameters
3232
multi_gpu_cfg: dict | None = None,
33-
):
33+
) -> None:
3434
# device-related parameters
3535
self.device = device
3636
self.is_multi_gpu = multi_gpu_cfg is not None
@@ -79,7 +79,7 @@ def init_storage(
7979
num_transitions_per_env: int,
8080
obs: TensorDict,
8181
actions_shape: tuple[int],
82-
):
82+
) -> None:
8383
# create rollout storage
8484
self.storage = RolloutStorage(
8585
training_type,
@@ -100,7 +100,7 @@ def act(self, obs: TensorDict) -> torch.Tensor:
100100

101101
def process_env_step(
102102
self, obs: TensorDict, rewards: torch.Tensor, dones: torch.Tensor, extras: dict[str, torch.Tensor]
103-
):
103+
) -> None:
104104
# update the normalizers
105105
self.policy.update_normalization(obs)
106106

@@ -163,7 +163,7 @@ def update(self) -> dict[str, float]:
163163
Helper functions
164164
"""
165165

166-
def broadcast_parameters(self):
166+
def broadcast_parameters(self) -> None:
167167
"""Broadcast model parameters to all GPUs."""
168168
# obtain the model parameters on current GPU
169169
model_params = [self.policy.state_dict()]
@@ -172,7 +172,7 @@ def broadcast_parameters(self):
172172
# load the model parameters on all GPUs from source GPU
173173
self.policy.load_state_dict(model_params[0])
174174

175-
def reduce_parameters(self):
175+
def reduce_parameters(self) -> None:
176176
"""Collect gradients from all GPUs and average them.
177177
178178
This function is called after the backward pass to synchronize the gradients across all GPUs.

rsl_rl/algorithms/ppo.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(
4646
symmetry_cfg: dict | None = None,
4747
# Distributed training parameters
4848
multi_gpu_cfg: dict | None = None,
49-
):
49+
) -> None:
5050
# device-related parameters
5151
self.device = device
5252
self.is_multi_gpu = multi_gpu_cfg is not None
@@ -123,7 +123,7 @@ def init_storage(
123123
num_transitions_per_env: int,
124124
obs: TensorDict,
125125
actions_shape: tuple[int] | list[int],
126-
):
126+
) -> None:
127127
# create rollout storage
128128
self.storage = RolloutStorage(
129129
training_type,
@@ -149,7 +149,7 @@ def act(self, obs: TensorDict) -> torch.Tensor:
149149

150150
def process_env_step(
151151
self, obs: TensorDict, rewards: torch.Tensor, dones: torch.Tensor, extras: dict[str, torch.Tensor]
152-
):
152+
) -> None:
153153
# update the normalizers
154154
self.policy.update_normalization(obs)
155155
if self.rnd:
@@ -178,7 +178,7 @@ def process_env_step(
178178
self.transition.clear()
179179
self.policy.reset(dones)
180180

181-
def compute_returns(self, obs: TensorDict):
181+
def compute_returns(self, obs: TensorDict) -> None:
182182
# compute value for the last step
183183
last_values = self.policy.evaluate(obs).detach()
184184
self.storage.compute_returns(
@@ -428,7 +428,7 @@ def update(self) -> dict[str, float]:
428428
Helper functions
429429
"""
430430

431-
def broadcast_parameters(self):
431+
def broadcast_parameters(self) -> None:
432432
"""Broadcast model parameters to all GPUs."""
433433
# obtain the model parameters on current GPU
434434
model_params = [self.policy.state_dict()]
@@ -441,7 +441,7 @@ def broadcast_parameters(self):
441441
if self.rnd:
442442
self.rnd.predictor.load_state_dict(model_params[1])
443443

444-
def reduce_parameters(self):
444+
def reduce_parameters(self) -> None:
445445
"""Collect gradients from all GPUs and average them.
446446
447447
This function is called after the backward pass to synchronize the gradients across all GPUs.

rsl_rl/modules/actor_critic.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch.nn as nn
1010
from tensordict import TensorDict
1111
from torch.distributions import Normal
12+
from typing import NoReturn
1213

1314
from rsl_rl.networks import MLP, EmpiricalNormalization
1415

@@ -30,7 +31,7 @@ def __init__(
3031
noise_std_type: str = "scalar",
3132
state_dependent_std=False,
3233
**kwargs,
33-
):
34+
) -> None:
3435
if kwargs:
3536
print(
3637
"ActorCritic.__init__ got unexpected arguments, which will be ignored: " + str([key for key in kwargs])
@@ -101,10 +102,10 @@ def reset(
101102
self,
102103
dones: torch.Tensor | None = None,
103104
hidden_states: tuple[torch.Tensor | tuple[torch.Tensor] | None] = (None, None),
104-
):
105+
) -> None:
105106
pass
106107

107-
def forward(self):
108+
def forward(self) -> NoReturn:
108109
raise NotImplementedError
109110

110111
@property
@@ -119,7 +120,7 @@ def action_std(self) -> torch.Tensor:
119120
def entropy(self) -> torch.Tensor:
120121
return self.distribution.entropy().sum(dim=-1)
121122

122-
def _update_distribution(self, obs: TensorDict):
123+
def _update_distribution(self, obs: TensorDict) -> None:
123124
if self.state_dependent_std:
124125
# compute mean and standard deviation
125126
mean_and_std = self.actor(obs)
@@ -173,7 +174,7 @@ def get_critic_obs(self, obs: TensorDict) -> torch.Tensor:
173174
def get_actions_log_prob(self, actions: torch.Tensor) -> torch.Tensor:
174175
return self.distribution.log_prob(actions).sum(dim=-1)
175176

176-
def update_normalization(self, obs: TensorDict):
177+
def update_normalization(self, obs: TensorDict) -> None:
177178
if self.actor_obs_normalization:
178179
actor_obs = self.get_actor_obs(obs)
179180
self.actor_obs_normalizer.update(actor_obs)

rsl_rl/modules/actor_critic_recurrent.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import warnings
1111
from tensordict import TensorDict
1212
from torch.distributions import Normal
13+
from typing import NoReturn
1314

1415
from rsl_rl.networks import MLP, EmpiricalNormalization, Memory
1516

@@ -34,7 +35,7 @@ def __init__(
3435
rnn_hidden_dim: int = 256,
3536
rnn_num_layers: int = 1,
3637
**kwargs,
37-
):
38+
) -> None:
3839
if "rnn_hidden_size" in kwargs:
3940
warnings.warn(
4041
"The argument `rnn_hidden_size` is deprecated and will be removed in a future version. "
@@ -126,14 +127,14 @@ def action_std(self) -> torch.Tensor:
126127
def entropy(self) -> torch.Tensor:
127128
return self.distribution.entropy().sum(dim=-1)
128129

129-
def reset(self, dones: torch.Tensor | None = None):
130+
def reset(self, dones: torch.Tensor | None = None) -> None:
130131
self.memory_a.reset(dones)
131132
self.memory_c.reset(dones)
132133

133-
def forward(self):
134+
def forward(self) -> NoReturn:
134135
raise NotImplementedError
135136

136-
def _update_distribution(self, obs: TensorDict):
137+
def _update_distribution(self, obs: TensorDict) -> None:
137138
if self.state_dependent_std:
138139
# compute mean and standard deviation
139140
mean_and_std = self.actor(obs)
@@ -205,7 +206,7 @@ def get_hidden_states(
205206
) -> tuple[torch.Tensor | tuple[torch.Tensor] | None, torch.Tensor | tuple[torch.Tensor] | None]:
206207
return self.memory_a.hidden_states, self.memory_c.hidden_states
207208

208-
def update_normalization(self, obs: TensorDict):
209+
def update_normalization(self, obs: TensorDict) -> None:
209210
if self.actor_obs_normalization:
210211
actor_obs = self.get_actor_obs(obs)
211212
self.actor_obs_normalizer.update(actor_obs)

rsl_rl/modules/rnd.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
import torch.nn as nn
1010
from tensordict import TensorDict
11+
from typing import NoReturn
1112

1213
from rsl_rl.env import VecEnv
1314
from rsl_rl.networks import MLP, EmpiricalDiscountedVariationNormalization, EmpiricalNormalization
@@ -33,7 +34,7 @@ def __init__(
3334
reward_normalization: bool = False,
3435
device: str = "cpu",
3536
weight_schedule: dict | None = None,
36-
):
37+
) -> None:
3738
"""Initialize the RND module.
3839
3940
- If :attr:`state_normalization` is True, then the input state is normalized using an Empirical Normalization
@@ -138,7 +139,7 @@ def get_intrinsic_reward(self, obs: TensorDict) -> torch.Tensor:
138139

139140
return intrinsic_reward
140141

141-
def forward(self, *args, **kwargs):
142+
def forward(self, *args, **kwargs) -> NoReturn:
142143
raise RuntimeError("Forward method is not implemented. Use get_intrinsic_reward instead.")
143144

144145
def train(self, mode: bool = True) -> RandomNetworkDistillation:
@@ -157,7 +158,7 @@ def get_rnd_state(self, obs: TensorDict) -> torch.Tensor:
157158
obs_list = [obs[obs_group] for obs_group in self.obs_groups["rnd_state"]]
158159
return torch.cat(obs_list, dim=-1)
159160

160-
def update_normalization(self, obs: TensorDict):
161+
def update_normalization(self, obs: TensorDict) -> None:
161162
# Normalize the state
162163
if self.state_normalization:
163164
rnd_state = self.get_rnd_state(obs)

rsl_rl/modules/student_teacher.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch.nn as nn
1010
from tensordict import TensorDict
1111
from torch.distributions import Normal
12+
from typing import NoReturn
1213

1314
from rsl_rl.networks import MLP, EmpiricalNormalization
1415

@@ -29,7 +30,7 @@ def __init__(
2930
init_noise_std: float = 0.1,
3031
noise_std_type: str = "scalar",
3132
**kwargs,
32-
):
33+
) -> None:
3334
if kwargs:
3435
print(
3536
"StudentTeacher.__init__ got unexpected arguments, which will be ignored: "
@@ -93,10 +94,10 @@ def reset(
9394
self,
9495
dones: torch.Tensor | None = None,
9596
hidden_states: tuple[torch.Tensor | tuple[torch.Tensor] | None] = (None, None),
96-
):
97+
) -> None:
9798
pass
9899

99-
def forward(self):
100+
def forward(self) -> NoReturn:
100101
raise NotImplementedError
101102

102103
@property
@@ -111,7 +112,7 @@ def action_std(self) -> torch.Tensor:
111112
def entropy(self) -> torch.Tensor:
112113
return self.distribution.entropy().sum(dim=-1)
113114

114-
def _update_distribution(self, obs: TensorDict):
115+
def _update_distribution(self, obs: TensorDict) -> None:
115116
# compute mean
116117
mean = self.student(obs)
117118
# compute standard deviation
@@ -152,16 +153,16 @@ def get_teacher_obs(self, obs: TensorDict) -> torch.Tensor:
152153
def get_hidden_states(self) -> tuple[torch.Tensor | tuple[torch.Tensor] | None]:
153154
return None, None
154155

155-
def detach_hidden_states(self, dones: torch.Tensor | None = None):
156+
def detach_hidden_states(self, dones: torch.Tensor | None = None) -> None:
156157
pass
157158

158-
def train(self, mode: bool = True):
159+
def train(self, mode: bool = True) -> None:
159160
super().train(mode)
160161
# make sure teacher is in eval mode
161162
self.teacher.eval()
162163
self.teacher_obs_normalizer.eval()
163164

164-
def update_normalization(self, obs: TensorDict):
165+
def update_normalization(self, obs: TensorDict) -> None:
165166
if self.student_obs_normalization:
166167
student_obs = self.get_student_obs(obs)
167168
self.student_obs_normalizer.update(student_obs)

rsl_rl/modules/student_teacher_recurrent.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import warnings
1111
from tensordict import TensorDict
1212
from torch.distributions import Normal
13+
from typing import NoReturn
1314

1415
from rsl_rl.networks import MLP, EmpiricalNormalization, Memory
1516

@@ -34,7 +35,7 @@ def __init__(
3435
rnn_num_layers: int = 1,
3536
teacher_recurrent: bool = False,
3637
**kwargs,
37-
):
38+
) -> None:
3839
if "rnn_hidden_size" in kwargs:
3940
warnings.warn(
4041
"The argument `rnn_hidden_size` is deprecated and will be removed in a future version. "
@@ -112,12 +113,12 @@ def reset(
112113
self,
113114
dones: torch.Tensor | None = None,
114115
hidden_states: tuple[torch.Tensor | tuple[torch.Tensor] | None] = (None, None),
115-
):
116+
) -> None:
116117
self.memory_s.reset(dones, hidden_states[0])
117118
if self.teacher_recurrent:
118119
self.memory_t.reset(dones, hidden_states[1])
119120

120-
def forward(self):
121+
def forward(self) -> NoReturn:
121122
raise NotImplementedError
122123

123124
@property
@@ -132,7 +133,7 @@ def action_std(self) -> torch.Tensor:
132133
def entropy(self) -> torch.Tensor:
133134
return self.distribution.entropy().sum(dim=-1)
134135

135-
def _update_distribution(self, obs: TensorDict):
136+
def _update_distribution(self, obs: TensorDict) -> None:
136137
# compute mean
137138
mean = self.student(obs)
138139
# compute standard deviation
@@ -181,18 +182,18 @@ def get_hidden_states(self) -> tuple[torch.Tensor | tuple[torch.Tensor] | None]:
181182
else:
182183
return self.memory_s.hidden_states, None
183184

184-
def detach_hidden_states(self, dones: torch.Tensor | None = None):
185+
def detach_hidden_states(self, dones: torch.Tensor | None = None) -> None:
185186
self.memory_s.detach_hidden_states(dones)
186187
if self.teacher_recurrent:
187188
self.memory_t.detach_hidden_states(dones)
188189

189-
def train(self, mode: bool = True):
190+
def train(self, mode: bool = True) -> None:
190191
super().train(mode)
191192
# make sure teacher is in eval mode
192193
self.teacher.eval()
193194
self.teacher_obs_normalizer.eval()
194195

195-
def update_normalization(self, obs: TensorDict):
196+
def update_normalization(self, obs: TensorDict) -> None:
196197
if self.student_obs_normalization:
197198
student_obs = self.get_student_obs(obs)
198199
self.student_obs_normalizer.update(student_obs)

rsl_rl/networks/memory.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class Memory(nn.Module):
1818
Currently only supports GRU and LSTM.
1919
"""
2020

21-
def __init__(self, input_size: int, hidden_dim: int = 256, num_layers: int = 1, type: str = "lstm"):
21+
def __init__(self, input_size: int, hidden_dim: int = 256, num_layers: int = 1, type: str = "lstm") -> None:
2222
super().__init__()
2323
# RNN
2424
rnn_cls = nn.GRU if type.lower() == "gru" else nn.LSTM
@@ -43,7 +43,9 @@ def forward(
4343
out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
4444
return out
4545

46-
def reset(self, dones: torch.Tensor | None = None, hidden_states: torch.Tensor | tuple[torch.Tensor] | None = None):
46+
def reset(
47+
self, dones: torch.Tensor | None = None, hidden_states: torch.Tensor | tuple[torch.Tensor] | None = None
48+
) -> None:
4749
if dones is None: # reset hidden states
4850
if hidden_states is None:
4951
self.hidden_states = None
@@ -61,7 +63,7 @@ def reset(self, dones: torch.Tensor | None = None, hidden_states: torch.Tensor |
6163
"Resetting hidden states of done environments with custom hidden states is not implemented"
6264
)
6365

64-
def detach_hidden_states(self, dones: torch.Tensor | None = None):
66+
def detach_hidden_states(self, dones: torch.Tensor | None = None) -> None:
6567
if self.hidden_states is not None:
6668
if dones is None: # detach all hidden states
6769
if isinstance(self.hidden_states, tuple): # tuple in case of LSTM

rsl_rl/networks/mlp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
hidden_dims: tuple[int] | list[int],
3636
activation: str = "elu",
3737
last_activation: str | None = None,
38-
):
38+
) -> None:
3939
"""Initialize the MLP.
4040
4141
Args:
@@ -82,7 +82,7 @@ def __init__(
8282
for idx, layer in enumerate(layers):
8383
self.add_module(f"{idx}", layer)
8484

85-
def init_weights(self, scales: float | tuple[float]):
85+
def init_weights(self, scales: float | tuple[float]) -> None:
8686
"""Initialize the weights of the MLP.
8787
8888
Args:

0 commit comments

Comments
 (0)