Skip to content

Commit b83421d

Browse files
committed
Update tests and duplicate code based on review
1 parent cc8671c commit b83421d

File tree

8 files changed

+212
-341
lines changed

8 files changed

+212
-341
lines changed

cebra/integrations/sklearn/cebra.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,8 @@ def _configure_for_all(
830830

831831
def _select_model(self, X: Union[npt.NDArray, torch.Tensor],
832832
session_id: int):
833+
if isinstance(X, np.ndarray):
834+
X = torch.from_numpy(X)
833835
return self.solver_._select_model(X, session_id=session_id)
834836

835837
def _check_labels_types(self, y: tuple, session_id: Optional[int] = None):
@@ -1061,7 +1063,8 @@ def _partial_fit(
10611063
self.model_ = model
10621064

10631065
self.n_features_ = solver.n_features
1064-
self.num_sessions_ = solver.num_sessions
1066+
self.num_sessions_ = solver.num_sessions if hasattr(
1067+
solver, "num_sessions") else None
10651068
self.solver_ = solver
10661069
self.n_features_in_ = ([model[n].num_input for n in range(len(model))]
10671070
if is_multisession else model.num_input)
@@ -1247,10 +1250,6 @@ def transform(self,
12471250
if isinstance(X, np.ndarray):
12481251
X = torch.from_numpy(X)
12491252

1250-
if batch_size is not None and batch_size < 1:
1251-
raise ValueError(
1252-
f"Batch size should be at least 1, got {batch_size}")
1253-
12541253
with torch.no_grad():
12551254
output = self.solver_.transform(
12561255
inputs=X,

cebra/solver/base.py

Lines changed: 111 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
3737

3838
import literate_dataclasses as dataclasses
39-
import numpy.typing as npt
4039
import torch
4140
import torch.nn.functional as F
4241
from torch.utils.data import DataLoader
@@ -104,6 +103,15 @@ def _add_batched_zero_padding(batched_data: torch.Tensor,
104103
Returns:
105104
The padded batch.
106105
"""
106+
if batch_start_idx > batch_end_idx:
107+
raise ValueError(
108+
f"batch_start_idx ({batch_start_idx}) cannot be greater than batch_end_idx ({batch_end_idx})."
109+
)
110+
if batch_start_idx < 0 or batch_end_idx < 0:
111+
raise ValueError(
112+
f"batch_start_idx ({batch_start_idx}) and batch_end_idx ({batch_end_idx}) must be positive integers."
113+
)
114+
107115
reversed_dims = torch.arange(batched_data.ndim - 1, -1, -1)
108116

109117
if batch_start_idx == 0: # First batch
@@ -179,7 +187,7 @@ def _inference_transform(model: cebra.models.Model,
179187
return output
180188

181189

182-
def _transform(
190+
def _not_batched_transform(
183191
model: cebra.models.Model,
184192
inputs: torch.Tensor,
185193
pad_before_transform: bool,
@@ -253,9 +261,11 @@ def __getitem__(self, idx):
253261
if batch_idx == (len(index_dataloader) - 1):
254262
# last batch, incomplete
255263
index_batch = torch.cat((last_batch, index_batch), dim=0)
256-
assert index_batch[-1] + 1 == len(inputs), (
257-
f"Last batch index {index_batch[-1]} + 1 should be equal to the length of inputs {len(inputs)}."
258-
)
264+
265+
if index_batch[-1] + 1 != len(inputs):
266+
raise ValueError(
267+
f"Last batch index {index_batch[-1]} + 1 should be equal to the length of inputs {len(inputs)}."
268+
)
259269

260270
# Batch start and end so that `batch_size` size with the last batch including 2 batches
261271
batch_start_idx, batch_end_idx = index_batch[0], index_batch[-1] + 1
@@ -494,9 +504,6 @@ def fit(
494504
if logdir is not None:
495505
self.save(logdir, f"checkpoint_{num_steps:#07d}.pth")
496506

497-
assert hasattr(self, "n_features")
498-
assert hasattr(self, "num_sessions")
499-
500507
def step(self, batch: cebra.data.Batch) -> dict:
501508
"""Perform a single gradient update.
502509
@@ -540,7 +547,10 @@ def validation(self,
540547
Returns:
541548
Loss averaged over iterations on data batch.
542549
"""
543-
assert (session_id is None) or (session_id == 0)
550+
if session_id is not None and session_id != 0:
551+
raise ValueError(
552+
f"session_id should be set to None or 0, got {session_id}")
553+
544554
iterator = self._get_loader(loader)
545555
total_loss = Meter()
546556
self.model.eval()
@@ -569,7 +579,6 @@ def decoding(self, train_loader, valid_loader):
569579
)
570580
return decode_metric
571581

572-
@abc.abstractmethod
573582
def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: int):
574583
"""Check that the inputs can be inferred using the selected model.
575584
@@ -582,7 +591,13 @@ def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: int):
582591
the number of sessions -1 for multisession, and set to
583592
``None`` for single session.
584593
"""
585-
raise NotImplementedError
594+
if isinstance(inputs, list):
595+
raise ValueError(
596+
"Inputs to transform() should be the data for a single session, but received a list."
597+
)
598+
elif not isinstance(inputs, torch.Tensor):
599+
raise ValueError(
600+
f"Inputs should be a torch.Tensor, not {type(inputs)}.")
586601

587602
@abc.abstractmethod
588603
def _check_is_session_id_valid(self, session_id: Optional[int] = None):
@@ -593,7 +608,6 @@ def _check_is_session_id_valid(self, session_id: Optional[int] = None):
593608
"""
594609
raise NotImplementedError
595610

596-
@abc.abstractmethod
597611
def _select_model(
598612
self, inputs: Union[torch.Tensor,
599613
List[torch.Tensor]], session_id: Optional[int]
@@ -610,6 +624,25 @@ def _select_model(
610624
Returns:
611625
The model (first returns) and the offset of the model (second returns).
612626
"""
627+
model = self._get_model(session_id=session_id)
628+
offset = model.get_offset()
629+
630+
self._check_is_inputs_valid(inputs, session_id=session_id)
631+
return model, offset
632+
633+
@abc.abstractmethod
634+
def _get_model(self,
635+
session_id: Optional[int] = None) -> cebra.models.Model:
636+
"""Get the model to use for inference.
637+
638+
Args:
639+
session_id: The session ID, an :py:class:`int` between 0 and
640+
the number of sessions -1 for multisession, and set to
641+
``None`` for single session.
642+
643+
Returns:
644+
The model.
645+
"""
613646
raise NotImplementedError
614647

615648
def _check_is_fitted(self):
@@ -627,7 +660,7 @@ def _check_is_fitted(self):
627660

628661
@torch.no_grad()
629662
def transform(self,
630-
inputs: Union[torch.Tensor, List[torch.Tensor], npt.NDArray],
663+
inputs: torch.Tensor,
631664
pad_before_transform: Optional[bool] = True,
632665
session_id: Optional[int] = None,
633666
batch_size: Optional[int] = None) -> torch.Tensor:
@@ -652,26 +685,40 @@ def transform(self,
652685
Returns:
653686
The output embedding.
654687
"""
655-
if isinstance(inputs, list):
656-
raise ValueError(
657-
"Inputs to transform() should be the data for a single session, but received a list."
658-
)
659-
elif not isinstance(inputs, torch.Tensor):
660-
raise ValueError(
661-
f"Inputs should be a torch.Tensor, not {type(inputs)}.")
662-
663688
self._check_is_fitted()
664-
665689
model, offset = self._select_model(inputs, session_id)
666690

667691
if len(offset) < 2 and pad_before_transform:
668692
pad_before_transform = False
669693

670694
model.eval()
695+
return self._transform(model=model,
696+
inputs=inputs,
697+
pad_before_transform=pad_before_transform,
698+
offset=offset,
699+
batch_size=batch_size)
700+
701+
@torch.no_grad()
702+
def _transform(self, model: cebra.models.Model, inputs: torch.Tensor,
703+
pad_before_transform: bool,
704+
offset: cebra.data.datatypes.Offset,
705+
batch_size: Optional[int]) -> torch.Tensor:
706+
"""Compute the embedding on the inputs using the model provided.
707+
708+
Args:
709+
model: Model to use for inference.
710+
inputs: Data.
711+
pad_before_transform: If True zero-pad the batched data.
712+
offset: Offset of the model to consider when padding.
713+
batch_size: If not None, batched inference will not be applied.
714+
715+
Returns:
716+
The embedding.
717+
"""
671718
if batch_size is not None and inputs.shape[0] > int(
672-
batch_size * 2) and not isinstance(
673-
self.model, cebra.models.ResampleModelMixin):
674-
# NOTE: resampling models are not supported for batched inference.
719+
batch_size * 2) and not (isinstance(
720+
self._get_model(0), cebra.models.ResampleModelMixin)):
721+
# NOTE(celia): resampling models are not supported for batched inference.
675722
output = _batched_transform(
676723
model=model,
677724
inputs=inputs,
@@ -680,11 +727,11 @@ def transform(self,
680727
pad_before_transform=pad_before_transform,
681728
)
682729
else:
683-
output = _transform(model=model,
684-
inputs=inputs,
685-
offset=offset,
686-
pad_before_transform=pad_before_transform)
687-
730+
output = _not_batched_transform(
731+
model=model,
732+
inputs=inputs,
733+
offset=offset,
734+
pad_before_transform=pad_before_transform)
688735
return output
689736

690737
@abc.abstractmethod
@@ -863,3 +910,37 @@ def step(self, batch: cebra.data.Batch) -> dict:
863910
time_neg=time_uniform.item(),
864911
time_total=time_loss.item(),
865912
)
913+
914+
915+
class AuxiliaryVariableSolver(Solver):
916+
917+
@torch.no_grad()
918+
def transform(self,
919+
inputs: torch.Tensor,
920+
pad_before_transform: bool = True,
921+
session_id: Optional[int] = None,
922+
batch_size: Optional[int] = None,
923+
use_reference_model: bool = False) -> torch.Tensor:
924+
"""Compute the embedding.
925+
This function by default use ``model`` that was trained to encode the positive
926+
and negative samples. To use ``reference_model`` instead of ``model``
927+
``use_reference_model`` should be equal ``True``.
928+
Args:
929+
inputs: The input signal
930+
use_reference_model: Flag for using ``reference_model``
931+
Returns:
932+
The output embedding.
933+
"""
934+
self._check_is_fitted()
935+
model, offset = self._select_model(
936+
inputs, session_id, use_reference_model=use_reference_model)
937+
938+
if len(offset) < 2 and pad_before_transform:
939+
pad_before_transform = False
940+
941+
model.eval()
942+
return self._transform(model=model,
943+
inputs=inputs,
944+
pad_before_transform=pad_before_transform,
945+
offset=offset,
946+
batch_size=batch_size)

0 commit comments

Comments
 (0)