File tree Expand file tree Collapse file tree 1 file changed +6
-1
lines changed
Expand file tree Collapse file tree 1 file changed +6
-1
lines changed Original file line number Diff line number Diff line change @@ -198,8 +198,11 @@ class DiscreteMultiSessionDataLoader(MultiSessionLoader):
198198 # Overwrite sampler with the discrete implementation
199199 # Generalize MultisessionSampler to avoid doing this?
200200 def __post_init__ (self ):
201+ # NOTE(stes): __post_init__ from superclass is intentionally not called.
201202 self .sampler = cebra .distributions .DiscreteMultisessionSampler (
202203 self .dataset )
204+ if self .num_negatives is None :
205+ self .num_negatives = self .batch_size
203206
204207 @property
205208 def index (self ):
@@ -235,7 +238,9 @@ def __post_init__(self):
235238 self .sampler = cebra .distributions .UnifiedSampler (
236239 self .dataset , self .time_offset )
237240
238- def get_indices (self , num_samples : int ) -> BatchIndex :
241+ def get_indices (self ,
242+ num_samples : int ,
243+ num_negatives : int = None ) -> BatchIndex :
239244 """Sample and return the specified number of indices.
240245
241246 The elements of the returned ``BatchIndex`` will be used to index the
You can’t perform that action at this time.
0 commit comments