@@ -1324,3 +1324,94 @@ Below is the documentation on the available arguments.
13241324 --train-ratio 0.8 Ratio of train dataset. The remaining will be used for valid and test split.
13251325 --valid-ratio 0.1 Ratio of validation set after the train data split. The remaining will be test split
13261326 --share-model
1327+
1328+ Model training using the Torch API
1329+ ----------------------------------
1330+
1331+ The scikit-learn API provides parametrization to many common use cases.
1332+ The Torch API however allows for more flexibility and customization, for e.g.
1333+ sampling, criterions, and data loaders.
1334+
1335+ In this minimal example we show how to initialize a CEBRA model using the Torch API.
1336+ Here the :py:class: `cebra.data.single_session.DiscreteDataLoader `
1337+ gets initialized which also allows the `prior ` to be directly parametrized.
1338+
1339+ 👉 For an example notebook using the Torch API check out the :doc: `demo_notebooks/Demo_Allen `.
1340+
1341+
1342+ .. testcode ::
1343+
1344+ import numpy as np
1345+ import cebra.datasets
1346+ import torch
1347+
1348+ if torch.cuda.is_available():
1349+ device = "cuda"
1350+ else:
1351+ device = "cpu"
1352+
1353+ neural_data = cebra.load_data(file="neural_data.npz", key="neural")
1354+
1355+ discrete_label = cebra.load_data(
1356+ file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"],
1357+ )
1358+
1359+ # 1. Define a CEBRA-ready dataset
1360+ input_data = cebra.data.TensorDataset(
1361+ torch.from_numpy(neural_data).type(torch.FloatTensor),
1362+ discrete=torch.from_numpy(np.array(discrete_label[:, 0])).type(torch.LongTensor),
1363+ ).to(device)
1364+
1365+ # 2. Define a CEBRA model
1366+ neural_model = cebra.models.init(
1367+ name="offset10-model",
1368+ num_neurons=input_data.input_dimension,
1369+ num_units=32,
1370+ num_output=2,
1371+ ).to(device)
1372+
1373+ input_data.configure_for(neural_model)
1374+
1375+ # 3. Define the Loss Function Criterion and Optimizer
1376+ crit = cebra.models.criterions.LearnableCosineInfoNCE(
1377+ temperature=1,
1378+ ).to(device)
1379+
1380+ opt = torch.optim.Adam(
1381+ list(neural_model.parameters()) + list(crit.parameters()),
1382+ lr=0.001,
1383+ weight_decay=0,
1384+ )
1385+
1386+ # 4. Initialize the CEBRA model
1387+ solver = cebra.solver.init(
1388+ name="single-session",
1389+ model=neural_model,
1390+ criterion=crit,
1391+ optimizer=opt,
1392+ tqdm_on=True,
1393+ ).to(device)
1394+
1395+ # 5. Define Data Loader
1396+ loader = cebra.data.single_session.DiscreteDataLoader(
1397+ dataset=input_data, num_steps=10, batch_size=200, prior="uniform"
1398+ )
1399+
1400+ # 6. Fit Model
1401+ solver.fit(loader=loader)
1402+
1403+ # 7. Transform Embedding
1404+ train_batches = np.lib.stride_tricks.sliding_window_view(
1405+ neural_data, neural_model.get_offset().__len__(), axis=0
1406+ )
1407+
1408+ x_train_emb = solver.transform(
1409+ torch.from_numpy(train_batches[:]).type(torch.FloatTensor).to(device)
1410+ ).to(device)
1411+
1412+ # 8. Plot Embedding
1413+ cebra.plot_embedding(
1414+ x_train_emb,
1415+ discrete_label[neural_model.get_offset().__len__() - 1 :, 0],
1416+ markersize=10,
1417+ )
0 commit comments