Skip to content

Commit 297ee92

Browse files
committed
Add goodness of fit metric
1 parent 9e14790 commit 297ee92

File tree

2 files changed

+113
-0
lines changed

2 files changed

+113
-0
lines changed

cebra/integrations/sklearn/metrics.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,68 @@ def infonce_loss(
108108
return avg_loss
109109

110110

111+
def goodness_of_fit(model: cebra_sklearn_cebra.CEBRA) -> List[float]:
112+
"""Evaluate the goodness of fit (bits) for a given model.
113+
114+
This function calculates the goodness of fit for the provided model
115+
using the specified batch size. The goodness of fit is computed offline
116+
it is a way to normalize wrt batch size to compare models with
117+
different batch sizes or different implementations.
118+
119+
Args:
120+
model: The model to evaluate. This can be an instance of either
121+
`cebra_sklearn_cebra.CEBRA` or `cebra_solver.Solver`.
122+
batch_size: Batch size used to train the model.
123+
124+
Returns:
125+
A list of float values representing the goodness of fit for the model.
126+
"""
127+
128+
if isinstance(model, cebra_sklearn_cebra.CEBRA):
129+
if model.batch_size is None:
130+
raise NotImplementedError(
131+
"Batch size is None, please provide a model with a batch size to compute the goodness of fit."
132+
)
133+
if model.solver_name_ == 'single-session':
134+
gof = _goodness_of_fit(loss=model.state_dict_["loss"],
135+
batch_size=model.batch_size)
136+
elif model.solver_name_ == 'multi-session':
137+
# For the multisession implementation, the batch size is multiplied by the
138+
# number of datasets to get the correct comparison.
139+
gof = _goodness_of_fit(loss=model.state_dict_["loss"],
140+
batch_size=model.batch_size *
141+
model.num_sessions_)
142+
else:
143+
raise NotImplementedError(f"Invalid solver: {model.solver_name_}.")
144+
elif isinstance(model, list):
145+
raise ValueError(
146+
f"Model should correspond to a single CEBRA model,"
147+
f"got {type(model)}, containing {len(model)} elements.")
148+
else:
149+
raise ValueError(f"Provide CEBRA model, got {type(model)}.")
150+
return gof
151+
152+
153+
def _goodness_of_fit(loss: List[float], batch_size: int) -> List[float]:
154+
"""
155+
Compute offline the goodness of fit (bits) from a provided loss.
156+
157+
This is a way to normalize wrt batch size to compare models with
158+
different batch sizes or different implementations.
159+
160+
Args:
161+
loss: A list of size `max_iteration`, corresponding to the loss across training.
162+
batch_size: Batch size used to train the model. For multisession implementation,
163+
you need to multiply the batch size by the number of datasets to get the correct
164+
comparison.
165+
166+
Returns:
167+
A list of float corresponding to the goodness of fit for the provided loss and batch size.
168+
"""
169+
log_batch_size = np.log(batch_size)
170+
return [(1 / np.log(2)) * (log_batch_size - lb) for lb in loss]
171+
172+
111173
def _consistency_scores(
112174
embeddings: List[Union[npt.NDArray, torch.Tensor]],
113175
datasets: List[Union[int, str]],

tests/test_sklearn_metrics.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,57 @@ def test_sklearn_infonce_loss():
223223
)
224224

225225

226+
def test_sklearn_goodness_of_fit():
227+
max_loss_iterations = 2
228+
cebra_model = cebra_sklearn_cebra.CEBRA(
229+
model_architecture="offset10-model",
230+
max_iterations=5,
231+
batch_size=128,
232+
)
233+
234+
# Example data
235+
X = torch.tensor(np.random.uniform(0, 1, (1000, 50)))
236+
y_c1 = torch.tensor(np.random.uniform(0, 1, (1000, 5)))
237+
238+
X2 = torch.tensor(np.random.uniform(0, 1, (500, 20)))
239+
y2_c1 = torch.tensor(np.random.uniform(0, 1, (500, 5)))
240+
241+
# Single session
242+
cebra_model.fit(X, y_c1)
243+
244+
gof = cebra.sklearn.metrics.goodness_of_fit(cebra_model)
245+
assert isinstance(gof, list)
246+
_gof = cebra.sklearn.metrics._goodness_of_fit(
247+
cebra_model.state_dict_["loss"], batch_size=128)
248+
assert isinstance(_gof, list)
249+
assert gof == _gof
250+
251+
# Multisession
252+
cebra_model.fit([X, X2], [y_c1, y2_c1])
253+
254+
gof = cebra.sklearn.metrics.goodness_of_fit(cebra_model)
255+
assert isinstance(gof, list)
256+
_gof = cebra.sklearn.metrics._goodness_of_fit(
257+
cebra_model.state_dict_["loss"], batch_size=128 * 2)
258+
assert isinstance(_gof, list)
259+
assert gof == _gof
260+
261+
# Multiple models passed
262+
with pytest.raises(ValueError, match="single.*model"):
263+
_ = cebra.sklearn.metrics.goodness_of_fit([cebra_model, cebra_model])
264+
265+
# No batch size
266+
cebra_model_no_bs = cebra_sklearn_cebra.CEBRA(
267+
model_architecture="offset10-model",
268+
max_iterations=max_loss_iterations,
269+
batch_size=None,
270+
)
271+
272+
cebra_model_no_bs.fit(X)
273+
with pytest.raises(NotImplementedError, match="Batch.*size"):
274+
gof = cebra.sklearn.metrics.goodness_of_fit(cebra_model_no_bs)
275+
276+
226277
def test_sklearn_datasets_consistency():
227278
# Example data
228279
np.random.seed(42)

0 commit comments

Comments
 (0)