Skip to content

Commit fa3cd3e

Browse files
committed
Merge remote-tracking branch 'upstream/main' into batched-inference-and-padding
2 parents 5745449 + 92c8b1f commit fa3cd3e

File tree

3 files changed

+7
-9
lines changed

3 files changed

+7
-9
lines changed

cebra/solver/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@
3737

3838
import literate_dataclasses as dataclasses
3939
import numpy.typing as npt
40-
import numpy as np
41-
import numpy.typing as npt
4240
import torch
4341
import torch.nn.functional as F
4442
from torch.utils.data import DataLoader
@@ -616,9 +614,9 @@ def transform(self,
616614
elif not isinstance(inputs, torch.Tensor):
617615
raise ValueError(
618616
f"Inputs should be a torch.Tensor, not {type(inputs)}.")
619-
617+
620618
self._check_is_fitted()
621-
619+
622620
model, offset = self._select_model(inputs, session_id)
623621

624622
if len(offset) < 2 and pad_before_transform:

tests/test_datasets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def test_demo():
6868

6969
@pytest.mark.requires_dataset
7070
def test_hippocampus():
71-
7271
pytest.skip("Outdated")
7372
dataset = cebra.datasets.init("rat-hippocampus-single")
7473
loader = cebra.data.ContinuousDataLoader(

tests/test_solver.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ def test_single_session(data_name, loader_initfunc, model_architecture,
195195
solver.load(temp_dir)
196196
_assert_equal(fitted_solver, solver)
197197

198-
199198
embedding = solver.transform(X)
200199
assert isinstance(embedding, torch.Tensor)
201200
assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION)
@@ -224,9 +223,11 @@ def test_single_session(data_name, loader_initfunc, model_architecture,
224223
_assert_equal(fitted_solver, solver)
225224

226225

227-
@pytest.mark.parametrize("data_name, loader_initfunc, model_architecture, solver_initfunc",
228-
single_session_tests)
229-
def test_single_session_auxvar(data_name, loader_initfunc, model_architecture, solver_initfunc):
226+
@pytest.mark.parametrize(
227+
"data_name, loader_initfunc, model_architecture, solver_initfunc",
228+
single_session_tests)
229+
def test_single_session_auxvar(data_name, loader_initfunc, model_architecture,
230+
solver_initfunc):
230231

231232
pytest.skip("Not yet supported")
232233

0 commit comments

Comments
 (0)