Skip to content

Commit df23975

Browse files
Fix _get_episode_idx (#161)
* Fix last element of last
1 parent b2fab67 commit df23975

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

mushroom_rl/utils/episodes.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,21 @@ def _get_episode_idx(last, backend=None):
4141
if backend is None:
4242
backend = ArrayBackend.get_array_backend_from(last)
4343

44+
last = backend.copy(last)
45+
last[-1] = True
46+
4447
n_episodes = last.sum()
4548
last_idx = backend.nonzero(last).squeeze()
4649
first_steps = backend.from_list([last_idx[0] + 1])
47-
if hasattr(last, 'device'):
50+
if backend.get_backend_name() == 'torch':
4851
first_steps = first_steps.to(last.device)
4952
episode_steps = backend.concatenate([first_steps, last_idx[1:] - last_idx[:-1]])
5053
max_episode_steps = episode_steps.max()
5154

52-
start_idx = backend.concatenate([backend.zeros(1, dtype=int, device=last.device if hasattr(last, 'device') else None), last_idx[:-1] + 1])
55+
start_idx = backend.concatenate([backend.zeros(1, dtype=int, device=last.device if backend.get_backend_name() == 'torch' else None), last_idx[:-1] + 1])
5356
range_n_episodes = backend.arange(0, n_episodes, dtype=int)
5457
range_len = backend.arange(0, last.shape[0], dtype=int)
55-
if hasattr(last, 'device'):
58+
if backend.get_backend_name() == 'torch':
5659
range_n_episodes = range_n_episodes.to(last.device)
5760
range_len = range_len.to(last.device)
5861
row_idx = backend.repeat(range_n_episodes, episode_steps)

0 commit comments

Comments
 (0)