Skip to content

Commit a1218aa

Browse files
committed
Fix pytorch API usage example
1 parent 217a8a7 commit a1218aa

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

docs/source/usage.rst

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1436,17 +1436,14 @@ gets initialized which also allows the `prior` to be directly parametrized.
14361436
solver.fit(loader=loader)
14371437

14381438
# 7. Transform Embedding
1439-
train_batches = np.lib.stride_tricks.sliding_window_view(
1440-
neural_data, neural_model.get_offset().__len__(), axis=0
1441-
)
1442-
14431439
x_train_emb = solver.transform(
1444-
torch.from_numpy(train_batches[:]).type(torch.FloatTensor).to(device)
1445-
).to(device)
1440+
torch.from_numpy(neural_data).type(torch.FloatTensor).to(device),
1441+
pad_before_transform=True,
1442+
batch_size=512).to(device)
14461443

14471444
# 8. Plot Embedding
14481445
cebra.plot_embedding(
14491446
x_train_emb.cpu(),
1450-
discrete_label[neural_model.get_offset().__len__() - 1 :, 0],
1447+
discrete_label[:,0],
14511448
markersize=10,
14521449
)

0 commit comments

Comments
 (0)