Skip to content

Commit 6c2d559

Browse files
committed
add sklearn implementation
1 parent 07212f2 commit 6c2d559

File tree

3 files changed

+24
-2
lines changed

3 files changed

+24
-2
lines changed

cebra/data/multi_session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,10 @@ def __post_init__(self):
236236
self.sampler = cebra.distributions.UnifiedSampler(
237237
self.dataset, self.time_offset)
238238

239-
if self.batch_size < 2:
239+
if self.batch_size is not None and self.batch_size < 2:
240240
raise ValueError("UnifiedLoader does not support batch_size < 2.")
241241

242-
if self.num_negatives < 2:
242+
if self.num_negatives is not None and self.num_negatives < 2:
243243
raise ValueError(
244244
"UnifiedLoader does not support num_negatives < 2.")
245245

cebra/integrations/sklearn/cebra.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,9 @@ class CEBRA(TransformerMixin, BaseEstimator):
501501
A Tuple of masking types and their corresponding required masking values. The keys are the
502502
names of the Mask instances and formatting should be ``((key, value), (key, value))``.
503503
|Default:| ``None``.
504+
num_negatives (int):
505+
The number of negative samples to use for training. If ``None``, the number of negative samples
506+
will be set to the batch size. |Default:| ``None``.
504507
505508
Example:
506509
@@ -576,6 +579,7 @@ def __init__(
576579
),
577580
masking_kwargs: Tuple[Tuple[str, Union[float, List[float],
578581
Tuple[float, ...]]], ...] = None,
582+
num_negatives: int = None,
579583
):
580584
self.__dict__.update(locals())
581585

@@ -728,6 +732,7 @@ def _prepare_loader(self, dataset: cebra.data.Dataset, max_iterations: int,
728732
dataset=dataset,
729733
batch_size=self.batch_size,
730734
num_steps=max_iterations,
735+
num_negatives=self.num_negatives,
731736
),
732737
extra_kwargs=dict(
733738
time_offsets=self.time_offsets,

tests/test_sklearn.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,3 +1544,20 @@ def test_last_incomplete_batch_smaller_than_offset():
15441544
model.fit(train.neural, train.continuous)
15451545

15461546
_ = model.transform(train.neural, batch_size=300)
1547+
1548+
1549+
@pytest.mark.parametrize("batch_size,num_negatives", [
1550+
(None, None),
1551+
(100, None),
1552+
(100, 100),
1553+
])
1554+
def test_num_negatives(batch_size, num_negatives):
1555+
train = cebra.data.TensorDataset(neural=np.random.rand(20111, 100),
1556+
continuous=np.random.rand(20111, 2))
1557+
1558+
model = cebra.CEBRA(max_iterations=2,
1559+
batch_size=batch_size,
1560+
num_negatives=num_negatives,
1561+
device="cpu")
1562+
model.fit(train.neural, train.continuous)
1563+
_ = model.transform(train.neural)

0 commit comments

Comments
 (0)