Skip to content

Commit 1438bb0

Browse files
timonmerkstesMMathisLab
authored
Add torch API usage example (#99)
* add torch API usage example * Update docs/source/usage.rst * Update docs/source/usage.rst * Update usage.rst - fix typo * Update usage.rst - FIx other variables * Minor edit * Update usage.rst - minor typesetting * Update usage.rst --------- Co-authored-by: Steffen Schneider <[email protected]> Co-authored-by: Steffen Schneider <[email protected]> Co-authored-by: Mackenzie Mathis <[email protected]>
1 parent 8b9bb1a commit 1438bb0

File tree

1 file changed

+91
-0
lines changed

1 file changed

+91
-0
lines changed

docs/source/usage.rst

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,3 +1324,94 @@ Below is the documentation on the available arguments.
13241324
--train-ratio 0.8 Ratio of train dataset. The remaining will be used for valid and test split.
13251325
--valid-ratio 0.1 Ratio of validation set after the train data split. The remaining will be test split
13261326
--share-model
1327+
1328+
Model training using the Torch API
1329+
----------------------------------
1330+
1331+
The scikit-learn API provides parametrization to many common use cases.
1332+
The Torch API however allows for more flexibility and customization, for e.g.
1333+
sampling, criterions, and data loaders.
1334+
1335+
In this minimal example we show how to initialize a CEBRA model using the Torch API.
1336+
Here the :py:class:`cebra.data.single_session.DiscreteDataLoader`
1337+
gets initialized which also allows the `prior` to be directly parametrized.
1338+
1339+
👉 For an example notebook using the Torch API check out the :doc:`demo_notebooks/Demo_Allen`.
1340+
1341+
1342+
.. testcode::
1343+
1344+
import numpy as np
1345+
import cebra.datasets
1346+
import torch
1347+
1348+
if torch.cuda.is_available():
1349+
device = "cuda"
1350+
else:
1351+
device = "cpu"
1352+
1353+
neural_data = cebra.load_data(file="neural_data.npz", key="neural")
1354+
1355+
discrete_label = cebra.load_data(
1356+
file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"],
1357+
)
1358+
1359+
# 1. Define a CEBRA-ready dataset
1360+
input_data = cebra.data.TensorDataset(
1361+
torch.from_numpy(neural_data).type(torch.FloatTensor),
1362+
discrete=torch.from_numpy(np.array(discrete_label[:, 0])).type(torch.LongTensor),
1363+
).to(device)
1364+
1365+
# 2. Define a CEBRA model
1366+
neural_model = cebra.models.init(
1367+
name="offset10-model",
1368+
num_neurons=input_data.input_dimension,
1369+
num_units=32,
1370+
num_output=2,
1371+
).to(device)
1372+
1373+
input_data.configure_for(neural_model)
1374+
1375+
# 3. Define the Loss Function Criterion and Optimizer
1376+
crit = cebra.models.criterions.LearnableCosineInfoNCE(
1377+
temperature=1,
1378+
).to(device)
1379+
1380+
opt = torch.optim.Adam(
1381+
list(neural_model.parameters()) + list(crit.parameters()),
1382+
lr=0.001,
1383+
weight_decay=0,
1384+
)
1385+
1386+
# 4. Initialize the CEBRA model
1387+
solver = cebra.solver.init(
1388+
name="single-session",
1389+
model=neural_model,
1390+
criterion=crit,
1391+
optimizer=opt,
1392+
tqdm_on=True,
1393+
).to(device)
1394+
1395+
# 5. Define Data Loader
1396+
loader = cebra.data.single_session.DiscreteDataLoader(
1397+
dataset=input_data, num_steps=10, batch_size=200, prior="uniform"
1398+
)
1399+
1400+
# 6. Fit Model
1401+
solver.fit(loader=loader)
1402+
1403+
# 7. Transform Embedding
1404+
train_batches = np.lib.stride_tricks.sliding_window_view(
1405+
neural_data, neural_model.get_offset().__len__(), axis=0
1406+
)
1407+
1408+
x_train_emb = solver.transform(
1409+
torch.from_numpy(train_batches[:]).type(torch.FloatTensor).to(device)
1410+
).to(device)
1411+
1412+
# 8. Plot Embedding
1413+
cebra.plot_embedding(
1414+
x_train_emb,
1415+
discrete_label[neural_model.get_offset().__len__() - 1 :, 0],
1416+
markersize=10,
1417+
)

0 commit comments

Comments
 (0)