@@ -108,6 +108,83 @@ def infonce_loss(
108108 return avg_loss
109109
110110
111+ def goodness_of_fit_score (
112+ cebra_model : cebra_sklearn_cebra .CEBRA ,
113+ X : Union [npt .NDArray , torch .Tensor ],
114+ * y ,
115+ session_id : Optional [int ] = None ,
116+ num_batches : int = 500 ,
117+ correct_by_batchsize : bool = False ,
118+ ) -> float :
119+ """Compute the InfoNCE loss on a *single session* dataset on the model.
120+
121+ Args:
122+ cebra_model: The model to use to compute the InfoNCE loss on the samples.
123+ X: A 2D data matrix, corresponding to a *single session* recording.
124+ y: An arbitrary amount of continuous indices passed as 2D matrices, and up to one
125+ discrete index passed as a 1D array. Each index has to match the length of ``X``.
126+ session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`cebra.CEBRA.num_sessions`
127+ for multisession, set to ``None`` for single session.
128+ num_batches: The number of iterations to consider to evaluate the model on the new data.
129+ Higher values will give a more accurate estimate. Set it to at least 500 iterations.
130+ """
131+ loss = infonce_loss (cebra_model = cebra_model ,
132+ X = X ,
133+ * y ,
134+ session_id = session_id ,
135+ num_batches = 500 ,
136+ correct_by_batchsize = False )
137+ return infonce_to_goodness_of_fit (loss , cebra_model )
138+
139+
140+ def goodness_of_fit_score (model ):
141+ infonce = np .array (model .state_dict_ ["log" ]["total" ])
142+ return infonce_to_goodness_of_fit (infonce , model )
143+
144+
145+ def infonce_to_goodness_of_fit (infonce : Union [float , Iterable [float ]],
146+ model : cebra .CEBRA ) -> np .ndarray :
147+ """Given a trained CEBRA model, return goodness of fit metric
148+
149+ The goodness of fit ranges from 0 (lowest meaningful value)
150+ to a positive number with the unit "bits", the higher the
151+ better.
152+
153+ Values lower than 0 bits are possible, but these only occur
154+ due to numerical effects. A perfectly collapsed embedding
155+ (e.g., because the data cannot be fit with the provided
156+ auxiliary variables) will have a goodness of fit of 0.
157+
158+ The conversion between the generalized InfoNCE metric that
159+ CEBRA is trained with and the goodness of fit computed with this
160+ function is
161+
162+ .. math::
163+
164+ S = \log N - \t ext{InfoNCE}
165+
166+ Args:
167+ model: The trained CEBRA model
168+
169+ Returns:
170+ Numpy array containing the goodness of fit
171+ values, measured in bits
172+
173+ Raises:
174+ ``RuntimeError``, if provided model is not
175+ fit to data.
176+ """
177+ if not hasattr (model , "state_dict_" ):
178+ raise RuntimeError ("Fit the CEBRA model first." )
179+
180+ nats_to_bits = np .log2 (np .e )
181+ num_sessions = model .num_sessions_
182+ if num_sessions is None :
183+ num_sessions = 1
184+ chance_level = np .log (model .batch_size * (model .num_sessions_ or 1 ))
185+ return (chance_level - infonce ) * nats_to_bits
186+
187+
111188def _consistency_scores (
112189 embeddings : List [Union [npt .NDArray , torch .Tensor ]],
113190 datasets : List [Union [int , str ]],
0 commit comments