Skip to content

Commit a7bacc8

Browse files
authored
Update usage.rst
1 parent 31473a7 commit a7bacc8

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

docs/source/usage.rst

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,14 +1203,17 @@ Putting all previous snippet examples together, we obtain the following pipeline
12031203
import cebra
12041204
from numpy.random import uniform, randint
12051205
from sklearn.model_selection import train_test_split
1206+
import os
1207+
import tempfile
1208+
from pathlib import Path
12061209

12071210
# 1. Define a CEBRA model
12081211
cebra_model = cebra.CEBRA(
12091212
model_architecture = "offset10-model",
12101213
batch_size = 512,
12111214
learning_rate = 1e-4,
1212-
max_iterations = 10, # TODO(user): to change to at least 10'000
1213-
max_adapt_iterations = 10, # TODO(user): to change to ~100-500
1215+
max_iterations = 10, # TODO(user): to change to ~5000-10000
1216+
#max_adapt_iterations = 10, # TODO(user): use and to change to ~100-500 if adapting
12141217
time_offsets = 10,
12151218
output_dimension = 8,
12161219
verbose = False
@@ -1244,7 +1247,7 @@ Putting all previous snippet examples together, we obtain the following pipeline
12441247
# time contrastive learning
12451248
cebra_model.fit(train_data)
12461249
# discrete behavior contrastive learning
1247-
cebra_model.fit(train_data, train_discrete_label,)
1250+
cebra_model.fit(train_data, train_discrete_label)
12481251
# continuous behavior contrastive learning
12491252
cebra_model.fit(train_data, train_continuous_label)
12501253
# mixed behavior contrastive learning
@@ -1258,10 +1261,10 @@ Putting all previous snippet examples together, we obtain the following pipeline
12581261
cebra_model = cebra.CEBRA.load(tmp_file)
12591262
train_embedding = cebra_model.transform(train_data)
12601263
valid_embedding = cebra_model.transform(valid_data)
1261-
assert train_embedding.shape == (70, 8)
1262-
assert valid_embedding.shape == (30, 8)
1264+
assert train_embedding.shape == (70, 8) # TODO(user): change to split ration & output dim
1265+
assert valid_embedding.shape == (30, 8) # TODO(user): change to split ration & output dim
12631266

1264-
# 7. Evaluate the model performances
1267+
# 7. Evaluate the model performance (you can also check the train_data)
12651268
goodness_of_fit = cebra.sklearn.metrics.infonce_loss(cebra_model,
12661269
valid_data,
12671270
valid_discrete_label,

0 commit comments

Comments
 (0)