Skip to content

Commit cd277d4

Browse files
authored
Fixes bug in ActorCriticRecurrent hidden state reset (#50)
* fix bug in rnn hidden state reset
1 parent 6909a47 commit cd277d4

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

rsl_rl/modules/actor_critic_recurrent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,5 +95,7 @@ def forward(self, input, masks=None, hidden_states=None):
9595

9696
def reset(self, dones=None):
9797
# When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
98+
if self.hidden_states is None:
99+
return
98100
for hidden_state in self.hidden_states:
99-
hidden_state[..., dones, :] = 0.0
101+
hidden_state[..., dones == 1] = 0.0

0 commit comments

Comments
 (0)