|
8 | 8 | import torch |
9 | 9 | import torch.nn as nn |
10 | 10 | from tensordict import TensorDict |
11 | | -from typing import NoReturn |
| 11 | +from typing import Any, NoReturn |
12 | 12 |
|
13 | 13 | from rsl_rl.env import VecEnv |
14 | 14 | from rsl_rl.networks import MLP, EmpiricalDiscountedVariationNormalization, EmpiricalNormalization |
@@ -139,7 +139,7 @@ def get_intrinsic_reward(self, obs: TensorDict) -> torch.Tensor: |
139 | 139 |
|
140 | 140 | return intrinsic_reward |
141 | 141 |
|
142 | | - def forward(self, *args, **kwargs) -> NoReturn: |
| 142 | + def forward(self, *args: Any, **kwargs: dict[str, Any]) -> NoReturn: |
143 | 143 | raise RuntimeError("Forward method is not implemented. Use get_intrinsic_reward instead.") |
144 | 144 |
|
145 | 145 | def train(self, mode: bool = True) -> RandomNetworkDistillation: |
@@ -168,14 +168,14 @@ def update_normalization(self, obs: TensorDict) -> None: |
168 | 168 | Different weight schedules. |
169 | 169 | """ |
170 | 170 |
|
171 | | - def _constant_weight_schedule(self, step: int, **kwargs) -> float: |
| 171 | + def _constant_weight_schedule(self, step: int, **kwargs: dict[str, Any]) -> float: |
172 | 172 | return self.initial_weight |
173 | 173 |
|
174 | | - def _step_weight_schedule(self, step: int, final_step: int, final_value: float, **kwargs) -> float: |
| 174 | + def _step_weight_schedule(self, step: int, final_step: int, final_value: float, **kwargs: dict[str, Any]) -> float: |
175 | 175 | return self.initial_weight if step < final_step else final_value |
176 | 176 |
|
177 | 177 | def _linear_weight_schedule( |
178 | | - self, step: int, initial_step: int, final_step: int, final_value: float, **kwargs |
| 178 | + self, step: int, initial_step: int, final_step: int, final_value: float, **kwargs: dict[str, Any] |
179 | 179 | ) -> float: |
180 | 180 | if step < initial_step: |
181 | 181 | return self.initial_weight |
|
0 commit comments