@@ -166,7 +166,7 @@ def act(
166166 self ,
167167 obs : TensorDict ,
168168 masks : torch .Tensor | None = None ,
169- hidden_states : torch .Tensor | tuple [torch .Tensor ] | None = None ,
169+ hidden_states : torch .Tensor | tuple [torch .Tensor , ... ] | None = None ,
170170 ) -> torch .Tensor :
171171 obs = self .get_actor_obs (obs )
172172 obs = self .actor_obs_normalizer (obs )
@@ -187,7 +187,7 @@ def evaluate(
187187 self ,
188188 obs : TensorDict ,
189189 masks : torch .Tensor | None = None ,
190- hidden_states : torch .Tensor | tuple [torch .Tensor ] | None = None ,
190+ hidden_states : torch .Tensor | tuple [torch .Tensor , ... ] | None = None ,
191191 ) -> torch .Tensor :
192192 obs = self .get_critic_obs (obs )
193193 obs = self .critic_obs_normalizer (obs )
@@ -207,7 +207,7 @@ def get_actions_log_prob(self, actions: torch.Tensor) -> torch.Tensor:
207207
208208 def get_hidden_states (
209209 self ,
210- ) -> tuple [torch .Tensor | tuple [torch .Tensor ] | None , torch .Tensor | tuple [torch .Tensor ] | None ]:
210+ ) -> tuple [torch .Tensor | tuple [torch .Tensor , ... ] | None , torch .Tensor | tuple [torch .Tensor , ... ] | None ]:
211211 return self .memory_a .hidden_states , self .memory_c .hidden_states
212212
213213 def update_normalization (self , obs : TensorDict ) -> None :
0 commit comments