Skip to content

Commit 723bcfb

Browse files
committed
Add support to sample more negatives
1 parent e982248 commit 723bcfb

File tree

4 files changed

+70
-20
lines changed

4 files changed

+70
-20
lines changed

cebra/data/base.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,12 @@ class Loader(abc.ABC, cebra.io.HasDevice):
239239
batch_size: int = dataclasses.field(default=None,
240240
doc="""The total batch size.""")
241241

242+
num_negatives: int = dataclasses.field(
243+
default=None,
244+
doc="""The number of negative samples to draw for each reference.
245+
If not specified, the batch size is used."""
246+
)
247+
242248
def __post_init__(self):
243249
if self.num_steps is None or self.num_steps <= 0:
244250
raise ValueError(
@@ -255,11 +261,12 @@ def __len__(self):
255261

256262
def __iter__(self) -> Batch:
257263
for _ in range(len(self)):
258-
index = self.get_indices(num_samples=self.batch_size)
264+
index = self.get_indices(num_samples=self.batch_size,
265+
num_negatives=self.num_negatives)
259266
yield self.dataset.load_batch(index)
260267

261268
@abc.abstractmethod
262-
def get_indices(self, num_samples: int):
269+
def get_indices(self, num_samples: int, num_negatives: int = None):
263270
"""Sample and return the specified number of indices.
264271
265272
The elements of the returned `BatchIndex` will be used to index the
@@ -271,5 +278,10 @@ def get_indices(self, num_samples: int):
271278
272279
Returns:
273280
batch indices for the reference, positive and negative sample.
281+
282+
283+
Note:
284+
From version 0.7.0 onwards, `num_negatives` parameter was added to allow
285+
specifying a different number of negative samples compared to the batch size.
274286
"""
275287
raise NotImplementedError()

cebra/data/multi_session.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,14 @@ def __post_init__(self):
155155
super().__post_init__()
156156
self.sampler = cebra.distributions.MultisessionSampler(
157157
self.dataset, self.time_offset)
158+
if self.num_negatives is None:
159+
self.num_negatives = self.batch_size
158160

159-
def get_indices(self, num_samples: int) -> List[BatchIndex]:
161+
# NOTE(stes): In the longer run, we need to unify the API here; the num_samples argument
162+
# is not used in the multi-session case, which is different to the single session samples.
163+
def get_indices(self, num_samples) -> List[BatchIndex]:
160164
ref_idx = self.sampler.sample_prior(self.batch_size)
161-
neg_idx = self.sampler.sample_prior(self.batch_size)
165+
neg_idx = self.sampler.sample_prior(self.num_negatives)
162166
pos_idx, idx, idx_rev = self.sampler.sample_conditional(ref_idx)
163167

164168
ref_idx = torch.from_numpy(ref_idx)
@@ -251,7 +255,7 @@ def get_indices(self, num_samples: int) -> BatchIndex:
251255
Batch indices for the reference, positive and negative samples.
252256
"""
253257
ref_idx = self.sampler.sample_prior(self.batch_size)
254-
neg_idx = self.sampler.sample_prior(self.batch_size)
258+
neg_idx = self.sampler.sample_prior(self.num_negatives)
255259

256260
pos_idx = self.sampler.sample_conditional(ref_idx)
257261

cebra/data/multiobjective.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __post_init__(self):
7171
def add_config(self, config):
7272
self.labels.append(config['label'])
7373

74-
def get_indices(self, num_samples: int):
74+
def get_indices(self, num_samples: int, num_negatives: int = None):
7575
if self.sampling_mode_supervised == "ref_shared":
7676
reference_idx = self.prior.sample_prior(num_samples)
7777
else:
@@ -142,11 +142,14 @@ def add_config(self, config):
142142

143143
self.distributions.append(distribution)
144144

145-
def get_indices(self, num_samples: int):
145+
def get_indices(self, num_samples: int, num_negatives: int = None):
146146
"""Sample and return the specified number of indices."""
147147

148+
if num_negatives is None:
149+
num_negatives = num_samples
150+
148151
if self.sampling_mode_contrastive == "refneg_shared":
149-
ref_and_neg = self.prior.sample_prior(num_samples * 2)
152+
ref_and_neg = self.prior.sample_prior(num_samples + num_negatives)
150153
reference_idx = ref_and_neg[:num_samples]
151154
negative_idx = ref_and_neg[num_samples:]
152155

@@ -169,5 +172,6 @@ def get_indices(self, num_samples: int):
169172

170173
def __iter__(self):
171174
for _ in range(len(self)):
172-
index = self.get_indices(num_samples=self.batch_size)
175+
index = self.get_indices(num_samples=self.batch_size,
176+
num_negatives=self.num_negatives)
173177
yield self.dataset.load_batch_contrastive(index)

cebra/data/single_session.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,9 @@ def _init_distribution(self):
138138
f"Invalid choice of prior distribution. Got '{self.prior}', but "
139139
f"only accept 'uniform' or 'empirical' as potential values.")
140140

141-
def get_indices(self, num_samples: int) -> BatchIndex:
141+
def get_indices(self,
142+
num_samples: int,
143+
num_negatives: int = None) -> BatchIndex:
142144
"""Samples indices for reference, positive and negative examples.
143145
144146
The reference samples will be sampled from the empirical or uniform prior
@@ -154,11 +156,16 @@ def get_indices(self, num_samples: int) -> BatchIndex:
154156
Args:
155157
num_samples: The number of samples (batch size) of the returned
156158
:py:class:`cebra.data.datatypes.BatchIndex`.
159+
num_negatives: The number of negative samples. If None, defaults to num_samples.
157160
158161
Returns:
159162
Indices for reference, positive and negatives samples.
160163
"""
161-
reference_idx = self.distribution.sample_prior(num_samples * 2)
164+
if num_negatives is None:
165+
num_negatives = num_samples
166+
167+
reference_idx = self.distribution.sample_prior(num_samples +
168+
num_negatives)
162169
negative_idx = reference_idx[num_samples:]
163170
reference_idx = reference_idx[:num_samples]
164171
reference = self.index[reference_idx]
@@ -246,7 +253,9 @@ def _init_distribution(self):
246253
else:
247254
raise ValueError(self.conditional)
248255

249-
def get_indices(self, num_samples: int) -> BatchIndex:
256+
def get_indices(self,
257+
num_samples: int,
258+
num_negatives: int = None) -> BatchIndex:
250259
"""Samples indices for reference, positive and negative examples.
251260
252261
The reference and negative samples will be sampled uniformly from
@@ -262,7 +271,11 @@ def get_indices(self, num_samples: int) -> BatchIndex:
262271
Returns:
263272
Indices for reference, positive and negatives samples.
264273
"""
265-
reference_idx = self.distribution.sample_prior(num_samples * 2)
274+
if num_negatives is None:
275+
num_negatives = num_samples
276+
277+
reference_idx = self.distribution.sample_prior(num_samples +
278+
num_negatives)
266279
negative_idx = reference_idx[num_samples:]
267280
reference_idx = reference_idx[:num_samples]
268281
positive_idx = self.distribution.sample_conditional(reference_idx)
@@ -305,7 +318,9 @@ def __post_init__(self):
305318
continuous=self.cindex,
306319
time_delta=self.time_offset)
307320

308-
def get_indices(self, num_samples: int) -> BatchIndex:
321+
def get_indices(self,
322+
num_samples: int,
323+
num_negatives: int = None) -> BatchIndex:
309324
"""Samples indices for reference, positive and negative examples.
310325
311326
The reference and negative samples will be sampled uniformly from
@@ -319,6 +334,7 @@ def get_indices(self, num_samples: int) -> BatchIndex:
319334
Args:
320335
num_samples: The number of samples (batch size) of the returned
321336
:py:class:`cebra.data.datatypes.BatchIndex`.
337+
num_negatives: The number of negative samples. If None, defaults to num_samples.
322338
323339
Returns:
324340
Indices for reference, positive and negatives samples.
@@ -328,10 +344,16 @@ def get_indices(self, num_samples: int) -> BatchIndex:
328344
class.
329345
- Sample the negatives with matching discrete variable
330346
"""
331-
reference_idx = self.distribution.sample_prior(num_samples)
347+
if num_negatives is None:
348+
num_negatives = num_samples
349+
350+
reference_idx = self.distribution.sample_prior(num_samples +
351+
num_negatives)
352+
negative_idx = reference_idx[num_samples:]
353+
reference_idx = reference_idx[:num_samples]
332354
return BatchIndex(
333355
reference=reference_idx,
334-
negative=self.distribution.sample_prior(num_samples),
356+
negative=negative_idx,
335357
positive=self.distribution.sample_conditional(reference_idx),
336358
)
337359

@@ -421,11 +443,13 @@ def _init_time_distribution(self):
421443
else:
422444
raise ValueError
423445

424-
def get_indices(self, num_samples: int) -> BatchIndex:
446+
def get_indices(self,
447+
num_samples: int,
448+
num_negatives: int = None) -> BatchIndex:
425449
"""Samples indices for reference, positive and negative examples.
426450
427451
The reference and negative samples will be sampled uniformly from
428-
all available time steps, and a total of ``2*num_samples`` will be
452+
all available time steps, and a total of ``num_samples + num_negatives`` will be
429453
returned for both.
430454
431455
For the positive samples, ``num_samples`` are sampled according to the
@@ -436,6 +460,7 @@ def get_indices(self, num_samples: int) -> BatchIndex:
436460
Args:
437461
num_samples: The number of samples (batch size) of the returned
438462
:py:class:`cebra.data.datatypes.BatchIndex`.
463+
num_negatives: The number of negative samples. If None, defaults to num_samples.
439464
440465
Returns:
441466
Indices for reference, positive and negatives samples.
@@ -444,7 +469,11 @@ def get_indices(self, num_samples: int) -> BatchIndex:
444469
Add the ``empirical`` vs. ``discrete`` sampling modes to this
445470
class.
446471
"""
447-
reference_idx = self.time_distribution.sample_prior(num_samples * 2)
472+
if num_negatives is None:
473+
num_negatives = num_samples
474+
475+
reference_idx = self.time_distribution.sample_prior(num_samples +
476+
num_negatives)
448477
negative_idx = reference_idx[num_samples:]
449478
reference_idx = reference_idx[:num_samples]
450479
behavior_positive_idx = self.behavior_distribution.sample_conditional(
@@ -470,7 +499,7 @@ def __post_init__(self):
470499
def offset(self):
471500
return self.dataset.offset
472501

473-
def get_indices(self, num_samples=None) -> BatchIndex:
502+
def get_indices(self, num_samples=None, num_negatives=None) -> BatchIndex:
474503
"""Samples indices for reference, positive and negative examples.
475504
476505
The reference indices are all available (valid, according to the
@@ -491,6 +520,7 @@ def get_indices(self, num_samples=None) -> BatchIndex:
491520
class.
492521
"""
493522
assert num_samples is None
523+
assert num_negatives is None
494524

495525
reference_idx = torch.arange(
496526
self.offset.left,

0 commit comments

Comments
 (0)