We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 3295221 commit 445d142Copy full SHA for 445d142
mushroom_rl/core/dataset.py
@@ -477,7 +477,7 @@ def compute_J(self, gamma=1.):
477
r_ep = split_episodes(self.last, self.reward)
478
479
if len(r_ep.shape) == 1:
480
- r_ep = r_ep.unsqueeze(0)
+ r_ep = self._array_backend.expand_dims(r_ep, 0)
481
if self._dataset_info.backend == 'torch':
482
js = self._array_backend.zeros(r_ep.shape[0], dtype=r_ep.dtype, device=r_ep.device)
483
else:
0 commit comments