Skip to content

Commit 66fc6aa

Browse files
committed
Update tests and duplicate code based on review
1 parent 2fcfb7f commit 66fc6aa

File tree

9 files changed

+213
-342
lines changed

9 files changed

+213
-342
lines changed

cebra/datasets/demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import cebra.io
3333
from cebra.datasets import register
3434

35-
_DEFAULT_NUM_TIMEPOINTS = 100000
35+
_DEFAULT_NUM_TIMEPOINTS = 1_000
3636

3737

3838
class DemoDataset(cebra.data.SingleSessionDataset):

cebra/integrations/sklearn/cebra.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,8 @@ def _configure_for_all(
826826

827827
def _select_model(self, X: Union[npt.NDArray, torch.Tensor],
828828
session_id: int):
829+
if isinstance(X, np.ndarray):
830+
X = torch.from_numpy(X)
829831
return self.solver_._select_model(X, session_id=session_id)
830832

831833
def _check_labels_types(self, y: tuple, session_id: Optional[int] = None):
@@ -1055,7 +1057,8 @@ def _partial_fit(
10551057
self.model_ = model
10561058

10571059
self.n_features_ = solver.n_features
1058-
self.num_sessions_ = solver.num_sessions
1060+
self.num_sessions_ = solver.num_sessions if hasattr(
1061+
solver, "num_sessions") else None
10591062
self.solver_ = solver
10601063
self.n_features_in_ = ([model[n].num_input for n in range(len(model))]
10611064
if is_multisession else model.num_input)
@@ -1241,10 +1244,6 @@ def transform(self,
12411244
if isinstance(X, np.ndarray):
12421245
X = torch.from_numpy(X)
12431246

1244-
if batch_size is not None and batch_size < 1:
1245-
raise ValueError(
1246-
f"Batch size should be at least 1, got {batch_size}")
1247-
12481247
with torch.no_grad():
12491248
output = self.solver_.transform(
12501249
inputs=X,

cebra/solver/base.py

Lines changed: 111 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
3737

3838
import literate_dataclasses as dataclasses
39-
import numpy.typing as npt
4039
import torch
4140
import torch.nn.functional as F
4241
from torch.utils.data import DataLoader
@@ -104,6 +103,15 @@ def _add_batched_zero_padding(batched_data: torch.Tensor,
104103
Returns:
105104
The padded batch.
106105
"""
106+
if batch_start_idx > batch_end_idx:
107+
raise ValueError(
108+
f"batch_start_idx ({batch_start_idx}) cannot be greater than batch_end_idx ({batch_end_idx})."
109+
)
110+
if batch_start_idx < 0 or batch_end_idx < 0:
111+
raise ValueError(
112+
f"batch_start_idx ({batch_start_idx}) and batch_end_idx ({batch_end_idx}) must be positive integers."
113+
)
114+
107115
reversed_dims = torch.arange(batched_data.ndim - 1, -1, -1)
108116

109117
if batch_start_idx == 0: # First batch
@@ -179,7 +187,7 @@ def _inference_transform(model: cebra.models.Model,
179187
return output
180188

181189

182-
def _transform(
190+
def _not_batched_transform(
183191
model: cebra.models.Model,
184192
inputs: torch.Tensor,
185193
pad_before_transform: bool,
@@ -253,9 +261,11 @@ def __getitem__(self, idx):
253261
if batch_idx == (len(index_dataloader) - 1):
254262
# last batch, incomplete
255263
index_batch = torch.cat((last_batch, index_batch), dim=0)
256-
assert index_batch[-1] + 1 == len(inputs), (
257-
f"Last batch index {index_batch[-1]} + 1 should be equal to the length of inputs {len(inputs)}."
258-
)
264+
265+
if index_batch[-1] + 1 != len(inputs):
266+
raise ValueError(
267+
f"Last batch index {index_batch[-1]} + 1 should be equal to the length of inputs {len(inputs)}."
268+
)
259269

260270
# Batch start and end so that `batch_size` size with the last batch including 2 batches
261271
batch_start_idx, batch_end_idx = index_batch[0], index_batch[-1] + 1
@@ -469,9 +479,6 @@ def fit(
469479
if logdir is not None:
470480
self.save(logdir, f"checkpoint_{num_steps:#07d}.pth")
471481

472-
assert hasattr(self, "n_features")
473-
assert hasattr(self, "num_sessions")
474-
475482
def step(self, batch: cebra.data.Batch) -> dict:
476483
"""Perform a single gradient update.
477484
@@ -515,7 +522,10 @@ def validation(self,
515522
Returns:
516523
Loss averaged over iterations on data batch.
517524
"""
518-
assert (session_id is None) or (session_id == 0)
525+
if session_id is not None and session_id != 0:
526+
raise ValueError(
527+
f"session_id should be set to None or 0, got {session_id}")
528+
519529
iterator = self._get_loader(loader)
520530
total_loss = Meter()
521531
self.model.eval()
@@ -544,7 +554,6 @@ def decoding(self, train_loader, valid_loader):
544554
)
545555
return decode_metric
546556

547-
@abc.abstractmethod
548557
def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: int):
549558
"""Check that the inputs can be inferred using the selected model.
550559
@@ -557,7 +566,13 @@ def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: int):
557566
the number of sessions -1 for multisession, and set to
558567
``None`` for single session.
559568
"""
560-
raise NotImplementedError
569+
if isinstance(inputs, list):
570+
raise ValueError(
571+
"Inputs to transform() should be the data for a single session, but received a list."
572+
)
573+
elif not isinstance(inputs, torch.Tensor):
574+
raise ValueError(
575+
f"Inputs should be a torch.Tensor, not {type(inputs)}.")
561576

562577
@abc.abstractmethod
563578
def _check_is_session_id_valid(self, session_id: Optional[int] = None):
@@ -568,7 +583,6 @@ def _check_is_session_id_valid(self, session_id: Optional[int] = None):
568583
"""
569584
raise NotImplementedError
570585

571-
@abc.abstractmethod
572586
def _select_model(
573587
self, inputs: Union[torch.Tensor,
574588
List[torch.Tensor]], session_id: Optional[int]
@@ -585,6 +599,25 @@ def _select_model(
585599
Returns:
586600
The model (first returns) and the offset of the model (second returns).
587601
"""
602+
model = self._get_model(session_id=session_id)
603+
offset = model.get_offset()
604+
605+
self._check_is_inputs_valid(inputs, session_id=session_id)
606+
return model, offset
607+
608+
@abc.abstractmethod
609+
def _get_model(self,
610+
session_id: Optional[int] = None) -> cebra.models.Model:
611+
"""Get the model to use for inference.
612+
613+
Args:
614+
session_id: The session ID, an :py:class:`int` between 0 and
615+
the number of sessions -1 for multisession, and set to
616+
``None`` for single session.
617+
618+
Returns:
619+
The model.
620+
"""
588621
raise NotImplementedError
589622

590623
def _check_is_fitted(self):
@@ -602,7 +635,7 @@ def _check_is_fitted(self):
602635

603636
@torch.no_grad()
604637
def transform(self,
605-
inputs: Union[torch.Tensor, List[torch.Tensor], npt.NDArray],
638+
inputs: torch.Tensor,
606639
pad_before_transform: Optional[bool] = True,
607640
session_id: Optional[int] = None,
608641
batch_size: Optional[int] = None) -> torch.Tensor:
@@ -627,26 +660,40 @@ def transform(self,
627660
Returns:
628661
The output embedding.
629662
"""
630-
if isinstance(inputs, list):
631-
raise ValueError(
632-
"Inputs to transform() should be the data for a single session, but received a list."
633-
)
634-
elif not isinstance(inputs, torch.Tensor):
635-
raise ValueError(
636-
f"Inputs should be a torch.Tensor, not {type(inputs)}.")
637-
638663
self._check_is_fitted()
639-
640664
model, offset = self._select_model(inputs, session_id)
641665

642666
if len(offset) < 2 and pad_before_transform:
643667
pad_before_transform = False
644668

645669
model.eval()
670+
return self._transform(model=model,
671+
inputs=inputs,
672+
pad_before_transform=pad_before_transform,
673+
offset=offset,
674+
batch_size=batch_size)
675+
676+
@torch.no_grad()
677+
def _transform(self, model: cebra.models.Model, inputs: torch.Tensor,
678+
pad_before_transform: bool,
679+
offset: cebra.data.datatypes.Offset,
680+
batch_size: Optional[int]) -> torch.Tensor:
681+
"""Compute the embedding on the inputs using the model provided.
682+
683+
Args:
684+
model: Model to use for inference.
685+
inputs: Data.
686+
pad_before_transform: If True zero-pad the batched data.
687+
offset: Offset of the model to consider when padding.
688+
batch_size: If not None, batched inference will not be applied.
689+
690+
Returns:
691+
The embedding.
692+
"""
646693
if batch_size is not None and inputs.shape[0] > int(
647-
batch_size * 2) and not isinstance(
648-
self.model, cebra.models.ResampleModelMixin):
649-
# NOTE: resampling models are not supported for batched inference.
694+
batch_size * 2) and not (isinstance(
695+
self._get_model(0), cebra.models.ResampleModelMixin)):
696+
# NOTE(celia): resampling models are not supported for batched inference.
650697
output = _batched_transform(
651698
model=model,
652699
inputs=inputs,
@@ -655,11 +702,11 @@ def transform(self,
655702
pad_before_transform=pad_before_transform,
656703
)
657704
else:
658-
output = _transform(model=model,
659-
inputs=inputs,
660-
offset=offset,
661-
pad_before_transform=pad_before_transform)
662-
705+
output = _not_batched_transform(
706+
model=model,
707+
inputs=inputs,
708+
offset=offset,
709+
pad_before_transform=pad_before_transform)
663710
return output
664711

665712
@abc.abstractmethod
@@ -838,3 +885,37 @@ def step(self, batch: cebra.data.Batch) -> dict:
838885
time_neg=time_uniform.item(),
839886
time_total=time_loss.item(),
840887
)
888+
889+
890+
class AuxiliaryVariableSolver(Solver):
891+
892+
@torch.no_grad()
893+
def transform(self,
894+
inputs: torch.Tensor,
895+
pad_before_transform: bool = True,
896+
session_id: Optional[int] = None,
897+
batch_size: Optional[int] = None,
898+
use_reference_model: bool = False) -> torch.Tensor:
899+
"""Compute the embedding.
900+
This function by default use ``model`` that was trained to encode the positive
901+
and negative samples. To use ``reference_model`` instead of ``model``
902+
``use_reference_model`` should be equal ``True``.
903+
Args:
904+
inputs: The input signal
905+
use_reference_model: Flag for using ``reference_model``
906+
Returns:
907+
The output embedding.
908+
"""
909+
self._check_is_fitted()
910+
model, offset = self._select_model(
911+
inputs, session_id, use_reference_model=use_reference_model)
912+
913+
if len(offset) < 2 and pad_before_transform:
914+
pad_before_transform = False
915+
916+
model.eval()
917+
return self._transform(model=model,
918+
inputs=inputs,
919+
pad_before_transform=pad_before_transform,
920+
offset=offset,
921+
batch_size=batch_size)

0 commit comments

Comments
 (0)