@@ -138,7 +138,9 @@ def _init_distribution(self):
138138 f"Invalid choice of prior distribution. Got '{ self .prior } ', but "
139139 f"only accept 'uniform' or 'empirical' as potential values." )
140140
141- def get_indices (self , num_samples : int ) -> BatchIndex :
141+ def get_indices (self ,
142+ num_samples : int ,
143+ num_negatives : int = None ) -> BatchIndex :
142144 """Samples indices for reference, positive and negative examples.
143145
144146 The reference samples will be sampled from the empirical or uniform prior
@@ -154,11 +156,16 @@ def get_indices(self, num_samples: int) -> BatchIndex:
154156 Args:
155157 num_samples: The number of samples (batch size) of the returned
156158 :py:class:`cebra.data.datatypes.BatchIndex`.
159+ num_negatives: The number of negative samples. If None, defaults to num_samples.
157160
158161 Returns:
159162 Indices for reference, positive and negatives samples.
160163 """
161- reference_idx = self .distribution .sample_prior (num_samples * 2 )
164+ if num_negatives is None :
165+ num_negatives = num_samples
166+
167+ reference_idx = self .distribution .sample_prior (num_samples +
168+ num_negatives )
162169 negative_idx = reference_idx [num_samples :]
163170 reference_idx = reference_idx [:num_samples ]
164171 reference = self .index [reference_idx ]
@@ -246,7 +253,9 @@ def _init_distribution(self):
246253 else :
247254 raise ValueError (self .conditional )
248255
249- def get_indices (self , num_samples : int ) -> BatchIndex :
256+ def get_indices (self ,
257+ num_samples : int ,
258+ num_negatives : int = None ) -> BatchIndex :
250259 """Samples indices for reference, positive and negative examples.
251260
252261 The reference and negative samples will be sampled uniformly from
@@ -262,7 +271,11 @@ def get_indices(self, num_samples: int) -> BatchIndex:
262271 Returns:
263272 Indices for reference, positive and negatives samples.
264273 """
265- reference_idx = self .distribution .sample_prior (num_samples * 2 )
274+ if num_negatives is None :
275+ num_negatives = num_samples
276+
277+ reference_idx = self .distribution .sample_prior (num_samples +
278+ num_negatives )
266279 negative_idx = reference_idx [num_samples :]
267280 reference_idx = reference_idx [:num_samples ]
268281 positive_idx = self .distribution .sample_conditional (reference_idx )
@@ -305,7 +318,9 @@ def __post_init__(self):
305318 continuous = self .cindex ,
306319 time_delta = self .time_offset )
307320
308- def get_indices (self , num_samples : int ) -> BatchIndex :
321+ def get_indices (self ,
322+ num_samples : int ,
323+ num_negatives : int = None ) -> BatchIndex :
309324 """Samples indices for reference, positive and negative examples.
310325
311326 The reference and negative samples will be sampled uniformly from
@@ -319,6 +334,7 @@ def get_indices(self, num_samples: int) -> BatchIndex:
319334 Args:
320335 num_samples: The number of samples (batch size) of the returned
321336 :py:class:`cebra.data.datatypes.BatchIndex`.
337+ num_negatives: The number of negative samples. If None, defaults to num_samples.
322338
323339 Returns:
324340 Indices for reference, positive and negatives samples.
@@ -328,10 +344,16 @@ def get_indices(self, num_samples: int) -> BatchIndex:
328344 class.
329345 - Sample the negatives with matching discrete variable
330346 """
331- reference_idx = self .distribution .sample_prior (num_samples )
347+ if num_negatives is None :
348+ num_negatives = num_samples
349+
350+ reference_idx = self .distribution .sample_prior (num_samples +
351+ num_negatives )
352+ negative_idx = reference_idx [num_samples :]
353+ reference_idx = reference_idx [:num_samples ]
332354 return BatchIndex (
333355 reference = reference_idx ,
334- negative = self . distribution . sample_prior ( num_samples ) ,
356+ negative = negative_idx ,
335357 positive = self .distribution .sample_conditional (reference_idx ),
336358 )
337359
@@ -421,11 +443,13 @@ def _init_time_distribution(self):
421443 else :
422444 raise ValueError
423445
424- def get_indices (self , num_samples : int ) -> BatchIndex :
446+ def get_indices (self ,
447+ num_samples : int ,
448+ num_negatives : int = None ) -> BatchIndex :
425449 """Samples indices for reference, positive and negative examples.
426450
427451 The reference and negative samples will be sampled uniformly from
428- all available time steps, and a total of ``2* num_samples`` will be
452+ all available time steps, and a total of ``num_samples + num_negatives `` will be
429453 returned for both.
430454
431455 For the positive samples, ``num_samples`` are sampled according to the
@@ -436,6 +460,7 @@ def get_indices(self, num_samples: int) -> BatchIndex:
436460 Args:
437461 num_samples: The number of samples (batch size) of the returned
438462 :py:class:`cebra.data.datatypes.BatchIndex`.
463+ num_negatives: The number of negative samples. If None, defaults to num_samples.
439464
440465 Returns:
441466 Indices for reference, positive and negatives samples.
@@ -444,7 +469,11 @@ def get_indices(self, num_samples: int) -> BatchIndex:
444469 Add the ``empirical`` vs. ``discrete`` sampling modes to this
445470 class.
446471 """
447- reference_idx = self .time_distribution .sample_prior (num_samples * 2 )
472+ if num_negatives is None :
473+ num_negatives = num_samples
474+
475+ reference_idx = self .time_distribution .sample_prior (num_samples +
476+ num_negatives )
448477 negative_idx = reference_idx [num_samples :]
449478 reference_idx = reference_idx [:num_samples ]
450479 behavior_positive_idx = self .behavior_distribution .sample_conditional (
@@ -470,7 +499,7 @@ def __post_init__(self):
470499 def offset (self ):
471500 return self .dataset .offset
472501
473- def get_indices (self , num_samples = None ) -> BatchIndex :
502+ def get_indices (self , num_samples = None , num_negatives = None ) -> BatchIndex :
474503 """Samples indices for reference, positive and negative examples.
475504
476505 The reference indices are all available (valid, according to the
@@ -491,6 +520,7 @@ def get_indices(self, num_samples=None) -> BatchIndex:
491520 class.
492521 """
493522 assert num_samples is None
523+ assert num_negatives is None
494524
495525 reference_idx = torch .arange (
496526 self .offset .left ,
0 commit comments