diff --git a/mushroom_rl/utils/episodes.py b/mushroom_rl/utils/episodes.py index 7bfd5021..f73417cf 100644 --- a/mushroom_rl/utils/episodes.py +++ b/mushroom_rl/utils/episodes.py @@ -41,18 +41,21 @@ def _get_episode_idx(last, backend=None): if backend is None: backend = ArrayBackend.get_array_backend_from(last) + last = backend.copy(last) + last[-1] = True + n_episodes = last.sum() last_idx = backend.nonzero(last).squeeze() first_steps = backend.from_list([last_idx[0] + 1]) - if hasattr(last, 'device'): + if backend.get_backend_name() == 'torch': first_steps = first_steps.to(last.device) episode_steps = backend.concatenate([first_steps, last_idx[1:] - last_idx[:-1]]) max_episode_steps = episode_steps.max() - start_idx = backend.concatenate([backend.zeros(1, dtype=int, device=last.device if hasattr(last, 'device') else None), last_idx[:-1] + 1]) + start_idx = backend.concatenate([backend.zeros(1, dtype=int, device=last.device if backend.get_backend_name() == 'torch' else None), last_idx[:-1] + 1]) range_n_episodes = backend.arange(0, n_episodes, dtype=int) range_len = backend.arange(0, last.shape[0], dtype=int) - if hasattr(last, 'device'): + if backend.get_backend_name() == 'torch': range_n_episodes = range_n_episodes.to(last.device) range_len = range_len.to(last.device) row_idx = backend.repeat(range_n_episodes, episode_steps)