Skip to content

Commit 718f7ca

Browse files
authored
Merge branch 'main' into unified-cebra
2 parents b4caf3a + 7ae5e1e commit 718f7ca

File tree

4 files changed

+3
-4
lines changed

4 files changed

+3
-4
lines changed

cebra/datasets/demo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
NUMS_NEURAL = [3, 4, 5]
3737

3838

39+
3940
class DemoDataset(cebra.data.SingleSessionDataset):
4041

4142
def __init__(self, num_timepoints=_DEFAULT_NUM_TIMEPOINTS, num_neural=4):

cebra/solver/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,6 @@ def num_parameters(self) -> int:
396396
@abc.abstractmethod
397397
def parameters(self, session_id: Optional[int] = None):
398398
"""Iterate over all parameters of the model.
399-
400399
Args:
401400
session_id: The session ID, an :py:class:`int` between 0 and
402401
the number of sessions -1 for multisession, and set to

cebra/solver/single_session.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import copy
2525
from typing import Optional, Tuple
2626

27+
2728
import literate_dataclasses as dataclasses
2829
import torch
2930

tests/test_solver.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,8 @@ def _get_loader(data_name, loader_initfunc):
4141
loader = loader_initfunc(data, **kwargs)
4242
return loader, data
4343

44-
4544
OUTPUT_DIMENSION = 3
4645

47-
4846
def _make_model(dataset, model_architecture="offset10-model"):
4947
# TODO flexible input dimension
5048
# return nn.Sequential(
@@ -400,4 +398,4 @@ def test_unified_session(data_name, model_architecture, loader_initfunc,
400398
assert emb.shape == (loader.dataset.num_timepoints, 3)
401399

402400
emb = solver.transform(data, labels, session_id=i, batch_size=300)
403-
assert emb.shape == (loader.dataset.num_timepoints, 3)
401+
assert emb.shape == (loader.dataset.num_timepoints, 3)

0 commit comments

Comments
 (0)