@@ -1203,14 +1203,17 @@ Putting all previous snippet examples together, we obtain the following pipeline
12031203 import cebra
12041204 from numpy.random import uniform, randint
12051205 from sklearn.model_selection import train_test_split
1206+ import os
1207+ import tempfile
1208+ from pathlib import Path
12061209
12071210 # 1. Define a CEBRA model
12081211 cebra_model = cebra.CEBRA(
12091212 model_architecture = "offset10-model",
12101213 batch_size = 512,
12111214 learning_rate = 1e-4,
1212- max_iterations = 10, # TODO(user): to change to at least 10'000
1213- max_adapt_iterations = 10, # TODO(user): to change to ~100-500
1215+ max_iterations = 10, # TODO(user): to change to ~5000-10000
1216+ # max_adapt_iterations = 10, # TODO(user): use and to change to ~100-500 if adapting
12141217 time_offsets = 10,
12151218 output_dimension = 8,
12161219 verbose = False
@@ -1244,7 +1247,7 @@ Putting all previous snippet examples together, we obtain the following pipeline
12441247 # time contrastive learning
12451248 cebra_model.fit(train_data)
12461249 # discrete behavior contrastive learning
1247- cebra_model.fit(train_data, train_discrete_label, )
1250+ cebra_model.fit(train_data, train_discrete_label)
12481251 # continuous behavior contrastive learning
12491252 cebra_model.fit(train_data, train_continuous_label)
12501253 # mixed behavior contrastive learning
@@ -1258,10 +1261,10 @@ Putting all previous snippet examples together, we obtain the following pipeline
12581261 cebra_model = cebra.CEBRA.load(tmp_file)
12591262 train_embedding = cebra_model.transform(train_data)
12601263 valid_embedding = cebra_model.transform(valid_data)
1261- assert train_embedding.shape == (70, 8)
1262- assert valid_embedding.shape == (30, 8)
1264+ assert train_embedding.shape == (70, 8) # TODO(user): change to split ration & output dim
1265+ assert valid_embedding.shape == (30, 8) # TODO(user): change to split ration & output dim
12631266
1264- # 7. Evaluate the model performances
1267+ # 7. Evaluate the model performance (you can also check the train_data)
12651268 goodness_of_fit = cebra.sklearn.metrics.infonce_loss(cebra_model,
12661269 valid_data,
12671270 valid_discrete_label,
0 commit comments