Skip to content

Commit af24fcf

Browse files
fix all typing errors
1 parent 2d5240d commit af24fcf

File tree

10 files changed

+24
-24
lines changed

10 files changed

+24
-24
lines changed

rsl_rl/modules/actor_critic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch.nn as nn
1010
from tensordict import TensorDict
1111
from torch.distributions import Normal
12-
from typing import NoReturn
12+
from typing import Any, NoReturn
1313

1414
from rsl_rl.networks import MLP, EmpiricalNormalization
1515

@@ -29,8 +29,8 @@ def __init__(
2929
activation: str = "elu",
3030
init_noise_std: float = 1.0,
3131
noise_std_type: str = "scalar",
32-
state_dependent_std=False,
33-
**kwargs,
32+
state_dependent_std: bool = False,
33+
**kwargs: dict[str, Any],
3434
) -> None:
3535
if kwargs:
3636
print(
@@ -144,7 +144,7 @@ def _update_distribution(self, obs: TensorDict) -> None:
144144
# create distribution
145145
self.distribution = Normal(mean, std)
146146

147-
def act(self, obs: TensorDict, **kwargs) -> torch.Tensor:
147+
def act(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor:
148148
obs = self.get_actor_obs(obs)
149149
obs = self.actor_obs_normalizer(obs)
150150
self._update_distribution(obs)
@@ -158,7 +158,7 @@ def act_inference(self, obs: TensorDict) -> torch.Tensor:
158158
else:
159159
return self.actor(obs)
160160

161-
def evaluate(self, obs: TensorDict, **kwargs) -> torch.Tensor:
161+
def evaluate(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor:
162162
obs = self.get_critic_obs(obs)
163163
obs = self.critic_obs_normalizer(obs)
164164
return self.critic(obs)

rsl_rl/modules/actor_critic_recurrent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import warnings
1111
from tensordict import TensorDict
1212
from torch.distributions import Normal
13-
from typing import NoReturn
13+
from typing import Any, NoReturn
1414

1515
from rsl_rl.networks import MLP, EmpiricalNormalization, Memory
1616

@@ -34,7 +34,7 @@ def __init__(
3434
rnn_type: str = "lstm",
3535
rnn_hidden_dim: int = 256,
3636
rnn_num_layers: int = 1,
37-
**kwargs,
37+
**kwargs: dict[str, Any],
3838
) -> None:
3939
if "rnn_hidden_size" in kwargs:
4040
warnings.warn(

rsl_rl/modules/rnd.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
import torch.nn as nn
1010
from tensordict import TensorDict
11-
from typing import NoReturn
11+
from typing import Any, NoReturn
1212

1313
from rsl_rl.env import VecEnv
1414
from rsl_rl.networks import MLP, EmpiricalDiscountedVariationNormalization, EmpiricalNormalization
@@ -139,7 +139,7 @@ def get_intrinsic_reward(self, obs: TensorDict) -> torch.Tensor:
139139

140140
return intrinsic_reward
141141

142-
def forward(self, *args, **kwargs) -> NoReturn:
142+
def forward(self, *args: Any, **kwargs: dict[str, Any]) -> NoReturn:
143143
raise RuntimeError("Forward method is not implemented. Use get_intrinsic_reward instead.")
144144

145145
def train(self, mode: bool = True) -> RandomNetworkDistillation:
@@ -168,14 +168,14 @@ def update_normalization(self, obs: TensorDict) -> None:
168168
Different weight schedules.
169169
"""
170170

171-
def _constant_weight_schedule(self, step: int, **kwargs) -> float:
171+
def _constant_weight_schedule(self, step: int, **kwargs: dict[str, Any]) -> float:
172172
return self.initial_weight
173173

174-
def _step_weight_schedule(self, step: int, final_step: int, final_value: float, **kwargs) -> float:
174+
def _step_weight_schedule(self, step: int, final_step: int, final_value: float, **kwargs: dict[str, Any]) -> float:
175175
return self.initial_weight if step < final_step else final_value
176176

177177
def _linear_weight_schedule(
178-
self, step: int, initial_step: int, final_step: int, final_value: float, **kwargs
178+
self, step: int, initial_step: int, final_step: int, final_value: float, **kwargs: dict[str, Any]
179179
) -> float:
180180
if step < initial_step:
181181
return self.initial_weight

rsl_rl/modules/student_teacher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch.nn as nn
1010
from tensordict import TensorDict
1111
from torch.distributions import Normal
12-
from typing import NoReturn
12+
from typing import Any, NoReturn
1313

1414
from rsl_rl.networks import MLP, EmpiricalNormalization
1515

@@ -29,7 +29,7 @@ def __init__(
2929
activation: str = "elu",
3030
init_noise_std: float = 0.1,
3131
noise_std_type: str = "scalar",
32-
**kwargs,
32+
**kwargs: dict[str, Any],
3333
) -> None:
3434
if kwargs:
3535
print(

rsl_rl/modules/student_teacher_recurrent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import warnings
1111
from tensordict import TensorDict
1212
from torch.distributions import Normal
13-
from typing import NoReturn
13+
from typing import Any, NoReturn
1414

1515
from rsl_rl.networks import MLP, EmpiricalNormalization, Memory
1616

@@ -34,7 +34,7 @@ def __init__(
3434
rnn_hidden_dim: int = 256,
3535
rnn_num_layers: int = 1,
3636
teacher_recurrent: bool = False,
37-
**kwargs,
37+
**kwargs: dict[str, Any],
3838
) -> None:
3939
if "rnn_hidden_size" in kwargs:
4040
warnings.warn(

rsl_rl/networks/mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def init_weights(self, scales: float | tuple[float]) -> None:
8989
scales: Scale factor for the weights.
9090
"""
9191

92-
def get_scale(idx) -> float:
92+
def get_scale(idx: int) -> float:
9393
"""Get the scale factor for the weights of the MLP.
9494
9595
Args:

rsl_rl/runners/on_policy_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def save(self, path: str, infos: dict | None = None) -> None:
308308
if self.logger_type in ["neptune", "wandb"] and not self.disable_logs:
309309
self.writer.save_model(path, self.current_learning_iteration)
310310

311-
def load(self, path: str, load_optimizer: bool = True, map_location: str | None = None):
311+
def load(self, path: str, load_optimizer: bool = True, map_location: str | None = None) -> dict:
312312
loaded_dict = torch.load(path, weights_only=False, map_location=map_location)
313313
# -- Load model
314314
resumed_training = self.alg.policy.load_state_dict(loaded_dict["model_state_dict"])
@@ -327,7 +327,7 @@ def load(self, path: str, load_optimizer: bool = True, map_location: str | None
327327
self.current_learning_iteration = loaded_dict["iter"]
328328
return loaded_dict["infos"]
329329

330-
def get_inference_policy(self, device: str | None = None):
330+
def get_inference_policy(self, device: str | None = None) -> callable:
331331
self.eval_mode() # switch to evaluation mode (dropout for example)
332332
if device is not None:
333333
self.alg.policy.to(device)

rsl_rl/utils/neptune_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class NeptuneLogger:
1919
def __init__(self, project: str, token: str) -> None:
2020
self.run = neptune.init_run(project=project, api_token=token)
2121

22-
def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg) -> None:
22+
def store_config(self, env_cfg: dict | object, runner_cfg: dict, alg_cfg: dict, policy_cfg: dict) -> None:
2323
self.run["runner_cfg"] = runner_cfg
2424
self.run["policy_cfg"] = policy_cfg
2525
self.run["alg_cfg"] = alg_cfg
@@ -84,7 +84,7 @@ def add_scalar(
8484
def stop(self) -> None:
8585
self.neptune_logger.run.stop()
8686

87-
def log_config(self, env_cfg: dict, runner_cfg: dict, alg_cfg: dict, policy_cfg: dict) -> None:
87+
def log_config(self, env_cfg: dict | object, runner_cfg: dict, alg_cfg: dict, policy_cfg: dict) -> None:
8888
self.neptune_logger.store_config(env_cfg, runner_cfg, alg_cfg, policy_cfg)
8989

9090
def save_model(self, model_path: str, iter: int) -> None:

rsl_rl/utils/wandb_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(self, log_dir: str, flush_secs: int, cfg: dict) -> None:
4545
"Train/mean_episode_length/time": "Train/mean_episode_length_time",
4646
}
4747

48-
def store_config(self, env_cfg, runner_cfg: dict, alg_cfg: dict, policy_cfg: dict) -> None:
48+
def store_config(self, env_cfg: dict | object, runner_cfg: dict, alg_cfg: dict, policy_cfg: dict) -> None:
4949
wandb.config.update({"runner_cfg": runner_cfg})
5050
wandb.config.update({"policy_cfg": policy_cfg})
5151
wandb.config.update({"alg_cfg": alg_cfg})
@@ -74,7 +74,7 @@ def add_scalar(
7474
def stop(self) -> None:
7575
wandb.finish()
7676

77-
def log_config(self, env_cfg, runner_cfg: dict, alg_cfg: dict, policy_cfg: dict) -> None:
77+
def log_config(self, env_cfg: dict | object, runner_cfg: dict, alg_cfg: dict, policy_cfg: dict) -> None:
7878
self.store_config(env_cfg, runner_cfg, alg_cfg, policy_cfg)
7979

8080
def save_model(self, model_path: str, iter: int) -> None:

ruff.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ select = [
3333
# ruff
3434
"RUF",
3535
]
36-
ignore = ["B006", "B007", "B028"]
36+
ignore = ["B006", "B007", "B028", "ANN401"]
3737
per-file-ignores = {"*/__init__.py" = ["F401"]}
3838

3939
[lint.isort]

0 commit comments

Comments
 (0)