Skip to content

Commit 923675b

Browse files
committed
Fix examples
1 parent f7a7042 commit 923675b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

cebra/integrations/sklearn/metrics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def goodness_of_fit_score(cebra_model: cebra_sklearn_cebra.CEBRA,
141141
>>> cebra_model = cebra.CEBRA(max_iterations=10)
142142
>>> cebra_model.fit(neural_data)
143143
CEBRA(max_iterations=10)
144-
>>> gof = cebra.goodness_of_fit_score(cebra_model, neural_data)
144+
>>> gof = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model, neural_data)
145145
"""
146146
loss = infonce_loss(cebra_model,
147147
X,
@@ -172,7 +172,7 @@ def goodness_of_fit_history(model):
172172
>>> cebra_model = cebra.CEBRA(max_iterations=10)
173173
>>> cebra_model.fit(neural_data)
174174
CEBRA(max_iterations=10)
175-
>>> gof_history = cebra.goodness_of_fit_history(cebra_model)
175+
>>> gof_history = cebra.sklearn.metrics.goodness_of_fit_history(cebra_model)
176176
"""
177177
infonce = np.array(model.state_dict_["log"]["total"])
178178
return infonce_to_goodness_of_fit(infonce, model)
@@ -215,7 +215,7 @@ def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]],
215215
num_sessions = model.num_sessions_
216216
if num_sessions is None:
217217
num_sessions = 1
218-
chance_level = np.log(model.batch_size * (model.num_sessions_ or 1))
218+
chance_level = np.log(model.batch_size * num_sessions)
219219
return (chance_level - infonce) * nats_to_bits
220220

221221

0 commit comments

Comments
 (0)