Skip to content

Commit 5745449

Browse files
committed
Run isort, ruff, yapf
1 parent 3acbdf4 commit 5745449

File tree

4 files changed

+8
-162
lines changed

4 files changed

+8
-162
lines changed

cebra/data/single_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def __post_init__(self):
370370

371371
self._init_behavior_distribution()
372372
self._init_time_distribution()
373-
373+
374374
if self.conditional != "time_delta":
375375
raise NotImplementedError(
376376
"Hybrid training is currently only implemented using the ``time_delta`` "

cebra/integrations/sklearn/cebra.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
np.dtypes.Float64DType, np.dtypes.Int64DType
5252
]
5353

54+
5455
def check_version(estimator):
5556
# NOTE(stes): required as a check for the old way of specifying tags
5657
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
@@ -76,8 +77,6 @@ def _safe_torch_load(filename, weights_only, **kwargs):
7677
return checkpoint
7778

7879

79-
80-
8180
def _init_loader(
8281
is_cont: bool,
8382
is_disc: bool,

cebra/solver/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import abc
3434
import os
3535
import warnings
36-
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
36+
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
3737

3838
import literate_dataclasses as dataclasses
3939
import numpy.typing as npt

tests/test_solver.py

Lines changed: 5 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,14 @@
5959
cebra.data.ContinuousMultiSessionDataLoader, "offset1-model"),
6060
("demo-continuous-multisession",
6161
cebra.data.ContinuousMultiSessionDataLoader, "offset10-model"),
62-
("demo-discrete-multisession",
63-
cebra.data.DiscreteMultiSessionDataLoader, "offset1-model"),
64-
("demo-discrete-multisession",
65-
cebra.data.DiscreteMultiSessionDataLoader, "offset10-model"),
62+
("demo-discrete-multisession", cebra.data.DiscreteMultiSessionDataLoader,
63+
"offset1-model"),
64+
("demo-discrete-multisession", cebra.data.DiscreteMultiSessionDataLoader,
65+
"offset10-model"),
6666
]:
6767
multi_session_tests.append((*args, cebra.solver.MultiSessionSolver))
6868

6969

70-
7170
def _get_loader(data, loader_initfunc):
7271
kwargs = dict(num_steps=5, batch_size=32)
7372
loader = loader_initfunc(data, **kwargs)
@@ -168,7 +167,7 @@ def test_single_session(data_name, loader_initfunc, model_architecture,
168167

169168
assert solver.num_sessions is None
170169
assert solver.n_features == X.shape[1]
171-
170+
172171
embedding = solver.transform(X)
173172
assert isinstance(embedding, torch.Tensor)
174173
assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION)
@@ -527,158 +526,6 @@ def test_multi_session_2(data_name, loader_initfunc, solver_initfunc):
527526

528527
solver.fit(loader)
529528

530-
assert solver.num_sessions == 3
531-
assert solver.n_features == [X[i].shape[1] for i in range(len(X))]
532-
533-
embedding = solver.transform(X[0], session_id=0)
534-
assert isinstance(embedding, torch.Tensor)
535-
assert embedding.shape == (X[0].shape[0], OUTPUT_DIMENSION)
536-
embedding = solver.transform(X[1], session_id=1)
537-
assert isinstance(embedding, torch.Tensor)
538-
assert embedding.shape == (X[1].shape[0], OUTPUT_DIMENSION)
539-
embedding = solver.transform(X[0], session_id=0, pad_before_transform=False)
540-
assert isinstance(embedding, torch.Tensor)
541-
assert embedding.shape == (X[0].shape[0] -
542-
len(solver.model[0].get_offset()) + 1,
543-
OUTPUT_DIMENSION)
544-
545-
with pytest.raises(ValueError, match="torch.Tensor"):
546-
embedding = solver.transform(X[0].numpy(), session_id=0)
547-
548-
with pytest.raises(ValueError, match="shape"):
549-
embedding = solver.transform(X[1], session_id=0)
550-
with pytest.raises(ValueError, match="shape"):
551-
embedding = solver.transform(X[0], session_id=1)
552-
553-
with pytest.raises(RuntimeError, match="No.*session_id"):
554-
embedding = solver.transform(X[0])
555-
with pytest.raises(ValueError, match="single.*session"):
556-
embedding = solver.transform(X)
557-
with pytest.raises(RuntimeError, match="Invalid.*session_id"):
558-
embedding = solver.transform(X[0], session_id=5)
559-
with pytest.raises(RuntimeError, match="Invalid.*session_id"):
560-
embedding = solver.transform(X[0], session_id=-1)
561-
562-
for param in solver.parameters(session_id=0):
563-
assert isinstance(param, torch.Tensor)
564-
565-
fitted_solver = copy.deepcopy(solver)
566-
with tempfile.TemporaryDirectory() as temp_dir:
567-
solver.save(temp_dir)
568-
solver.load(temp_dir)
569-
_assert_equal(fitted_solver, solver)
570-
571-
572-
@pytest.mark.parametrize(
573-
"inputs, add_padding, offset, start_batch_idx, end_batch_idx, expected_output",
574-
[
575-
# Test case 1: No padding
576-
(torch.tensor([[1, 2], [3, 4], [5, 6]]), False, cebra.data.Offset(
577-
0, 1), 0, 2, torch.tensor([[1, 2], [3, 4]])), # first batch
578-
(torch.tensor([[1, 2], [3, 4], [5, 6]]), False, cebra.data.Offset(
579-
0, 1), 1, 3, torch.tensor([[3, 4], [5, 6]])), # last batch
580-
(torch.tensor(
581-
[[1, 2], [3, 4], [5, 6], [7, 8]]), False, cebra.data.Offset(
582-
0, 1), 1, 3, torch.tensor([[3, 4], [5, 6]])), # middle batch
583-
584-
# Test case 2: First batch with padding
585-
(
586-
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
587-
True,
588-
cebra.data.Offset(0, 1),
589-
0,
590-
2,
591-
torch.tensor([[1, 2, 3], [4, 5, 6]]),
592-
),
593-
(
594-
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
595-
True,
596-
cebra.data.Offset(1, 1),
597-
0,
598-
3,
599-
torch.tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [7, 8, 9]]),
600-
),
601-
602-
# Test case 3: Last batch with padding
603-
(
604-
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
605-
True,
606-
cebra.data.Offset(0, 1),
607-
1,
608-
3,
609-
torch.tensor([[4, 5, 6], [7, 8, 9]]),
610-
),
611-
(
612-
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12],
613-
[13, 14, 15]]),
614-
True,
615-
cebra.data.Offset(1, 2),
616-
1,
617-
3,
618-
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),
619-
),
620-
621-
# Test case 4: Middle batch with padding
622-
(
623-
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),
624-
True,
625-
cebra.data.Offset(0, 1),
626-
1,
627-
3,
628-
torch.tensor([[4, 5, 6], [7, 8, 9]]),
629-
),
630-
(
631-
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),
632-
True,
633-
cebra.data.Offset(1, 1),
634-
1,
635-
3,
636-
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
637-
),
638-
(
639-
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12],
640-
[13, 14, 15]]),
641-
True,
642-
cebra.data.Offset(0, 1),
643-
2,
644-
4,
645-
torch.tensor([[7, 8, 9], [10, 11, 12]]),
646-
),
647-
(
648-
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),
649-
True,
650-
cebra.data.Offset(0, 1),
651-
0,
652-
3,
653-
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
654-
),
655-
656-
# Examples that throw an error:
657-
658-
# Padding without offset (should raise an error)
659-
(torch.tensor([[1, 2]]), True, None, 0, 2, ValueError),
660-
# Negative start_batch_idx or end_batch_idx (should raise an error)
661-
(torch.tensor([[1, 2]]), False, cebra.data.Offset(
662-
0, 1), -1, 2, ValueError),
663-
# out of bound indices because offset is too large
664-
(torch.tensor([[1, 2], [3, 4]]), True, cebra.data.Offset(
665-
5, 5), 1, 2, ValueError),
666-
# Batch length is smaller than offset.
667-
(torch.tensor([[1, 2], [3, 4]]), False, cebra.data.Offset(
668-
0, 1), 0, 1, ValueError), # first batch
669-
],
670-
)
671-
def test_get_batch(inputs, add_padding, offset, start_batch_idx, end_batch_idx,
672-
expected_output):
673-
if expected_output == ValueError:
674-
with pytest.raises(ValueError):
675-
cebra.solver.base._get_batch(inputs, offset, start_batch_idx,
676-
end_batch_idx, add_padding)
677-
else:
678-
result = cebra.solver.base._get_batch(inputs, offset, start_batch_idx,
679-
end_batch_idx, add_padding)
680-
assert torch.equal(result, expected_output)
681-
682529

683530
def create_model(model_name, input_dimension):
684531
return cebra.models.init(model_name,

0 commit comments

Comments
 (0)