Skip to content

Commit 3e91459

Browse files
committed
Add tests to solver
1 parent d71ca8d commit 3e91459

File tree

4 files changed

+226
-8
lines changed

4 files changed

+226
-8
lines changed

cebra/data/multi_session.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,18 @@ def configure_for(self, model: "cebra.models.Model"):
116116
for i, session in enumerate(self.iter_sessions()):
117117
session.configure_for(model[i])
118118

119+
def configure_for(self, model: "cebra.models.Model"):
120+
"""Configure the dataset offset for the provided model.
121+
122+
Call this function before indexing the dataset. This sets the
123+
:py:attr:`offset` attribute of the dataset.
124+
125+
Args:
126+
model: The model to configure the dataset for.
127+
"""
128+
for i, session in enumerate(self.iter_sessions()):
129+
session.configure_for(model[i])
130+
119131

120132
@dataclasses.dataclass
121133
class MultiSessionLoader(cebra_data.Loader):

cebra/data/single_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def configure_for(self, model: "cebra.models.Model"):
7373
"""Configure the dataset offset for the provided model.
7474
7575
Call this function before indexing the dataset. This sets the
76-
`offset` attribute of the dataset.
76+
:py:attr:`offset` attribute of the dataset.
7777
7878
Args:
7979
model: The model to configure the dataset for.

cebra/solver/base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import literate_dataclasses as dataclasses
3939
import numpy.typing as npt
4040
import numpy as np
41+
import numpy.typing as npt
4142
import torch
4243
import torch.nn.functional as F
4344
import tqdm
@@ -569,8 +570,12 @@ def _select_model(
569570
raise NotImplementedError
570571

571572
@property
572-
def is_fitted(self):
573-
return hasattr(self, "n_features")
573+
def _check_is_fitted(self):
574+
#NOTE(celia): instead of hasattr(model, "n_features_"), double check this!
575+
if not (hasattr(self, "history") and len(self.history) > 0):
576+
raise ValueError(
577+
f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with "
578+
"appropriate arguments before using this estimator.")
574579

575580
@torch.no_grad()
576581
def transform(self,

tests/test_solver.py

Lines changed: 206 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,13 @@
5959
cebra.data.ContinuousMultiSessionDataLoader, "offset1-model"),
6060
("demo-continuous-multisession",
6161
cebra.data.ContinuousMultiSessionDataLoader, "offset10-model"),
62+
("demo-discrete-multisession",
63+
cebra.data.DiscreteMultiSessionDataLoader, "offset1-model"),
64+
("demo-discrete-multisession",
65+
cebra.data.DiscreteMultiSessionDataLoader, "offset10-model"),
6266
]:
6367
multi_session_tests.append((*args, cebra.solver.MultiSessionSolver))
6468

65-
# multi_session_tests.append((*args, cebra.solver.MultiSessionAuxVariableSolver))
66-
6769

6870
def _get_loader(data, loader_initfunc):
6971
kwargs = dict(num_steps=5, batch_size=32)
@@ -165,6 +167,28 @@ def test_single_session(data_name, loader_initfunc, model_architecture,
165167

166168
assert solver.num_sessions == None
167169
assert solver.n_features == X.shape[1]
170+
171+
embedding = solver.transform(X)
172+
assert isinstance(embedding, torch.Tensor)
173+
assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION)
174+
embedding = solver.transform(torch.Tensor(X))
175+
assert isinstance(embedding, torch.Tensor)
176+
assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION)
177+
embedding = solver.transform(X, session_id=0)
178+
assert isinstance(embedding, torch.Tensor)
179+
assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION)
180+
embedding = solver.transform(X, pad_before_transform=False)
181+
assert isinstance(embedding, torch.Tensor)
182+
assert embedding.shape == (X.shape[0] - len(offset) + 1, OUTPUT_DIMENSION)
183+
184+
with pytest.raises(ValueError, match="torch.Tensor"):
185+
solver.transform(X.numpy())
186+
with pytest.raises(RuntimeError, match="Invalid.*session_id"):
187+
embedding = solver.transform(X, session_id=2)
188+
189+
for param in solver.parameters():
190+
assert isinstance(param, torch.Tensor)
191+
168192

169193
embedding = solver.transform(X)
170194
assert isinstance(embedding, torch.Tensor)
@@ -320,6 +344,183 @@ def test_multi_session(data_name, loader_initfunc, model_architecture,
320344
assert solver.num_sessions == 3
321345
assert solver.n_features == [X[i].shape[1] for i in range(len(X))]
322346

347+
embedding = solver.transform(X[0], session_id=0)
348+
assert isinstance(embedding, torch.Tensor)
349+
assert embedding.shape == (X[0].shape[0], OUTPUT_DIMENSION)
350+
embedding = solver.transform(X[1], session_id=1)
351+
assert isinstance(embedding, torch.Tensor)
352+
assert embedding.shape == (X[1].shape[0], OUTPUT_DIMENSION)
353+
embedding = solver.transform(X[0], session_id=0, pad_before_transform=False)
354+
assert isinstance(embedding, torch.Tensor)
355+
assert embedding.shape == (X[0].shape[0] -
356+
len(solver.model[0].get_offset()) + 1,
357+
OUTPUT_DIMENSION)
358+
359+
with pytest.raises(ValueError, match="torch.Tensor"):
360+
embedding = solver.transform(X[0].numpy(), session_id=0)
361+
362+
with pytest.raises(ValueError, match="shape"):
363+
embedding = solver.transform(X[1], session_id=0)
364+
with pytest.raises(ValueError, match="shape"):
365+
embedding = solver.transform(X[0], session_id=1)
366+
367+
with pytest.raises(RuntimeError, match="No.*session_id"):
368+
embedding = solver.transform(X[0])
369+
with pytest.raises(RuntimeError, match="single.*session"):
370+
embedding = solver.transform(X)
371+
with pytest.raises(RuntimeError, match="Invalid.*session_id"):
372+
embedding = solver.transform(X[0], session_id=5)
373+
with pytest.raises(RuntimeError, match="Invalid.*session_id"):
374+
embedding = solver.transform(X[0], session_id=-1)
375+
376+
for param in solver.parameters(session_id=0):
377+
assert isinstance(param, torch.Tensor)
378+
379+
with pytest.raises(RuntimeError, match="No.*session_id"):
380+
for param in solver.parameters():
381+
assert isinstance(param, torch.Tensor)
382+
383+
384+
@pytest.mark.parametrize(
385+
"inputs, add_padding, offset, start_batch_idx, end_batch_idx, expected_output",
386+
[
387+
# Test case 1: No padding
388+
(torch.tensor([[1, 2], [3, 4], [5, 6]]), False, cebra.data.Offset(
389+
0, 1), 0, 2, torch.tensor([[1, 2], [3, 4]])), # first batch
390+
(torch.tensor([[1, 2], [3, 4], [5, 6]]), False, cebra.data.Offset(
391+
0, 1), 1, 3, torch.tensor([[3, 4], [5, 6]])), # last batch
392+
(torch.tensor(
393+
[[1, 2], [3, 4], [5, 6], [7, 8]]), False, cebra.data.Offset(
394+
0, 1), 1, 3, torch.tensor([[3, 4], [5, 6]])), # middle batch
395+
396+
# Test case 2: First batch with padding
397+
(
398+
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
399+
True,
400+
cebra.data.Offset(0, 1),
401+
0,
402+
2,
403+
torch.tensor([[1, 2, 3], [4, 5, 6]]),
404+
),
405+
(
406+
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
407+
True,
408+
cebra.data.Offset(1, 1),
409+
0,
410+
3,
411+
torch.tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [7, 8, 9]]),
412+
),
413+
414+
# Test case 3: Last batch with padding
415+
(
416+
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
417+
True,
418+
cebra.data.Offset(0, 1),
419+
1,
420+
3,
421+
torch.tensor([[4, 5, 6], [7, 8, 9]]),
422+
),
423+
(
424+
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12],
425+
[13, 14, 15]]),
426+
True,
427+
cebra.data.Offset(1, 2),
428+
1,
429+
3,
430+
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),
431+
),
432+
433+
# Test case 4: Middle batch with padding
434+
(
435+
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),
436+
True,
437+
cebra.data.Offset(0, 1),
438+
1,
439+
3,
440+
torch.tensor([[4, 5, 6], [7, 8, 9]]),
441+
),
442+
(
443+
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),
444+
True,
445+
cebra.data.Offset(1, 1),
446+
1,
447+
3,
448+
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
449+
),
450+
(
451+
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12],
452+
[13, 14, 15]]),
453+
True,
454+
cebra.data.Offset(0, 1),
455+
2,
456+
4,
457+
torch.tensor([[7, 8, 9], [10, 11, 12]]),
458+
),
459+
(
460+
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),
461+
True,
462+
cebra.data.Offset(0, 1),
463+
0,
464+
3,
465+
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
466+
),
467+
468+
# Examples that throw an error:
469+
470+
# Padding without offset (should raise an error)
471+
(torch.tensor([[1, 2]]), True, None, 0, 2, ValueError),
472+
# Negative start_batch_idx or end_batch_idx (should raise an error)
473+
(torch.tensor([[1, 2]]), False, cebra.data.Offset(
474+
0, 1), -1, 2, ValueError),
475+
# out of bound indices because offset is too large
476+
(torch.tensor([[1, 2], [3, 4]]), True, cebra.data.Offset(
477+
5, 5), 1, 2, ValueError),
478+
# Batch length is smaller than offset.
479+
(torch.tensor([[1, 2], [3, 4]]), False, cebra.data.Offset(
480+
0, 1), 0, 1, ValueError), # first batch
481+
],
482+
)
483+
def test_get_batch(inputs, add_padding, offset, start_batch_idx, end_batch_idx,
484+
expected_output):
485+
if expected_output == ValueError:
486+
with pytest.raises(ValueError):
487+
cebra.solver.base._get_batch(inputs, offset, start_batch_idx,
488+
end_batch_idx, add_padding)
489+
else:
490+
result = cebra.solver.base._get_batch(inputs, offset, start_batch_idx,
491+
end_batch_idx, add_padding)
492+
assert torch.equal(result, expected_output)
493+
494+
495+
@pytest.mark.parametrize("data_name, loader_initfunc, solver_initfunc",
496+
multi_session_tests)
497+
def test_multi_session_2(data_name, loader_initfunc, solver_initfunc):
498+
loader = _get_loader(data_name, loader_initfunc)
499+
criterion = cebra.models.InfoNCE()
500+
model = nn.ModuleList(
501+
[_make_model(dataset) for dataset in loader.dataset.iter_sessions()])
502+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
503+
504+
solver = solver_initfunc(model=model,
505+
criterion=criterion,
506+
optimizer=optimizer,
507+
tqdm_on=True)
508+
509+
batch = next(iter(loader))
510+
for session_id, dataset in enumerate(loader.dataset.iter_sessions()):
511+
assert batch[session_id].reference.shape == (32,
512+
dataset.input_dimension,
513+
10)
514+
assert batch[session_id].index is not None
515+
516+
log = solver.step(batch)
517+
assert isinstance(log, dict)
518+
519+
solver.fit(loader)
520+
521+
assert solver.num_sessions == 3
522+
assert solver.n_features == [X[i].shape[1] for i in range(len(X))]
523+
323524
embedding = solver.transform(X[0], session_id=0)
324525
assert isinstance(embedding, torch.Tensor)
325526
assert embedding.shape == (X[0].shape[0], OUTPUT_DIMENSION)
@@ -504,8 +705,8 @@ def create_model(model_name, input_dimension):
504705

505706
@pytest.mark.parametrize(
506707
"data_name, model_name ,session_id, loader_initfunc, solver_initfunc",
507-
single_session_tests_select_model + single_session_hybrid_tests_select_model
508-
)
708+
single_session_tests_select_model +
709+
single_session_hybrid_tests_select_model)
509710
def test_select_model_single_session(data_name, model_name, session_id,
510711
loader_initfunc, solver_initfunc):
511712
dataset = cebra.datasets.init(data_name)
@@ -576,7 +777,7 @@ def test_select_model_multi_session(data_name, model_name, session_id,
576777
"offset40-model-4x-subsample",
577778
"offset1-model",
578779
"offset10-model",
579-
]
780+
] #NOTE(rodrigo): there is an issue with "offset4-model-2x-subsample" because it's not a convolutional model.
580781
batch_size_inference = [40_000, 99_990, 99_999]
581782

582783
single_session_tests_transform = []

0 commit comments

Comments
 (0)