diff --git a/the_well/data/datasets.py b/the_well/data/datasets.py index ccb6ab62..1254a9c0 100755 --- a/the_well/data/datasets.py +++ b/the_well/data/datasets.py @@ -389,8 +389,9 @@ def _build_restriction_set( - 1 ) if traj in trajectories_sampled: + n_windows = self.n_windows_per_trajectory[file_index] global_indices = global_indices + list( - range(0, self.n_windows_per_trajectory[file_index]) + range(current_index, current_index + n_windows) ) current_index += self.n_windows_per_trajectory[file_index] global_indices = np.array(global_indices)