Skip to content

Commit ecd47e9

Browse files
timonmerkstes
authored andcommitted
add positive sampling options for MixedDataLoader
1 parent 7a4d3fc commit ecd47e9

File tree

1 file changed

+45
-14
lines changed

1 file changed

+45
-14
lines changed

cebra/data/single_session.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -268,27 +268,47 @@ class MixedDataLoader(cebra_data.Loader):
268268
1. Positive pairs always share their discrete variable.
269269
2. Positive pairs are drawn only based on their conditional,
270270
not discrete variable.
271+
272+
Args:
273+
conditional (str): The conditional variable for sampling positive pairs. :py:attr:`cebra.CEBRA.conditional`
274+
time_offset (int): :py:attr:`cebra.CEBRA.time_offsets`
275+
positive_sampling (str): either "discrete_variable" (default) or "conditional"
276+
discrete_sampling_prior (str): either "empirical" (default) or "uniform"
271277
"""
272278

273279
conditional: str = dataclasses.field(default="time_delta")
274280
time_offset: int = dataclasses.field(default=10)
281+
positive_sampling: str = dataclasses.field(default="discrete_variable")
282+
discrete_sampling_prior: str = dataclasses.field(default="uniform")
275283

276284
@property
277-
def dindex(self):
278-
# TODO(stes) rename to discrete_index
285+
def discrete_index(self):
279286
return self.dataset.discrete_index
280287

281288
@property
282-
def cindex(self):
283-
# TODO(stes) rename to continuous_index
289+
def continuous_index(self):
284290
return self.dataset.continuous_index
285291

286292
def __post_init__(self):
287293
super().__post_init__()
288-
self.distribution = cebra.distributions.MixedTimeDeltaDistribution(
289-
discrete=self.dindex,
290-
continuous=self.cindex,
291-
time_delta=self.time_offset)
294+
if self.positive_sampling == "conditional":
295+
self.distribution = cebra.distributions.MixedTimeDeltaDistribution(
296+
discrete=self.discrete_index,
297+
continuous=self.continuous_index,
298+
time_delta=self.time_offset)
299+
elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior == "empirical":
300+
self.distribution = cebra.distributions.DiscreteEmpirical(self.discrete_index)
301+
elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior == "uniform":
302+
self.distribution = cebra.distributions.DiscreteUniform(self.discrete_index)
303+
elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior not in ["empirical", "uniform"]:
304+
raise ValueError(
305+
f"Invalid choice of prior distribution. Got '{self.discrete_sampling_prior}', but "
306+
f"only accept 'uniform' or 'empirical' as potential values.")
307+
else:
308+
raise ValueError(
309+
f"Invalid positive sampling mode: "
310+
f"{self.positive_sampling} valid options are "
311+
f"'conditional' or 'discrete_variable'.")
292312

293313
def get_indices(self, num_samples: int) -> BatchIndex:
294314
"""Samples indices for reference, positive and negative examples.
@@ -313,12 +333,23 @@ def get_indices(self, num_samples: int) -> BatchIndex:
313333
class.
314334
- Sample the negatives with matching discrete variable
315335
"""
316-
reference_idx = self.distribution.sample_prior(num_samples)
317-
return BatchIndex(
318-
reference=reference_idx,
319-
negative=self.distribution.sample_prior(num_samples),
320-
positive=self.distribution.sample_conditional(reference_idx),
321-
)
336+
if self.positive_sampling == "conditional":
337+
reference_idx = self.distribution.sample_prior(num_samples)
338+
return BatchIndex(
339+
reference=reference_idx,
340+
negative=self.distribution.sample_prior(num_samples),
341+
positive=self.distribution.sample_conditional(reference_idx),
342+
)
343+
else:
344+
# taken from the DiscreteDataLoader get_indices function
345+
reference_idx = self.distribution.sample_prior(num_samples * 2)
346+
negative_idx = reference_idx[num_samples:]
347+
reference_idx = reference_idx[:num_samples]
348+
reference = self.discrete_index[reference_idx]
349+
positive_idx = self.distribution.sample_conditional(reference)
350+
return BatchIndex(reference=reference_idx,
351+
positive=positive_idx,
352+
negative=negative_idx)
322353

323354

324355
@dataclasses.dataclass

0 commit comments

Comments
 (0)