@@ -138,7 +138,7 @@ def goodness_of_fit_score(cebra_model: cebra_sklearn_cebra.CEBRA,
138138 >>> import cebra
139139 >>> import numpy as np
140140 >>> neural_data = np.random.uniform(0, 1, (1000, 20))
141- >>> cebra_model = cebra.CEBRA(max_iterations=10)
141+ >>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512 )
142142 >>> cebra_model.fit(neural_data)
143143 CEBRA(max_iterations=10)
144144 >>> gof = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model, neural_data)
@@ -169,7 +169,7 @@ def goodness_of_fit_history(model):
169169 >>> import cebra
170170 >>> import numpy as np
171171 >>> neural_data = np.random.uniform(0, 1, (1000, 20))
172- >>> cebra_model = cebra.CEBRA(max_iterations=10)
172+ >>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512 )
173173 >>> cebra_model.fit(neural_data)
174174 CEBRA(max_iterations=10)
175175 >>> gof_history = cebra.sklearn.metrics.goodness_of_fit_history(cebra_model)
@@ -210,6 +210,11 @@ def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]],
210210 """
211211 if not hasattr (model , "state_dict_" ):
212212 raise RuntimeError ("Fit the CEBRA model first." )
213+ if model .batch_size is None :
214+ raise ValueError (
215+ "Computing the goodness of fit is not yet supported for "
216+ "models trained on the full dataset (batchsize = None). "
217+ )
213218
214219 nats_to_bits = np .log2 (np .e )
215220 num_sessions = model .num_sessions_
0 commit comments