@@ -268,27 +268,47 @@ class MixedDataLoader(cebra_data.Loader):
268268 1. Positive pairs always share their discrete variable.
269269 2. Positive pairs are drawn only based on their conditional,
270270 not discrete variable.
271+
272+ Args:
273+ conditional (str): The conditional variable for sampling positive pairs. :py:attr:`cebra.CEBRA.conditional`
274+ time_offset (int): :py:attr:`cebra.CEBRA.time_offsets`
275+ positive_sampling (str): either "discrete_variable" (default) or "conditional"
276+ discrete_sampling_prior (str): either "empirical" (default) or "uniform"
271277 """
272278
273279 conditional : str = dataclasses .field (default = "time_delta" )
274280 time_offset : int = dataclasses .field (default = 10 )
281+ positive_sampling : str = dataclasses .field (default = "discrete_variable" )
282+ discrete_sampling_prior : str = dataclasses .field (default = "uniform" )
275283
276284 @property
277- def dindex (self ):
278- # TODO(stes) rename to discrete_index
285+ def discrete_index (self ):
279286 return self .dataset .discrete_index
280287
281288 @property
282- def cindex (self ):
283- # TODO(stes) rename to continuous_index
289+ def continuous_index (self ):
284290 return self .dataset .continuous_index
285291
286292 def __post_init__ (self ):
287293 super ().__post_init__ ()
288- self .distribution = cebra .distributions .MixedTimeDeltaDistribution (
289- discrete = self .dindex ,
290- continuous = self .cindex ,
291- time_delta = self .time_offset )
294+ if self .positive_sampling == "conditional" :
295+ self .distribution = cebra .distributions .MixedTimeDeltaDistribution (
296+ discrete = self .discrete_index ,
297+ continuous = self .continuous_index ,
298+ time_delta = self .time_offset )
299+ elif self .positive_sampling == "discrete_variable" and self .discrete_sampling_prior == "empirical" :
300+ self .distribution = cebra .distributions .DiscreteEmpirical (self .discrete_index )
301+ elif self .positive_sampling == "discrete_variable" and self .discrete_sampling_prior == "uniform" :
302+ self .distribution = cebra .distributions .DiscreteUniform (self .discrete_index )
303+ elif self .positive_sampling == "discrete_variable" and self .discrete_sampling_prior not in ["empirical" , "uniform" ]:
304+ raise ValueError (
305+ f"Invalid choice of prior distribution. Got '{ self .discrete_sampling_prior } ', but "
306+ f"only accept 'uniform' or 'empirical' as potential values." )
307+ else :
308+ raise ValueError (
309+ f"Invalid positive sampling mode: "
310+ f"{ self .positive_sampling } valid options are "
311+ f"'conditional' or 'discrete_variable'." )
292312
293313 def get_indices (self , num_samples : int ) -> BatchIndex :
294314 """Samples indices for reference, positive and negative examples.
@@ -313,12 +333,23 @@ def get_indices(self, num_samples: int) -> BatchIndex:
313333 class.
314334 - Sample the negatives with matching discrete variable
315335 """
316- reference_idx = self .distribution .sample_prior (num_samples )
317- return BatchIndex (
318- reference = reference_idx ,
319- negative = self .distribution .sample_prior (num_samples ),
320- positive = self .distribution .sample_conditional (reference_idx ),
321- )
336+ if self .positive_sampling == "conditional" :
337+ reference_idx = self .distribution .sample_prior (num_samples )
338+ return BatchIndex (
339+ reference = reference_idx ,
340+ negative = self .distribution .sample_prior (num_samples ),
341+ positive = self .distribution .sample_conditional (reference_idx ),
342+ )
343+ else :
344+ # taken from the DiscreteDataLoader get_indices function
345+ reference_idx = self .distribution .sample_prior (num_samples * 2 )
346+ negative_idx = reference_idx [num_samples :]
347+ reference_idx = reference_idx [:num_samples ]
348+ reference = self .discrete_index [reference_idx ]
349+ positive_idx = self .distribution .sample_conditional (reference )
350+ return BatchIndex (reference = reference_idx ,
351+ positive = positive_idx ,
352+ negative = negative_idx )
322353
323354
324355@dataclasses .dataclass
0 commit comments