Skip to content

Commit 07212f2

Browse files
committed
Improve sampling API
1 parent 540b006 commit 07212f2

File tree

5 files changed

+165
-126
lines changed

5 files changed

+165
-126
lines changed

cebra/data/base.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"""Base classes for datasets and loaders."""
2323

2424
import abc
25+
from typing import Iterator
2526

2627
import literate_dataclasses as dataclasses
2728
import torch
@@ -254,19 +255,25 @@ def __post_init__(self):
254255
raise ValueError(
255256
f"Batch size has to be None, or a non-negative value. Got {self.batch_size}."
256257
)
258+
if self.num_negatives is not None and self.num_negatives <= 0:
259+
raise ValueError(
260+
f"Number of negatives has to be None, or a non-negative value. Got {self.num_negatives}."
261+
)
262+
263+
if self.num_negatives is None:
264+
self.num_negatives = self.batch_size
257265

258266
def __len__(self):
259267
"""The number of batches returned when calling as an iterator."""
260268
return self.num_steps
261269

262-
def __iter__(self) -> Batch:
270+
def __iter__(self) -> Iterator[Batch]:
263271
for _ in range(len(self)):
264-
index = self.get_indices(num_samples=self.batch_size,
265-
num_negatives=self.num_negatives)
272+
index = self.get_indices()
266273
yield self.dataset.load_batch(index)
267274

268275
@abc.abstractmethod
269-
def get_indices(self, num_samples: int, num_negatives: int = None):
276+
def get_indices(self):
270277
"""Sample and return the specified number of indices.
271278
272279
The elements of the returned `BatchIndex` will be used to index the

cebra/data/multi_session.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,7 @@ def __post_init__(self):
160160

161161
# NOTE(stes): In the longer run, we need to unify the API here; the num_samples argument
162162
# is not used in the multi-session case, which is different to the single session samples.
163-
def get_indices(self,
164-
num_samples: int,
165-
num_negatives: int = None) -> List[BatchIndex]:
163+
def get_indices(self) -> List[BatchIndex]:
166164
ref_idx = self.sampler.sample_prior(self.batch_size)
167165
neg_idx = self.sampler.sample_prior(self.num_negatives)
168166
pos_idx, idx, idx_rev = self.sampler.sample_conditional(ref_idx)
@@ -238,9 +236,14 @@ def __post_init__(self):
238236
self.sampler = cebra.distributions.UnifiedSampler(
239237
self.dataset, self.time_offset)
240238

241-
def get_indices(self,
242-
num_samples: int,
243-
num_negatives: int = None) -> BatchIndex:
239+
if self.batch_size < 2:
240+
raise ValueError("UnifiedLoader does not support batch_size < 2.")
241+
242+
if self.num_negatives < 2:
243+
raise ValueError(
244+
"UnifiedLoader does not support num_negatives < 2.")
245+
246+
def get_indices(self) -> BatchIndex:
244247
"""Sample and return the specified number of indices.
245248
246249
The elements of the returned ``BatchIndex`` will be used to index the

cebra/data/multiobjective.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@
2020
# limitations under the License.
2121
#
2222

23+
from typing import Iterator
24+
2325
import literate_dataclasses as dataclasses
2426

2527
import cebra.data as cebra_data
2628
import cebra.distributions
29+
from cebra.data.datatypes import Batch
2730
from cebra.data.datatypes import BatchIndex
2831
from cebra.distributions.continuous import Prior
2932

@@ -71,9 +74,9 @@ def __post_init__(self):
7174
def add_config(self, config):
7275
self.labels.append(config['label'])
7376

74-
def get_indices(self, num_samples: int, num_negatives: int = None):
77+
def get_indices(self) -> BatchIndex:
7578
if self.sampling_mode_supervised == "ref_shared":
76-
reference_idx = self.prior.sample_prior(num_samples)
79+
reference_idx = self.prior.sample_prior(self.batch_size)
7780
else:
7881
raise ValueError(
7982
f"Sampling mode {self.sampling_mode_supervised} is not implemented."
@@ -87,9 +90,9 @@ def get_indices(self, num_samples: int, num_negatives: int = None):
8790

8891
return batch_index
8992

90-
def __iter__(self):
93+
def __iter__(self) -> Iterator[Batch]:
9194
for _ in range(len(self)):
92-
index = self.get_indices(num_samples=self.batch_size)
95+
index = self.get_indices()
9396
yield self.dataset.load_batch_supervised(index, self.labels)
9497

9598

@@ -142,16 +145,14 @@ def add_config(self, config):
142145

143146
self.distributions.append(distribution)
144147

145-
def get_indices(self, num_samples: int, num_negatives: int = None):
148+
def get_indices(self) -> BatchIndex:
146149
"""Sample and return the specified number of indices."""
147150

148-
if num_negatives is None:
149-
num_negatives = num_samples
150-
151151
if self.sampling_mode_contrastive == "refneg_shared":
152-
ref_and_neg = self.prior.sample_prior(num_samples + num_negatives)
153-
reference_idx = ref_and_neg[:num_samples]
154-
negative_idx = ref_and_neg[num_samples:]
152+
ref_and_neg = self.prior.sample_prior(self.batch_size +
153+
self.num_negatives)
154+
reference_idx = ref_and_neg[:self.batch_size]
155+
negative_idx = ref_and_neg[self.batch_size:]
155156

156157
positives_idx = []
157158
for distribution in self.distributions:
@@ -172,6 +173,5 @@ def get_indices(self, num_samples: int, num_negatives: int = None):
172173

173174
def __iter__(self):
174175
for _ in range(len(self)):
175-
index = self.get_indices(num_samples=self.batch_size,
176-
num_negatives=self.num_negatives)
176+
index = self.get_indices()
177177
yield self.dataset.load_batch_contrastive(index)

cebra/data/single_session.py

Lines changed: 30 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import abc
2929
import warnings
30+
from typing import Iterator
3031

3132
import literate_dataclasses as dataclasses
3233
import torch
@@ -138,9 +139,7 @@ def _init_distribution(self):
138139
f"Invalid choice of prior distribution. Got '{self.prior}', but "
139140
f"only accept 'uniform' or 'empirical' as potential values.")
140141

141-
def get_indices(self,
142-
num_samples: int,
143-
num_negatives: int = None) -> BatchIndex:
142+
def get_indices(self) -> BatchIndex:
144143
"""Samples indices for reference, positive and negative examples.
145144
146145
The reference samples will be sampled from the empirical or uniform prior
@@ -161,13 +160,10 @@ def get_indices(self,
161160
Returns:
162161
Indices for reference, positive and negatives samples.
163162
"""
164-
if num_negatives is None:
165-
num_negatives = num_samples
166-
167-
reference_idx = self.distribution.sample_prior(num_samples +
168-
num_negatives)
169-
negative_idx = reference_idx[num_samples:]
170-
reference_idx = reference_idx[:num_samples]
163+
reference_idx = self.distribution.sample_prior(self.batch_size +
164+
self.num_negatives)
165+
negative_idx = reference_idx[self.batch_size:]
166+
reference_idx = reference_idx[:self.batch_size]
171167
reference = self.index[reference_idx]
172168
positive_idx = self.distribution.sample_conditional(reference)
173169
return BatchIndex(reference=reference_idx,
@@ -253,9 +249,7 @@ def _init_distribution(self):
253249
else:
254250
raise ValueError(self.conditional)
255251

256-
def get_indices(self,
257-
num_samples: int,
258-
num_negatives: int = None) -> BatchIndex:
252+
def get_indices(self) -> BatchIndex:
259253
"""Samples indices for reference, positive and negative examples.
260254
261255
The reference and negative samples will be sampled uniformly from
@@ -271,13 +265,10 @@ def get_indices(self,
271265
Returns:
272266
Indices for reference, positive and negatives samples.
273267
"""
274-
if num_negatives is None:
275-
num_negatives = num_samples
276-
277-
reference_idx = self.distribution.sample_prior(num_samples +
278-
num_negatives)
279-
negative_idx = reference_idx[num_samples:]
280-
reference_idx = reference_idx[:num_samples]
268+
reference_idx = self.distribution.sample_prior(self.batch_size +
269+
self.num_negatives)
270+
negative_idx = reference_idx[self.batch_size:]
271+
reference_idx = reference_idx[:self.batch_size]
281272
positive_idx = self.distribution.sample_conditional(reference_idx)
282273
return BatchIndex(reference=reference_idx,
283274
positive=positive_idx,
@@ -318,9 +309,7 @@ def __post_init__(self):
318309
continuous=self.cindex,
319310
time_delta=self.time_offset)
320311

321-
def get_indices(self,
322-
num_samples: int,
323-
num_negatives: int = None) -> BatchIndex:
312+
def get_indices(self) -> BatchIndex:
324313
"""Samples indices for reference, positive and negative examples.
325314
326315
The reference and negative samples will be sampled uniformly from
@@ -344,13 +333,10 @@ def get_indices(self,
344333
class.
345334
- Sample the negatives with matching discrete variable
346335
"""
347-
if num_negatives is None:
348-
num_negatives = num_samples
349-
350-
reference_idx = self.distribution.sample_prior(num_samples +
351-
num_negatives)
352-
negative_idx = reference_idx[num_samples:]
353-
reference_idx = reference_idx[:num_samples]
336+
reference_idx = self.distribution.sample_prior(self.batch_size +
337+
self.num_negatives)
338+
negative_idx = reference_idx[self.batch_size:]
339+
reference_idx = reference_idx[:self.batch_size]
354340
return BatchIndex(
355341
reference=reference_idx,
356342
negative=negative_idx,
@@ -443,9 +429,7 @@ def _init_time_distribution(self):
443429
else:
444430
raise ValueError
445431

446-
def get_indices(self,
447-
num_samples: int,
448-
num_negatives: int = None) -> BatchIndex:
432+
def get_indices(self) -> BatchIndex:
449433
"""Samples indices for reference, positive and negative examples.
450434
451435
The reference and negative samples will be sampled uniformly from
@@ -469,13 +453,10 @@ def get_indices(self,
469453
Add the ``empirical`` vs. ``discrete`` sampling modes to this
470454
class.
471455
"""
472-
if num_negatives is None:
473-
num_negatives = num_samples
474-
475-
reference_idx = self.time_distribution.sample_prior(num_samples +
476-
num_negatives)
477-
negative_idx = reference_idx[num_samples:]
478-
reference_idx = reference_idx[:num_samples]
456+
reference_idx = self.time_distribution.sample_prior(self.batch_size +
457+
self.num_negatives)
458+
negative_idx = reference_idx[self.batch_size:]
459+
reference_idx = reference_idx[:self.batch_size]
479460
behavior_positive_idx = self.behavior_distribution.sample_conditional(
480461
reference_idx)
481462
time_positive_idx = self.time_distribution.sample_conditional(
@@ -493,13 +474,18 @@ class FullDataLoader(ContinuousDataLoader):
493474

494475
def __post_init__(self):
495476
super().__post_init__()
496-
self.batch_size = None
477+
478+
if self.batch_size is not None:
479+
raise ValueError("Batch size cannot be set for FullDataLoader.")
480+
if self.num_negatives is not None:
481+
raise ValueError(
482+
"Number of negatives cannot be set for FullDataLoader.")
497483

498484
@property
499485
def offset(self):
500486
return self.dataset.offset
501487

502-
def get_indices(self, num_samples=None, num_negatives=None) -> BatchIndex:
488+
def get_indices(self) -> BatchIndex:
503489
"""Samples indices for reference, positive and negative examples.
504490
505491
The reference indices are all available (valid, according to the
@@ -519,8 +505,6 @@ def get_indices(self, num_samples=None, num_negatives=None) -> BatchIndex:
519505
Add the ``empirical`` vs. ``discrete`` sampling modes to this
520506
class.
521507
"""
522-
assert num_samples is None
523-
assert num_negatives is None
524508

525509
reference_idx = torch.arange(
526510
self.offset.left,
@@ -534,7 +518,6 @@ def get_indices(self, num_samples=None, num_negatives=None) -> BatchIndex:
534518
positive=positive_idx,
535519
negative=negative_idx)
536520

537-
def __iter__(self):
521+
def __iter__(self) -> Iterator[BatchIndex]:
538522
for _ in range(len(self)):
539-
index = self.get_indices(num_samples=self.batch_size)
540-
yield index
523+
yield self.get_indices()

0 commit comments

Comments
 (0)