@@ -108,6 +108,149 @@ def infonce_loss(
108108 return avg_loss
109109
110110
111+ def goodness_of_fit_score (cebra_model : cebra_sklearn_cebra .CEBRA ,
112+ X : Union [npt .NDArray , torch .Tensor ],
113+ * y ,
114+ session_id : Optional [int ] = None ,
115+ num_batches : int = 500 ) -> float :
116+ """Compute the goodness of fit score on a *single session* dataset on the model.
117+
118+ This function uses the :func:`infonce_loss` function to compute the InfoNCE loss
119+ for a given `cebra_model` and the :func:`infonce_to_goodness_of_fit` function
120+ to derive the goodness of fit from the InfoNCE loss.
121+
122+ Args:
123+ cebra_model: The model to use to compute the InfoNCE loss on the samples.
124+ X: A 2D data matrix, corresponding to a *single session* recording.
125+ y: An arbitrary amount of continuous indices passed as 2D matrices, and up to one
126+ discrete index passed as a 1D array. Each index has to match the length of ``X``.
127+ session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`cebra.CEBRA.num_sessions`
128+ for multisession, set to ``None`` for single session.
129+ num_batches: The number of iterations to consider to evaluate the model on the new data.
130+ Higher values will give a more accurate estimate. Set it to at least 500 iterations.
131+
132+ Returns:
133+ The average GoF score estimated over ``num_batches`` batches from the data distribution.
134+
135+ Related:
136+ :func:`infonce_to_goodness_of_fit`
137+
138+ Example:
139+
140+ >>> import cebra
141+ >>> import numpy as np
142+ >>> neural_data = np.random.uniform(0, 1, (1000, 20))
143+ >>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512)
144+ >>> cebra_model.fit(neural_data)
145+ CEBRA(batch_size=512, max_iterations=10)
146+ >>> gof = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model, neural_data)
147+ """
148+ loss = infonce_loss (cebra_model ,
149+ X ,
150+ * y ,
151+ session_id = session_id ,
152+ num_batches = num_batches ,
153+ correct_by_batchsize = False )
154+ return infonce_to_goodness_of_fit (loss , cebra_model )
155+
156+
157+ def goodness_of_fit_history (model : cebra_sklearn_cebra .CEBRA ) -> np .ndarray :
158+ """Return the history of the goodness of fit score.
159+
160+ Args:
161+ model: A trained CEBRA model.
162+
163+ Returns:
164+ A numpy array containing the goodness of fit values, measured in bits.
165+
166+ Related:
167+ :func:`infonce_to_goodness_of_fit`
168+
169+ Example:
170+
171+ >>> import cebra
172+ >>> import numpy as np
173+ >>> neural_data = np.random.uniform(0, 1, (1000, 20))
174+ >>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512)
175+ >>> cebra_model.fit(neural_data)
176+ CEBRA(batch_size=512, max_iterations=10)
177+ >>> gof_history = cebra.sklearn.metrics.goodness_of_fit_history(cebra_model)
178+ """
179+ infonce = np .array (model .state_dict_ ["log" ]["total" ])
180+ return infonce_to_goodness_of_fit (infonce , model )
181+
182+
183+ def infonce_to_goodness_of_fit (
184+ infonce : Union [float , np .ndarray ],
185+ model : Optional [cebra_sklearn_cebra .CEBRA ] = None ,
186+ batch_size : Optional [int ] = None ,
187+ num_sessions : Optional [int ] = None ) -> Union [float , np .ndarray ]:
188+ """Given a trained CEBRA model, return goodness of fit metric.
189+
190+ The goodness of fit ranges from 0 (lowest meaningful value)
191+ to a positive number with the unit "bits", the higher the
192+ better.
193+
194+ Values lower than 0 bits are possible, but these only occur
195+ due to numerical effects. A perfectly collapsed embedding
196+ (e.g., because the data cannot be fit with the provided
197+ auxiliary variables) will have a goodness of fit of 0.
198+
199+ The conversion between the generalized InfoNCE metric that
200+ CEBRA is trained with and the goodness of fit computed with this
201+ function is
202+
203+ .. math::
204+
205+ S = \\ log N - \\ text{InfoNCE}
206+
207+ To use this function, either provide a trained CEBRA model or the
208+ batch size and number of sessions.
209+
210+ Args:
211+ infonce: The InfoNCE loss, either a single value or an iterable of values.
212+ model: The trained CEBRA model.
213+ batch_size: The batch size used to train the model.
214+ num_sessions: The number of sessions used to train the model.
215+
216+ Returns:
217+ Numpy array containing the goodness of fit values, measured in bits
218+
219+ Raises:
220+ RuntimeError: If the provided model is not fit to data.
221+ ValueError: If both ``model`` and ``(batch_size, num_sessions)`` are provided.
222+ """
223+ if model is not None :
224+ if batch_size is not None or num_sessions is not None :
225+ raise ValueError (
226+ "batch_size and num_sessions should not be provided if model is provided."
227+ )
228+ if not hasattr (model , "state_dict_" ):
229+ raise RuntimeError ("Fit the CEBRA model first." )
230+ if model .batch_size is None :
231+ raise ValueError (
232+ "Computing the goodness of fit is not yet supported for "
233+ "models trained on the full dataset (batchsize = None). " )
234+ batch_size = model .batch_size
235+ num_sessions = model .num_sessions_
236+ if num_sessions is None :
237+ num_sessions = 1
238+
239+ if model .batch_size is None :
240+ raise ValueError (
241+ "Computing the goodness of fit is not yet supported for "
242+ "models trained on the full dataset (batchsize = None). " )
243+ else :
244+ if batch_size is None or num_sessions is None :
245+ raise ValueError (
246+ f"batch_size ({ batch_size } ) and num_sessions ({ num_sessions } )"
247+ f"should be provided if model is not provided." )
248+
249+ nats_to_bits = np .log2 (np .e )
250+ chance_level = np .log (batch_size * num_sessions )
251+ return (chance_level - infonce ) * nats_to_bits
252+
253+
111254def _consistency_scores (
112255 embeddings : List [Union [npt .NDArray , torch .Tensor ]],
113256 datasets : List [Union [int , str ]],
0 commit comments