Skip to content

Commit 2372c8b

Browse files
committed
Started implementing improved goodness of fit implementation
1 parent 5f46c32 commit 2372c8b

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

cebra/integrations/sklearn/metrics.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,83 @@ def infonce_loss(
108108
return avg_loss
109109

110110

111+
def goodness_of_fit_score(
112+
cebra_model: cebra_sklearn_cebra.CEBRA,
113+
X: Union[npt.NDArray, torch.Tensor],
114+
*y,
115+
session_id: Optional[int] = None,
116+
num_batches: int = 500,
117+
correct_by_batchsize: bool = False,
118+
) -> float:
119+
"""Compute the InfoNCE loss on a *single session* dataset on the model.
120+
121+
Args:
122+
cebra_model: The model to use to compute the InfoNCE loss on the samples.
123+
X: A 2D data matrix, corresponding to a *single session* recording.
124+
y: An arbitrary amount of continuous indices passed as 2D matrices, and up to one
125+
discrete index passed as a 1D array. Each index has to match the length of ``X``.
126+
session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`cebra.CEBRA.num_sessions`
127+
for multisession, set to ``None`` for single session.
128+
num_batches: The number of iterations to consider to evaluate the model on the new data.
129+
Higher values will give a more accurate estimate. Set it to at least 500 iterations.
130+
"""
131+
loss = infonce_loss(cebra_model=cebra_model,
132+
X=X,
133+
*y,
134+
session_id=session_id,
135+
num_batches=500,
136+
correct_by_batchsize=False)
137+
return infonce_to_goodness_of_fit(loss, cebra_model)
138+
139+
140+
def goodness_of_fit_score(model):
141+
infonce = np.array(model.state_dict_["log"]["total"])
142+
return infonce_to_goodness_of_fit(infonce, model)
143+
144+
145+
def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]],
146+
model: cebra.CEBRA) -> np.ndarray:
147+
"""Given a trained CEBRA model, return goodness of fit metric
148+
149+
The goodness of fit ranges from 0 (lowest meaningful value)
150+
to a positive number with the unit "bits", the higher the
151+
better.
152+
153+
Values lower than 0 bits are possible, but these only occur
154+
due to numerical effects. A perfectly collapsed embedding
155+
(e.g., because the data cannot be fit with the provided
156+
auxiliary variables) will have a goodness of fit of 0.
157+
158+
The conversion between the generalized InfoNCE metric that
159+
CEBRA is trained with and the goodness of fit computed with this
160+
function is
161+
162+
.. math::
163+
164+
S = \log N - \text{InfoNCE}
165+
166+
Args:
167+
model: The trained CEBRA model
168+
169+
Returns:
170+
Numpy array containing the goodness of fit
171+
values, measured in bits
172+
173+
Raises:
174+
``RuntimeError``, if provided model is not
175+
fit to data.
176+
"""
177+
if not hasattr(model, "state_dict_"):
178+
raise RuntimeError("Fit the CEBRA model first.")
179+
180+
nats_to_bits = np.log2(np.e)
181+
num_sessions = model.num_sessions_
182+
if num_sessions is None:
183+
num_sessions = 1
184+
chance_level = np.log(model.batch_size * (model.num_sessions_ or 1))
185+
return (chance_level - infonce) * nats_to_bits
186+
187+
111188
def _consistency_scores(
112189
embeddings: List[Union[npt.NDArray, torch.Tensor]],
113190
datasets: List[Union[int, str]],

0 commit comments

Comments
 (0)