|
15 | 15 | class RolloutStorage: |
16 | 16 | class Transition: |
17 | 17 | def __init__(self) -> None: |
18 | | - self.observations: TensorDict = None # type: ignore |
19 | | - self.actions: torch.Tensor = None # type: ignore |
20 | | - self.privileged_actions: torch.Tensor = None # type: ignore |
21 | | - self.rewards: torch.Tensor = None # type: ignore |
22 | | - self.dones: torch.Tensor = None # type: ignore |
23 | | - self.values: torch.Tensor = None # type: ignore |
24 | | - self.actions_log_prob: torch.Tensor = None # type: ignore |
25 | | - self.action_mean: torch.Tensor = None # type: ignore |
26 | | - self.action_sigma: torch.Tensor = None # type: ignore |
27 | | - self.hidden_states: tuple[torch.Tensor | tuple[torch.Tensor] | None] = (None, None) # type: ignore |
| 18 | + self.observations: TensorDict | None = None |
| 19 | + self.actions: torch.Tensor | None = None |
| 20 | + self.privileged_actions: torch.Tensor | None = None |
| 21 | + self.rewards: torch.Tensor | None = None |
| 22 | + self.dones: torch.Tensor | None = None |
| 23 | + self.values: torch.Tensor | None = None |
| 24 | + self.actions_log_prob: torch.Tensor |
| 25 | + self.action_mean: torch.Tensor | None = None |
| 26 | + self.action_sigma: torch.Tensor | None = None |
| 27 | + self.hidden_states: tuple[torch.Tensor | tuple[torch.Tensor] | None, ...] = (None, None) |
28 | 28 |
|
29 | 29 | def clear(self) -> None: |
30 | 30 | self.__init__() |
@@ -102,7 +102,7 @@ def add_transitions(self, transition: Transition) -> None: |
102 | 102 | # Increment the counter |
103 | 103 | self.step += 1 |
104 | 104 |
|
105 | | - def _save_hidden_states(self, hidden_states: tuple[torch.Tensor | tuple[torch.Tensor] | None]) -> None: |
| 105 | + def _save_hidden_states(self, hidden_states: tuple[torch.Tensor | tuple[torch.Tensor] | None, ...]) -> None: |
106 | 106 | if hidden_states == (None, None): |
107 | 107 | return |
108 | 108 | # Make a tuple out of GRU hidden states to match the LSTM format |
|
0 commit comments