diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index e25a6ba8f52..d554a20130d 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -724,9 +724,6 @@ def forward(self, tensordict: TensorDictBase): tensordict_shaped.get(key, default) for key, default in zip(self.in_keys, defaults) ) - batch, steps = value.shape[:2] - device = value.device - dtype = value.dtype # packed sequences do not help to get the accurate last hidden values # if splits is not None: # value = torch.nn.utils.rnn.pack_padded_sequence(value, splits, batch_first=True) @@ -737,8 +734,13 @@ def forward(self, tensordict: TensorDictBase): # When using the recurrent_mode=True option, the lstm can be called from # any intermediate state, hence zeroing should not be done. is_init_expand = expand_as_right(is_init, hidden0) - hidden0 = torch.where(is_init_expand, 0, hidden0) - hidden1 = torch.where(is_init_expand, 0, hidden1) + zeros = torch.zeros_like(hidden0) + hidden0 = torch.where(is_init_expand, zeros, hidden0) + hidden1 = torch.where(is_init_expand, zeros, hidden1) + + batch, steps = value.shape[:2] + device = value.device + dtype = value.dtype val, hidden0, hidden1 = self._lstm( value, batch, steps, device, dtype, hidden0, hidden1