@@ -231,7 +231,7 @@ def _get_samples(
231231 return RecurrentRolloutBufferSamples (
232232 # (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim)
233233 observations = self .pad (self .observations [batch_inds ]).reshape ((padded_batch_size , * self .obs_shape )),
234- actions = self .pad (self .actions [batch_inds ]).reshape ((padded_batch_size ,) + self .actions .shape [1 :]),
234+ actions = self .pad (self .actions [batch_inds ]).reshape ((padded_batch_size , * self .actions .shape [1 :]) ),
235235 old_values = self .pad_and_flatten (self .values [batch_inds ]),
236236 old_log_prob = self .pad_and_flatten (self .log_probs [batch_inds ]),
237237 advantages = self .pad_and_flatten (self .advantages [batch_inds ]),
@@ -374,7 +374,7 @@ def _get_samples(
374374
375375 return RecurrentDictRolloutBufferSamples (
376376 observations = observations ,
377- actions = self .pad (self .actions [batch_inds ]).reshape ((padded_batch_size ,) + self .actions .shape [1 :]),
377+ actions = self .pad (self .actions [batch_inds ]).reshape ((padded_batch_size , * self .actions .shape [1 :]) ),
378378 old_values = self .pad_and_flatten (self .values [batch_inds ]),
379379 old_log_prob = self .pad_and_flatten (self .log_probs [batch_inds ]),
380380 advantages = self .pad_and_flatten (self .advantages [batch_inds ]),
0 commit comments