Skip to content

Commit f7a7042

Browse files
committed
add tests and improve implementation
1 parent 2372c8b commit f7a7042

File tree

2 files changed

+116
-18
lines changed

2 files changed

+116
-18
lines changed

cebra/integrations/sklearn/metrics.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,15 @@ def infonce_loss(
108108
return avg_loss
109109

110110

111-
def goodness_of_fit_score(
112-
cebra_model: cebra_sklearn_cebra.CEBRA,
113-
X: Union[npt.NDArray, torch.Tensor],
114-
*y,
115-
session_id: Optional[int] = None,
116-
num_batches: int = 500,
117-
correct_by_batchsize: bool = False,
118-
) -> float:
111+
def goodness_of_fit_score(cebra_model: cebra_sklearn_cebra.CEBRA,
112+
X: Union[npt.NDArray, torch.Tensor],
113+
*y,
114+
session_id: Optional[int] = None,
115+
num_batches: int = 500) -> float:
119116
"""Compute the InfoNCE loss on a *single session* dataset on the model.
120117
118+
This function uses the :func:`infonce_loss` function to compute the InfoNCE loss.
119+
121120
Args:
122121
cebra_model: The model to use to compute the InfoNCE loss on the samples.
123122
X: A 2D data matrix, corresponding to a *single session* recording.
@@ -127,23 +126,60 @@ def goodness_of_fit_score(
127126
for multisession, set to ``None`` for single session.
128127
num_batches: The number of iterations to consider to evaluate the model on the new data.
129128
Higher values will give a more accurate estimate. Set it to at least 500 iterations.
129+
130+
Returns:
131+
The average GoF score estimated over ``num_batches`` batches from the data distribution.
132+
133+
Related:
134+
:func:`infonce_to_goodness_of_fit`
135+
136+
Example:
137+
138+
>>> import cebra
139+
>>> import numpy as np
140+
>>> neural_data = np.random.uniform(0, 1, (1000, 20))
141+
>>> cebra_model = cebra.CEBRA(max_iterations=10)
142+
>>> cebra_model.fit(neural_data)
143+
CEBRA(max_iterations=10)
144+
>>> gof = cebra.goodness_of_fit_score(cebra_model, neural_data)
130145
"""
131-
loss = infonce_loss(cebra_model=cebra_model,
132-
X=X,
146+
loss = infonce_loss(cebra_model,
147+
X,
133148
*y,
134149
session_id=session_id,
135-
num_batches=500,
150+
num_batches=num_batches,
136151
correct_by_batchsize=False)
137152
return infonce_to_goodness_of_fit(loss, cebra_model)
138153

139154

140-
def goodness_of_fit_score(model):
155+
def goodness_of_fit_history(model):
156+
"""Return the history of the goodness of fit score.
157+
158+
Args:
159+
model: A trained CEBRA model.
160+
161+
Returns:
162+
A numpy array containing the goodness of fit values, measured in bits.
163+
164+
Related:
165+
:func:`infonce_to_goodness_of_fit`
166+
167+
Example:
168+
169+
>>> import cebra
170+
>>> import numpy as np
171+
>>> neural_data = np.random.uniform(0, 1, (1000, 20))
172+
>>> cebra_model = cebra.CEBRA(max_iterations=10)
173+
>>> cebra_model.fit(neural_data)
174+
CEBRA(max_iterations=10)
175+
>>> gof_history = cebra.goodness_of_fit_history(cebra_model)
176+
"""
141177
infonce = np.array(model.state_dict_["log"]["total"])
142178
return infonce_to_goodness_of_fit(infonce, model)
143179

144180

145181
def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]],
146-
model: cebra.CEBRA) -> np.ndarray:
182+
model: cebra_sklearn_cebra.CEBRA) -> np.ndarray:
147183
"""Given a trained CEBRA model, return goodness of fit metric
148184
149185
The goodness of fit ranges from 0 (lowest meaningful value)
@@ -161,18 +197,16 @@ def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]],
161197
162198
.. math::
163199
164-
S = \log N - \text{InfoNCE}
200+
S = \\log N - \\text{InfoNCE}
165201
166202
Args:
167203
model: The trained CEBRA model
168204
169205
Returns:
170-
Numpy array containing the goodness of fit
171-
values, measured in bits
206+
Numpy array containing the goodness of fit values, measured in bits
172207
173208
Raises:
174-
``RuntimeError``, if provided model is not
175-
fit to data.
209+
``RuntimeError``, if provided model is not fit to data.
176210
"""
177211
if not hasattr(model, "state_dict_"):
178212
raise RuntimeError("Fit the CEBRA model first.")

tests/test_sklearn_metrics.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,3 +383,67 @@ def test_sklearn_runs_consistency():
383383
with pytest.raises(ValueError, match="Invalid.*embeddings"):
384384
_, _, _ = cebra_sklearn_metrics.consistency_score(
385385
invalid_embeddings_runs, between="runs")
386+
387+
388+
@pytest.mark.parametrize("seed", [42, 24, 10])
389+
def test_goodness_of_fit_score(seed):
390+
"""
391+
Ensure that the GoF score is close to 0 for a model fit on random data.
392+
"""
393+
cebra_model = cebra_sklearn_cebra.CEBRA(
394+
model_architecture="offset1-model",
395+
max_iterations=5,
396+
batch_size=512,
397+
)
398+
X = torch.tensor(np.random.uniform(0, 1, (5000, 50)))
399+
y = torch.tensor(np.random.uniform(0, 1, (5000, 5)))
400+
cebra_model.fit(X, y)
401+
score = cebra_sklearn_metrics.goodness_of_fit_score(cebra_model,
402+
X,
403+
y,
404+
session_id=0,
405+
num_batches=500)
406+
assert isinstance(score, float)
407+
assert np.isclose(score, 0, atol=0.01)
408+
409+
410+
@pytest.mark.parametrize("seed", [42, 24, 10])
411+
def test_goodness_of_fit_history(seed):
412+
"""
413+
Ensure that the GoF score is higher for a model fit on data with underlying
414+
structure than for a model fit on random data.
415+
"""
416+
417+
# Generate data
418+
generator = torch.Generator().manual_seed(seed)
419+
X = torch.rand(1000, 50, dtype=torch.float32, generator=generator)
420+
y_random = torch.rand(len(X), 5, dtype=torch.float32, generator=generator)
421+
linear_map = torch.randn(50, 5, dtype=torch.float32, generator=generator)
422+
y_linear = X @ linear_map
423+
424+
def _fit_and_get_history(X, y):
425+
cebra_model = cebra_sklearn_cebra.CEBRA(
426+
model_architecture="offset1-model",
427+
max_iterations=150,
428+
batch_size=512,
429+
device="cpu")
430+
cebra_model.fit(X, y)
431+
history = cebra_sklearn_metrics.goodness_of_fit_history(cebra_model)
432+
# NOTE(stes): Ignore the first 5 iterations, they can have nonsensical values
433+
# due to numerical issues.
434+
return history[5:]
435+
436+
history_random = _fit_and_get_history(X, y_random)
437+
history_linear = _fit_and_get_history(X, y_linear)
438+
439+
assert isinstance(history_random, np.ndarray)
440+
assert history_random.shape[0] > 0
441+
# NOTE(stes): Ignore the first 5 iterations, they can have nonsensical values
442+
# due to numerical issues.
443+
history_random_non_negative = history_random[history_random >= 0]
444+
np.testing.assert_allclose(history_random_non_negative, 0, atol=0.05)
445+
446+
assert isinstance(history_linear, np.ndarray)
447+
assert history_linear.shape[0] > 0
448+
449+
assert np.all(history_linear[-20:] > history_random[-20:])

0 commit comments

Comments
 (0)