Skip to content

Commit 5cee743

Browse files
authored
Fix docstring tests
1 parent ad8ae60 commit 5cee743

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

cebra/integrations/sklearn/metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def goodness_of_fit_score(cebra_model: cebra_sklearn_cebra.CEBRA,
140140
>>> neural_data = np.random.uniform(0, 1, (1000, 20))
141141
>>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512)
142142
>>> cebra_model.fit(neural_data)
143-
CEBRA(max_iterations=10)
143+
CEBRA(batch_size=512, max_iterations=10)
144144
>>> gof = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model, neural_data)
145145
"""
146146
loss = infonce_loss(cebra_model,
@@ -171,7 +171,7 @@ def goodness_of_fit_history(model):
171171
>>> neural_data = np.random.uniform(0, 1, (1000, 20))
172172
>>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512)
173173
>>> cebra_model.fit(neural_data)
174-
CEBRA(max_iterations=10)
174+
CEBRA(batch_size=512, max_iterations=10)
175175
>>> gof_history = cebra.sklearn.metrics.goodness_of_fit_history(cebra_model)
176176
"""
177177
infonce = np.array(model.state_dict_["log"]["total"])

0 commit comments

Comments
 (0)