@@ -108,16 +108,15 @@ def infonce_loss(
108108 return avg_loss
109109
110110
111- def goodness_of_fit_score (
112- cebra_model : cebra_sklearn_cebra .CEBRA ,
113- X : Union [npt .NDArray , torch .Tensor ],
114- * y ,
115- session_id : Optional [int ] = None ,
116- num_batches : int = 500 ,
117- correct_by_batchsize : bool = False ,
118- ) -> float :
111+ def goodness_of_fit_score (cebra_model : cebra_sklearn_cebra .CEBRA ,
112+ X : Union [npt .NDArray , torch .Tensor ],
113+ * y ,
114+ session_id : Optional [int ] = None ,
115+ num_batches : int = 500 ) -> float :
119116 """Compute the InfoNCE loss on a *single session* dataset on the model.
120117
118+ This function uses the :func:`infonce_loss` function to compute the InfoNCE loss.
119+
121120 Args:
122121 cebra_model: The model to use to compute the InfoNCE loss on the samples.
123122 X: A 2D data matrix, corresponding to a *single session* recording.
@@ -127,23 +126,60 @@ def goodness_of_fit_score(
127126 for multisession, set to ``None`` for single session.
128127 num_batches: The number of iterations to consider to evaluate the model on the new data.
129128 Higher values will give a more accurate estimate. Set it to at least 500 iterations.
129+
130+ Returns:
131+ The average GoF score estimated over ``num_batches`` batches from the data distribution.
132+
133+ Related:
134+ :func:`infonce_to_goodness_of_fit`
135+
136+ Example:
137+
138+ >>> import cebra
139+ >>> import numpy as np
140+ >>> neural_data = np.random.uniform(0, 1, (1000, 20))
141+ >>> cebra_model = cebra.CEBRA(max_iterations=10)
142+ >>> cebra_model.fit(neural_data)
143+ CEBRA(max_iterations=10)
144+ >>> gof = cebra.goodness_of_fit_score(cebra_model, neural_data)
130145 """
131- loss = infonce_loss (cebra_model = cebra_model ,
132- X = X ,
146+ loss = infonce_loss (cebra_model ,
147+ X ,
133148 * y ,
134149 session_id = session_id ,
135- num_batches = 500 ,
150+ num_batches = num_batches ,
136151 correct_by_batchsize = False )
137152 return infonce_to_goodness_of_fit (loss , cebra_model )
138153
139154
140- def goodness_of_fit_score (model ):
155+ def goodness_of_fit_history (model ):
156+ """Return the history of the goodness of fit score.
157+
158+ Args:
159+ model: A trained CEBRA model.
160+
161+ Returns:
162+ A numpy array containing the goodness of fit values, measured in bits.
163+
164+ Related:
165+ :func:`infonce_to_goodness_of_fit`
166+
167+ Example:
168+
169+ >>> import cebra
170+ >>> import numpy as np
171+ >>> neural_data = np.random.uniform(0, 1, (1000, 20))
172+ >>> cebra_model = cebra.CEBRA(max_iterations=10)
173+ >>> cebra_model.fit(neural_data)
174+ CEBRA(max_iterations=10)
175+ >>> gof_history = cebra.goodness_of_fit_history(cebra_model)
176+ """
141177 infonce = np .array (model .state_dict_ ["log" ]["total" ])
142178 return infonce_to_goodness_of_fit (infonce , model )
143179
144180
145181def infonce_to_goodness_of_fit (infonce : Union [float , Iterable [float ]],
146- model : cebra .CEBRA ) -> np .ndarray :
182+ model : cebra_sklearn_cebra .CEBRA ) -> np .ndarray :
147183 """Given a trained CEBRA model, return goodness of fit metric
148184
149185 The goodness of fit ranges from 0 (lowest meaningful value)
@@ -161,18 +197,16 @@ def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]],
161197
162198 .. math::
163199
164- S = \log N - \t ext{InfoNCE}
200+ S = \\ log N - \ \ text{InfoNCE}
165201
166202 Args:
167203 model: The trained CEBRA model
168204
169205 Returns:
170- Numpy array containing the goodness of fit
171- values, measured in bits
206+ Numpy array containing the goodness of fit values, measured in bits
172207
173208 Raises:
174- ``RuntimeError``, if provided model is not
175- fit to data.
209+ ``RuntimeError``, if provided model is not fit to data.
176210 """
177211 if not hasattr (model , "state_dict_" ):
178212 raise RuntimeError ("Fit the CEBRA model first." )
0 commit comments