@@ -41,18 +41,21 @@ def _get_episode_idx(last, backend=None):
4141 if backend is None :
4242 backend = ArrayBackend .get_array_backend_from (last )
4343
44+ last = backend .copy (last )
45+ last [- 1 ] = True
46+
4447 n_episodes = last .sum ()
4548 last_idx = backend .nonzero (last ).squeeze ()
4649 first_steps = backend .from_list ([last_idx [0 ] + 1 ])
47- if hasattr ( last , 'device' ) :
50+ if backend . get_backend_name () == 'torch' :
4851 first_steps = first_steps .to (last .device )
4952 episode_steps = backend .concatenate ([first_steps , last_idx [1 :] - last_idx [:- 1 ]])
5053 max_episode_steps = episode_steps .max ()
5154
52- start_idx = backend .concatenate ([backend .zeros (1 , dtype = int , device = last .device if hasattr ( last , 'device' ) else None ), last_idx [:- 1 ] + 1 ])
55+ start_idx = backend .concatenate ([backend .zeros (1 , dtype = int , device = last .device if backend . get_backend_name () == 'torch' else None ), last_idx [:- 1 ] + 1 ])
5356 range_n_episodes = backend .arange (0 , n_episodes , dtype = int )
5457 range_len = backend .arange (0 , last .shape [0 ], dtype = int )
55- if hasattr ( last , 'device' ) :
58+ if backend . get_backend_name () == 'torch' :
5659 range_n_episodes = range_n_episodes .to (last .device )
5760 range_len = range_len .to (last .device )
5861 row_idx = backend .repeat (range_n_episodes , episode_steps )
0 commit comments