diff --git a/pyproject.toml b/pyproject.toml index 2f19c2c0..3745357d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ exclude = """(?x)( | sb3_contrib/ars/ars.py$ | sb3_contrib/common/recurrent/policies.py$ | sb3_contrib/common/recurrent/buffers.py$ + | sb3_contrib/common/torch_layers.py$ | tests/test_train_eval_mode.py$ )""" diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 5386db11..8cdc525a 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -228,6 +228,11 @@ def _get_samples( lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous()) lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous()) + # See issue GH#284 + episode_starts = np.logical_or(self.episode_starts[batch_inds], env_change[batch_inds]).astype( + self.episode_starts.dtype + ) + return RecurrentRolloutBufferSamples( # (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim) observations=self.pad(self.observations[batch_inds]).reshape((padded_batch_size, *self.obs_shape)), @@ -237,7 +242,7 @@ def _get_samples( advantages=self.pad_and_flatten(self.advantages[batch_inds]), returns=self.pad_and_flatten(self.returns[batch_inds]), lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), - episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), + episode_starts=self.pad_and_flatten(episode_starts), mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), ) @@ -372,6 +377,10 @@ def _get_samples( observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()} observations = {key: obs.reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items()} + episode_starts = np.logical_or(self.episode_starts[batch_inds], env_change[batch_inds]).astype( + self.episode_starts.dtype + ) + return RecurrentDictRolloutBufferSamples( observations=observations, actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]), @@ -380,6 +389,6 @@ def _get_samples( advantages=self.pad_and_flatten(self.advantages[batch_inds]), returns=self.pad_and_flatten(self.returns[batch_inds]), lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), - episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), + episode_starts=self.pad_and_flatten(episode_starts), mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), )