Skip to content

Commit 0326fb9

Browse files
MMathisLabstes
andauthored
Update cebra/data/single_session.py
Co-authored-by: Steffen Schneider <[email protected]>
1 parent 8dee8a0 commit 0326fb9

File tree

1 file changed

+1
-9
lines changed

1 file changed

+1
-9
lines changed

cebra/data/single_session.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -360,15 +360,7 @@ def get_indices(self, num_samples: int) -> BatchIndex:
360360
positive=self.distribution.sample_conditional(reference_idx),
361361
)
362362
else:
363-
# taken from the DiscreteDataLoader get_indices function
364-
reference_idx = self.distribution.sample_prior(num_samples * 2)
365-
negative_idx = reference_idx[num_samples:]
366-
reference_idx = reference_idx[:num_samples]
367-
reference = self.discrete_index[reference_idx]
368-
positive_idx = self.distribution.sample_conditional(reference)
369-
return BatchIndex(reference=reference_idx,
370-
positive=positive_idx,
371-
negative=negative_idx)
363+
return self.distribution.get_indices(num_samples)
372364

373365

374366
@dataclasses.dataclass

0 commit comments

Comments
 (0)