Skip to content

Commit 540b006

Browse files
committed
Fix multi-session samplers
1 parent dbabb6e commit 540b006

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

cebra/data/multi_session.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,11 @@ class DiscreteMultiSessionDataLoader(MultiSessionLoader):
198198
# Overwrite sampler with the discrete implementation
199199
# Generalize MultisessionSampler to avoid doing this?
200200
def __post_init__(self):
201+
# NOTE(stes): __post_init__ from superclass is intentionally not called.
201202
self.sampler = cebra.distributions.DiscreteMultisessionSampler(
202203
self.dataset)
204+
if self.num_negatives is None:
205+
self.num_negatives = self.batch_size
203206

204207
@property
205208
def index(self):
@@ -235,7 +238,9 @@ def __post_init__(self):
235238
self.sampler = cebra.distributions.UnifiedSampler(
236239
self.dataset, self.time_offset)
237240

238-
def get_indices(self, num_samples: int) -> BatchIndex:
241+
def get_indices(self,
242+
num_samples: int,
243+
num_negatives: int = None) -> BatchIndex:
239244
"""Sample and return the specified number of indices.
240245
241246
The elements of the returned ``BatchIndex`` will be used to index the

0 commit comments

Comments
 (0)