@@ -108,6 +108,68 @@ def infonce_loss(
108108 return avg_loss
109109
110110
111+ def goodness_of_fit (model : cebra_sklearn_cebra .CEBRA ) -> List [float ]:
112+ """Evaluate the goodness of fit (bits) for a given model.
113+
114+ This function calculates the goodness of fit for the provided model
115+ using the specified batch size. The goodness of fit is computed offline
116+ it is a way to normalize wrt batch size to compare models with
117+ different batch sizes or different implementations.
118+
119+ Args:
120+ model: The model to evaluate. This can be an instance of either
121+ `cebra_sklearn_cebra.CEBRA` or `cebra_solver.Solver`.
122+ batch_size: Batch size used to train the model.
123+
124+ Returns:
125+ A list of float values representing the goodness of fit for the model.
126+ """
127+
128+ if isinstance (model , cebra_sklearn_cebra .CEBRA ):
129+ if model .batch_size is None :
130+ raise NotImplementedError (
131+ "Batch size is None, please provide a model with a batch size to compute the goodness of fit."
132+ )
133+ if model .solver_name_ == 'single-session' :
134+ gof = _goodness_of_fit (loss = model .state_dict_ ["loss" ],
135+ batch_size = model .batch_size )
136+ elif model .solver_name_ == 'multi-session' :
137+ # For the multisession implementation, the batch size is multiplied by the
138+ # number of datasets to get the correct comparison.
139+ gof = _goodness_of_fit (loss = model .state_dict_ ["loss" ],
140+ batch_size = model .batch_size *
141+ model .num_sessions_ )
142+ else :
143+ raise NotImplementedError (f"Invalid solver: { model .solver_name_ } ." )
144+ elif isinstance (model , list ):
145+ raise ValueError (
146+ f"Model should correspond to a single CEBRA model,"
147+ f"got { type (model )} , containing { len (model )} elements." )
148+ else :
149+ raise ValueError (f"Provide CEBRA model, got { type (model )} ." )
150+ return gof
151+
152+
153+ def _goodness_of_fit (loss : List [float ], batch_size : int ) -> List [float ]:
154+ """
155+ Compute offline the goodness of fit (bits) from a provided loss.
156+
157+ This is a way to normalize wrt batch size to compare models with
158+ different batch sizes or different implementations.
159+
160+ Args:
161+ loss: A list of size `max_iteration`, corresponding to the loss across training.
162+ batch_size: Batch size used to train the model. For multisession implementation,
163+ you need to multiply the batch size by the number of datasets to get the correct
164+ comparison.
165+
166+ Returns:
167+ A list of float corresponding to the goodness of fit for the provided loss and batch size.
168+ """
169+ log_batch_size = np .log (batch_size )
170+ return [(1 / np .log (2 )) * (log_batch_size - lb ) for lb in loss ]
171+
172+
111173def _consistency_scores (
112174 embeddings : List [Union [npt .NDArray , torch .Tensor ]],
113175 datasets : List [Union [int , str ]],
0 commit comments