|
10 | 10 |
|
11 | 11 | from rsl_rl.utils import unpad_trajectories |
12 | 12 |
|
| 13 | +HiddenState = torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None |
| 14 | +"""Type alias for the hidden state of RNNs (GRU/LSTM). |
| 15 | +
|
| 16 | +For GRUs, this is a single tensor while for LSTMs, this is a tuple of two tensors (hidden state and cell state). |
| 17 | +""" |
| 18 | + |
13 | 19 |
|
14 | 20 | class Memory(nn.Module): |
15 | 21 | """Memory module for recurrent networks. |
16 | 22 |
|
17 | | - This module is used to store the hidden states of the policy. It currently only supports GRU and LSTM. |
| 23 | + This module is used to store the hidden state of the policy. It currently supports GRU and LSTM. |
18 | 24 | """ |
19 | 25 |
|
20 | 26 | def __init__(self, input_size: int, hidden_dim: int = 256, num_layers: int = 1, type: str = "lstm") -> None: |
21 | 27 | super().__init__() |
22 | 28 | rnn_cls = nn.GRU if type.lower() == "gru" else nn.LSTM |
23 | 29 | self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_dim, num_layers=num_layers) |
24 | | - self.hidden_states = None |
| 30 | + self.hidden_state = None |
25 | 31 |
|
26 | 32 | def forward( |
27 | 33 | self, |
28 | 34 | input: torch.Tensor, |
29 | 35 | masks: torch.Tensor | None = None, |
30 | | - hidden_states: torch.Tensor | tuple[torch.Tensor, ...] | None = None, |
| 36 | + hidden_state: HiddenState = None, |
31 | 37 | ) -> torch.Tensor: |
32 | 38 | batch_mode = masks is not None |
33 | 39 | if batch_mode: |
34 | 40 | # Batch mode needs saved hidden states |
35 | | - if hidden_states is None: |
| 41 | + if hidden_state is None: |
36 | 42 | raise ValueError("Hidden states not passed to memory module during policy update") |
37 | | - out, _ = self.rnn(input, hidden_states) |
| 43 | + out, _ = self.rnn(input, hidden_state) |
38 | 44 | out = unpad_trajectories(out, masks) |
39 | 45 | else: |
40 | | - # Inference/distillation mode uses hidden states of last step |
41 | | - out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states) |
| 46 | + # Inference/distillation mode uses hidden state of last step |
| 47 | + out, self.hidden_state = self.rnn(input.unsqueeze(0), self.hidden_state) |
42 | 48 | return out |
43 | 49 |
|
44 | | - def reset( |
45 | | - self, dones: torch.Tensor | None = None, hidden_states: torch.Tensor | tuple[torch.Tensor, ...] | None = None |
46 | | - ) -> None: |
47 | | - if dones is None: # Reset hidden states |
48 | | - if hidden_states is None: |
49 | | - self.hidden_states = None |
| 50 | + def reset(self, dones: torch.Tensor | None = None, hidden_state: HiddenState = None) -> None: |
| 51 | + if dones is None: # Reset hidden state |
| 52 | + if hidden_state is None: |
| 53 | + self.hidden_state = None |
50 | 54 | else: |
51 | | - self.hidden_states = hidden_states |
52 | | - elif self.hidden_states is not None: # Reset hidden states of done environments |
53 | | - if hidden_states is None: |
54 | | - if isinstance(self.hidden_states, tuple): # Tuple in case of LSTM |
55 | | - for hidden_state in self.hidden_states: |
| 55 | + self.hidden_state = hidden_state |
| 56 | + elif self.hidden_state is not None: # Reset hidden state of done environments |
| 57 | + if hidden_state is None: |
| 58 | + if isinstance(self.hidden_state, tuple): # Tuple in case of LSTM |
| 59 | + for hidden_state in self.hidden_state: |
56 | 60 | hidden_state[..., dones == 1, :] = 0.0 |
57 | 61 | else: |
58 | | - self.hidden_states[..., dones == 1, :] = 0.0 |
| 62 | + self.hidden_state[..., dones == 1, :] = 0.0 |
59 | 63 | else: |
60 | 64 | NotImplementedError( |
61 | | - "Resetting hidden states of done environments with custom hidden states is not implemented" |
| 65 | + "Resetting the hidden state of done environments with a custom hidden state is not implemented" |
62 | 66 | ) |
63 | 67 |
|
64 | | - def detach_hidden_states(self, dones: torch.Tensor | None = None) -> None: |
65 | | - if self.hidden_states is not None: |
66 | | - if dones is None: # Detach all hidden states |
67 | | - if isinstance(self.hidden_states, tuple): # Tuple in case of LSTM |
68 | | - self.hidden_states = tuple(hidden_state.detach() for hidden_state in self.hidden_states) |
| 68 | + def detach_hidden_state(self, dones: torch.Tensor | None = None) -> None: |
| 69 | + if self.hidden_state is not None: |
| 70 | + if dones is None: # Detach hidden state |
| 71 | + if isinstance(self.hidden_state, tuple): # Tuple in case of LSTM |
| 72 | + self.hidden_state = tuple(hidden_state.detach() for hidden_state in self.hidden_state) |
69 | 73 | else: |
70 | | - self.hidden_states = self.hidden_states.detach() |
71 | | - else: # Detach hidden states of done environments |
72 | | - if isinstance(self.hidden_states, tuple): # Tuple in case of LSTM |
73 | | - for hidden_state in self.hidden_states: |
| 74 | + self.hidden_state = self.hidden_state.detach() |
| 75 | + else: # Detach hidden state of done environments |
| 76 | + if isinstance(self.hidden_state, tuple): # Tuple in case of LSTM |
| 77 | + for hidden_state in self.hidden_state: |
74 | 78 | hidden_state[..., dones == 1, :] = hidden_state[..., dones == 1, :].detach() |
75 | 79 | else: |
76 | | - self.hidden_states[..., dones == 1, :] = self.hidden_states[..., dones == 1, :].detach() |
| 80 | + self.hidden_state[..., dones == 1, :] = self.hidden_state[..., dones == 1, :].detach() |
0 commit comments