1010import warnings
1111from tensordict import TensorDict
1212from torch .distributions import Normal
13+ from typing import NoReturn
1314
1415from rsl_rl .networks import MLP , EmpiricalNormalization , Memory
1516
@@ -34,7 +35,7 @@ def __init__(
3435 rnn_num_layers : int = 1 ,
3536 teacher_recurrent : bool = False ,
3637 ** kwargs ,
37- ):
38+ ) -> None :
3839 if "rnn_hidden_size" in kwargs :
3940 warnings .warn (
4041 "The argument `rnn_hidden_size` is deprecated and will be removed in a future version. "
@@ -112,12 +113,12 @@ def reset(
112113 self ,
113114 dones : torch .Tensor | None = None ,
114115 hidden_states : tuple [torch .Tensor | tuple [torch .Tensor ] | None ] = (None , None ),
115- ):
116+ ) -> None :
116117 self .memory_s .reset (dones , hidden_states [0 ])
117118 if self .teacher_recurrent :
118119 self .memory_t .reset (dones , hidden_states [1 ])
119120
120- def forward (self ):
121+ def forward (self ) -> NoReturn :
121122 raise NotImplementedError
122123
123124 @property
@@ -132,7 +133,7 @@ def action_std(self) -> torch.Tensor:
132133 def entropy (self ) -> torch .Tensor :
133134 return self .distribution .entropy ().sum (dim = - 1 )
134135
135- def _update_distribution (self , obs : TensorDict ):
136+ def _update_distribution (self , obs : TensorDict ) -> None :
136137 # compute mean
137138 mean = self .student (obs )
138139 # compute standard deviation
@@ -181,18 +182,18 @@ def get_hidden_states(self) -> tuple[torch.Tensor | tuple[torch.Tensor] | None]:
181182 else :
182183 return self .memory_s .hidden_states , None
183184
184- def detach_hidden_states (self , dones : torch .Tensor | None = None ):
185+ def detach_hidden_states (self , dones : torch .Tensor | None = None ) -> None :
185186 self .memory_s .detach_hidden_states (dones )
186187 if self .teacher_recurrent :
187188 self .memory_t .detach_hidden_states (dones )
188189
189- def train (self , mode : bool = True ):
190+ def train (self , mode : bool = True ) -> None :
190191 super ().train (mode )
191192 # make sure teacher is in eval mode
192193 self .teacher .eval ()
193194 self .teacher_obs_normalizer .eval ()
194195
195- def update_normalization (self , obs : TensorDict ):
196+ def update_normalization (self , obs : TensorDict ) -> None :
196197 if self .student_obs_normalization :
197198 student_obs = self .get_student_obs (obs )
198199 self .student_obs_normalizer .update (student_obs )
0 commit comments