Skip to content

Commit f2af3b6

Browse files
committed
Update GoF computation
1 parent e259e45 commit f2af3b6

File tree

3 files changed

+28
-14
lines changed

3 files changed

+28
-14
lines changed

cebra/integrations/sklearn/cebra.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,13 @@ def num_sessions(self) -> Optional[int]:
596596
"""
597597
return self.num_sessions_
598598

599+
@property
600+
def num_negatives_(self) -> int:
601+
"""The number of negative examples."""
602+
if self.num_negatives is None:
603+
return self.batch_size
604+
return self.num_negatives
605+
599606
@property
600607
def state_dict_(self) -> dict:
601608
return self.solver_.state_dict()

cebra/integrations/sklearn/metrics.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,12 @@ def infonce_loss(
100100
solver.to(cebra_model.device_)
101101
avg_loss = solver.validation(loader=loader, session_id=session_id)
102102
if correct_by_batchsize:
103-
if cebra_model.batch_size is None:
103+
if cebra_model.num_negatives_ is None:
104104
raise ValueError(
105105
"Batch size is None, please provide a model with a batch size to correct the InfoNCE."
106106
)
107107
else:
108-
avg_loss = avg_loss - np.log(cebra_model.batch_size)
108+
avg_loss = avg_loss - np.log(cebra_model.num_negatives_)
109109
return avg_loss
110110

111111

@@ -211,7 +211,7 @@ def infonce_to_goodness_of_fit(
211211
Args:
212212
infonce: The InfoNCE loss, either a single value or an iterable of values.
213213
model: The trained CEBRA model.
214-
batch_size: The batch size used to train the model.
214+
batch_size: The batch size (or number of negatives, if different from the batch size) used to train the model.
215215
num_sessions: The number of sessions used to train the model.
216216
217217
Returns:
@@ -228,19 +228,15 @@ def infonce_to_goodness_of_fit(
228228
)
229229
if not hasattr(model, "state_dict_"):
230230
raise RuntimeError("Fit the CEBRA model first.")
231-
if model.batch_size is None:
231+
if model.num_negatives_ is None:
232232
raise ValueError(
233233
"Computing the goodness of fit is not yet supported for "
234234
"models trained on the full dataset (batchsize = None). ")
235-
batch_size = model.batch_size
235+
batch_size = model.num_negatives_
236236
num_sessions = model.num_sessions_
237237
if num_sessions is None:
238238
num_sessions = 1
239239

240-
if model.batch_size is None:
241-
raise ValueError(
242-
"Computing the goodness of fit is not yet supported for "
243-
"models trained on the full dataset (batchsize = None). ")
244240
else:
245241
if batch_size is None or num_sessions is None:
246242
raise ValueError(

tests/test_sklearn_metrics.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -482,14 +482,22 @@ def _fit_and_get_history(X, y):
482482

483483

484484
@pytest.mark.parametrize("seed", [42, 24, 10])
485-
def test_infonce_to_goodness_of_fit(seed):
485+
@pytest.mark.parametrize("batch_size", [100, 200])
486+
@pytest.mark.parametrize("num_negatives", [None, 100, 200])
487+
def test_infonce_to_goodness_of_fit(seed, batch_size, num_negatives):
486488
"""Test the conversion from InfoNCE loss to goodness of fit metric."""
489+
nats_to_bits = np.log2(np.e)
490+
487491
# Test with model
488492
cebra_model = cebra_sklearn_cebra.CEBRA(
489493
model_architecture="offset10-model",
490494
max_iterations=5,
491-
batch_size=128,
495+
batch_size=batch_size,
496+
num_negatives=num_negatives,
492497
)
498+
if num_negatives is None:
499+
num_negatives = batch_size
500+
493501
generator = torch.Generator().manual_seed(seed)
494502
X = torch.rand(1000, 50, dtype=torch.float32, generator=generator)
495503
cebra_model.fit(X)
@@ -498,19 +506,22 @@ def test_infonce_to_goodness_of_fit(seed):
498506
gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
499507
model=cebra_model)
500508
assert isinstance(gof, float)
509+
assert np.isclose(gof, (np.log(num_negatives) - 1.0) * nats_to_bits)
501510

502511
# Test array of values
503512
infonce_values = np.array([1.0, 2.0, 3.0])
504513
gof_array = cebra_sklearn_metrics.infonce_to_goodness_of_fit(
505514
infonce_values, model=cebra_model)
506515
assert isinstance(gof_array, np.ndarray)
507516
assert gof_array.shape == infonce_values.shape
517+
assert np.allclose(gof_array,
518+
(np.log(num_negatives) - infonce_values) * nats_to_bits)
508519

509520
# Test with explicit batch_size and num_sessions
510-
gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
511-
batch_size=128,
512-
num_sessions=1)
521+
gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit(
522+
1.0, batch_size=batch_size, num_sessions=1)
513523
assert isinstance(gof, float)
524+
assert np.isclose(gof, (np.log(batch_size) - 1.0) * nats_to_bits)
514525

515526
# Test error cases
516527
with pytest.raises(ValueError, match="batch_size.*should not be provided"):

0 commit comments

Comments
 (0)