3636from typing import Callable , Dict , List , Literal , Optional , Tuple , Union
3737
3838import literate_dataclasses as dataclasses
39- import numpy .typing as npt
4039import torch
4140import torch .nn .functional as F
4241from torch .utils .data import DataLoader
@@ -104,6 +103,15 @@ def _add_batched_zero_padding(batched_data: torch.Tensor,
104103 Returns:
105104 The padded batch.
106105 """
106+ if batch_start_idx > batch_end_idx :
107+ raise ValueError (
108+ f"batch_start_idx ({ batch_start_idx } ) cannot be greater than batch_end_idx ({ batch_end_idx } )."
109+ )
110+ if batch_start_idx < 0 or batch_end_idx < 0 :
111+ raise ValueError (
112+ f"batch_start_idx ({ batch_start_idx } ) and batch_end_idx ({ batch_end_idx } ) must be positive integers."
113+ )
114+
107115 reversed_dims = torch .arange (batched_data .ndim - 1 , - 1 , - 1 )
108116
109117 if batch_start_idx == 0 : # First batch
@@ -179,7 +187,7 @@ def _inference_transform(model: cebra.models.Model,
179187 return output
180188
181189
182- def _transform (
190+ def _not_batched_transform (
183191 model : cebra .models .Model ,
184192 inputs : torch .Tensor ,
185193 pad_before_transform : bool ,
@@ -253,9 +261,11 @@ def __getitem__(self, idx):
253261 if batch_idx == (len (index_dataloader ) - 1 ):
254262 # last batch, incomplete
255263 index_batch = torch .cat ((last_batch , index_batch ), dim = 0 )
256- assert index_batch [- 1 ] + 1 == len (inputs ), (
257- f"Last batch index { index_batch [- 1 ]} + 1 should be equal to the length of inputs { len (inputs )} ."
258- )
264+
265+ if index_batch [- 1 ] + 1 != len (inputs ):
266+ raise ValueError (
267+ f"Last batch index { index_batch [- 1 ]} + 1 should be equal to the length of inputs { len (inputs )} ."
268+ )
259269
260270 # Batch start and end so that `batch_size` size with the last batch including 2 batches
261271 batch_start_idx , batch_end_idx = index_batch [0 ], index_batch [- 1 ] + 1
@@ -494,9 +504,6 @@ def fit(
494504 if logdir is not None :
495505 self .save (logdir , f"checkpoint_{ num_steps :#07d} .pth" )
496506
497- assert hasattr (self , "n_features" )
498- assert hasattr (self , "num_sessions" )
499-
500507 def step (self , batch : cebra .data .Batch ) -> dict :
501508 """Perform a single gradient update.
502509
@@ -540,7 +547,10 @@ def validation(self,
540547 Returns:
541548 Loss averaged over iterations on data batch.
542549 """
543- assert (session_id is None ) or (session_id == 0 )
550+ if session_id is not None and session_id != 0 :
551+ raise ValueError (
552+ f"session_id should be set to None or 0, got { session_id } " )
553+
544554 iterator = self ._get_loader (loader )
545555 total_loss = Meter ()
546556 self .model .eval ()
@@ -569,7 +579,6 @@ def decoding(self, train_loader, valid_loader):
569579 )
570580 return decode_metric
571581
572- @abc .abstractmethod
573582 def _check_is_inputs_valid (self , inputs : torch .Tensor , session_id : int ):
574583 """Check that the inputs can be inferred using the selected model.
575584
@@ -582,7 +591,13 @@ def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: int):
582591 the number of sessions -1 for multisession, and set to
583592 ``None`` for single session.
584593 """
585- raise NotImplementedError
594+ if isinstance (inputs , list ):
595+ raise ValueError (
596+ "Inputs to transform() should be the data for a single session, but received a list."
597+ )
598+ elif not isinstance (inputs , torch .Tensor ):
599+ raise ValueError (
600+ f"Inputs should be a torch.Tensor, not { type (inputs )} ." )
586601
587602 @abc .abstractmethod
588603 def _check_is_session_id_valid (self , session_id : Optional [int ] = None ):
@@ -593,7 +608,6 @@ def _check_is_session_id_valid(self, session_id: Optional[int] = None):
593608 """
594609 raise NotImplementedError
595610
596- @abc .abstractmethod
597611 def _select_model (
598612 self , inputs : Union [torch .Tensor ,
599613 List [torch .Tensor ]], session_id : Optional [int ]
@@ -610,6 +624,25 @@ def _select_model(
610624 Returns:
611625 The model (first returns) and the offset of the model (second returns).
612626 """
627+ model = self ._get_model (session_id = session_id )
628+ offset = model .get_offset ()
629+
630+ self ._check_is_inputs_valid (inputs , session_id = session_id )
631+ return model , offset
632+
633+ @abc .abstractmethod
634+ def _get_model (self ,
635+ session_id : Optional [int ] = None ) -> cebra .models .Model :
636+ """Get the model to use for inference.
637+
638+ Args:
639+ session_id: The session ID, an :py:class:`int` between 0 and
640+ the number of sessions -1 for multisession, and set to
641+ ``None`` for single session.
642+
643+ Returns:
644+ The model.
645+ """
613646 raise NotImplementedError
614647
615648 def _check_is_fitted (self ):
@@ -627,7 +660,7 @@ def _check_is_fitted(self):
627660
628661 @torch .no_grad ()
629662 def transform (self ,
630- inputs : Union [ torch .Tensor , List [ torch . Tensor ], npt . NDArray ] ,
663+ inputs : torch .Tensor ,
631664 pad_before_transform : Optional [bool ] = True ,
632665 session_id : Optional [int ] = None ,
633666 batch_size : Optional [int ] = None ) -> torch .Tensor :
@@ -652,26 +685,40 @@ def transform(self,
652685 Returns:
653686 The output embedding.
654687 """
655- if isinstance (inputs , list ):
656- raise ValueError (
657- "Inputs to transform() should be the data for a single session, but received a list."
658- )
659- elif not isinstance (inputs , torch .Tensor ):
660- raise ValueError (
661- f"Inputs should be a torch.Tensor, not { type (inputs )} ." )
662-
663688 self ._check_is_fitted ()
664-
665689 model , offset = self ._select_model (inputs , session_id )
666690
667691 if len (offset ) < 2 and pad_before_transform :
668692 pad_before_transform = False
669693
670694 model .eval ()
695+ return self ._transform (model = model ,
696+ inputs = inputs ,
697+ pad_before_transform = pad_before_transform ,
698+ offset = offset ,
699+ batch_size = batch_size )
700+
701+ @torch .no_grad ()
702+ def _transform (self , model : cebra .models .Model , inputs : torch .Tensor ,
703+ pad_before_transform : bool ,
704+ offset : cebra .data .datatypes .Offset ,
705+ batch_size : Optional [int ]) -> torch .Tensor :
706+ """Compute the embedding on the inputs using the model provided.
707+
708+ Args:
709+ model: Model to use for inference.
710+ inputs: Data.
711+ pad_before_transform: If True zero-pad the batched data.
712+ offset: Offset of the model to consider when padding.
713+ batch_size: If not None, batched inference will not be applied.
714+
715+ Returns:
716+ The embedding.
717+ """
671718 if batch_size is not None and inputs .shape [0 ] > int (
672- batch_size * 2 ) and not isinstance (
673- self .model , cebra .models .ResampleModelMixin ):
674- # NOTE: resampling models are not supported for batched inference.
719+ batch_size * 2 ) and not ( isinstance (
720+ self ._get_model ( 0 ) , cebra .models .ResampleModelMixin ) ):
721+ # NOTE(celia) : resampling models are not supported for batched inference.
675722 output = _batched_transform (
676723 model = model ,
677724 inputs = inputs ,
@@ -680,11 +727,11 @@ def transform(self,
680727 pad_before_transform = pad_before_transform ,
681728 )
682729 else :
683- output = _transform ( model = model ,
684- inputs = inputs ,
685- offset = offset ,
686- pad_before_transform = pad_before_transform )
687-
730+ output = _not_batched_transform (
731+ model = model ,
732+ inputs = inputs ,
733+ offset = offset ,
734+ pad_before_transform = pad_before_transform )
688735 return output
689736
690737 @abc .abstractmethod
@@ -863,3 +910,37 @@ def step(self, batch: cebra.data.Batch) -> dict:
863910 time_neg = time_uniform .item (),
864911 time_total = time_loss .item (),
865912 )
913+
914+
915+ class AuxiliaryVariableSolver (Solver ):
916+
917+ @torch .no_grad ()
918+ def transform (self ,
919+ inputs : torch .Tensor ,
920+ pad_before_transform : bool = True ,
921+ session_id : Optional [int ] = None ,
922+ batch_size : Optional [int ] = None ,
923+ use_reference_model : bool = False ) -> torch .Tensor :
924+ """Compute the embedding.
925+ This function by default use ``model`` that was trained to encode the positive
926+ and negative samples. To use ``reference_model`` instead of ``model``
927+ ``use_reference_model`` should be equal ``True``.
928+ Args:
929+ inputs: The input signal
930+ use_reference_model: Flag for using ``reference_model``
931+ Returns:
932+ The output embedding.
933+ """
934+ self ._check_is_fitted ()
935+ model , offset = self ._select_model (
936+ inputs , session_id , use_reference_model = use_reference_model )
937+
938+ if len (offset ) < 2 and pad_before_transform :
939+ pad_before_transform = False
940+
941+ model .eval ()
942+ return self ._transform (model = model ,
943+ inputs = inputs ,
944+ pad_before_transform = pad_before_transform ,
945+ offset = offset ,
946+ batch_size = batch_size )
0 commit comments