3636from typing import Callable , Dict , List , Literal , Optional , Tuple , Union
3737
3838import literate_dataclasses as dataclasses
39- import numpy .typing as npt
4039import torch
4140import torch .nn .functional as F
4241from torch .utils .data import DataLoader
@@ -104,6 +103,15 @@ def _add_batched_zero_padding(batched_data: torch.Tensor,
104103 Returns:
105104 The padded batch.
106105 """
106+ if batch_start_idx > batch_end_idx :
107+ raise ValueError (
108+ f"batch_start_idx ({ batch_start_idx } ) cannot be greater than batch_end_idx ({ batch_end_idx } )."
109+ )
110+ if batch_start_idx < 0 or batch_end_idx < 0 :
111+ raise ValueError (
112+ f"batch_start_idx ({ batch_start_idx } ) and batch_end_idx ({ batch_end_idx } ) must be positive integers."
113+ )
114+
107115 reversed_dims = torch .arange (batched_data .ndim - 1 , - 1 , - 1 )
108116
109117 if batch_start_idx == 0 : # First batch
@@ -179,7 +187,7 @@ def _inference_transform(model: cebra.models.Model,
179187 return output
180188
181189
182- def _transform (
190+ def _not_batched_transform (
183191 model : cebra .models .Model ,
184192 inputs : torch .Tensor ,
185193 pad_before_transform : bool ,
@@ -253,9 +261,11 @@ def __getitem__(self, idx):
253261 if batch_idx == (len (index_dataloader ) - 1 ):
254262 # last batch, incomplete
255263 index_batch = torch .cat ((last_batch , index_batch ), dim = 0 )
256- assert index_batch [- 1 ] + 1 == len (inputs ), (
257- f"Last batch index { index_batch [- 1 ]} + 1 should be equal to the length of inputs { len (inputs )} ."
258- )
264+
265+ if index_batch [- 1 ] + 1 != len (inputs ):
266+ raise ValueError (
267+ f"Last batch index { index_batch [- 1 ]} + 1 should be equal to the length of inputs { len (inputs )} ."
268+ )
259269
260270 # Batch start and end so that `batch_size` size with the last batch including 2 batches
261271 batch_start_idx , batch_end_idx = index_batch [0 ], index_batch [- 1 ] + 1
@@ -469,9 +479,6 @@ def fit(
469479 if logdir is not None :
470480 self .save (logdir , f"checkpoint_{ num_steps :#07d} .pth" )
471481
472- assert hasattr (self , "n_features" )
473- assert hasattr (self , "num_sessions" )
474-
475482 def step (self , batch : cebra .data .Batch ) -> dict :
476483 """Perform a single gradient update.
477484
@@ -515,7 +522,10 @@ def validation(self,
515522 Returns:
516523 Loss averaged over iterations on data batch.
517524 """
518- assert (session_id is None ) or (session_id == 0 )
525+ if session_id is not None and session_id != 0 :
526+ raise ValueError (
527+ f"session_id should be set to None or 0, got { session_id } " )
528+
519529 iterator = self ._get_loader (loader )
520530 total_loss = Meter ()
521531 self .model .eval ()
@@ -544,7 +554,6 @@ def decoding(self, train_loader, valid_loader):
544554 )
545555 return decode_metric
546556
547- @abc .abstractmethod
548557 def _check_is_inputs_valid (self , inputs : torch .Tensor , session_id : int ):
549558 """Check that the inputs can be inferred using the selected model.
550559
@@ -557,7 +566,13 @@ def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: int):
557566 the number of sessions -1 for multisession, and set to
558567 ``None`` for single session.
559568 """
560- raise NotImplementedError
569+ if isinstance (inputs , list ):
570+ raise ValueError (
571+ "Inputs to transform() should be the data for a single session, but received a list."
572+ )
573+ elif not isinstance (inputs , torch .Tensor ):
574+ raise ValueError (
575+ f"Inputs should be a torch.Tensor, not { type (inputs )} ." )
561576
562577 @abc .abstractmethod
563578 def _check_is_session_id_valid (self , session_id : Optional [int ] = None ):
@@ -568,7 +583,6 @@ def _check_is_session_id_valid(self, session_id: Optional[int] = None):
568583 """
569584 raise NotImplementedError
570585
571- @abc .abstractmethod
572586 def _select_model (
573587 self , inputs : Union [torch .Tensor ,
574588 List [torch .Tensor ]], session_id : Optional [int ]
@@ -585,6 +599,25 @@ def _select_model(
585599 Returns:
586600 The model (first returns) and the offset of the model (second returns).
587601 """
602+ model = self ._get_model (session_id = session_id )
603+ offset = model .get_offset ()
604+
605+ self ._check_is_inputs_valid (inputs , session_id = session_id )
606+ return model , offset
607+
608+ @abc .abstractmethod
609+ def _get_model (self ,
610+ session_id : Optional [int ] = None ) -> cebra .models .Model :
611+ """Get the model to use for inference.
612+
613+ Args:
614+ session_id: The session ID, an :py:class:`int` between 0 and
615+ the number of sessions -1 for multisession, and set to
616+ ``None`` for single session.
617+
618+ Returns:
619+ The model.
620+ """
588621 raise NotImplementedError
589622
590623 def _check_is_fitted (self ):
@@ -602,7 +635,7 @@ def _check_is_fitted(self):
602635
603636 @torch .no_grad ()
604637 def transform (self ,
605- inputs : Union [ torch .Tensor , List [ torch . Tensor ], npt . NDArray ] ,
638+ inputs : torch .Tensor ,
606639 pad_before_transform : Optional [bool ] = True ,
607640 session_id : Optional [int ] = None ,
608641 batch_size : Optional [int ] = None ) -> torch .Tensor :
@@ -627,26 +660,40 @@ def transform(self,
627660 Returns:
628661 The output embedding.
629662 """
630- if isinstance (inputs , list ):
631- raise ValueError (
632- "Inputs to transform() should be the data for a single session, but received a list."
633- )
634- elif not isinstance (inputs , torch .Tensor ):
635- raise ValueError (
636- f"Inputs should be a torch.Tensor, not { type (inputs )} ." )
637-
638663 self ._check_is_fitted ()
639-
640664 model , offset = self ._select_model (inputs , session_id )
641665
642666 if len (offset ) < 2 and pad_before_transform :
643667 pad_before_transform = False
644668
645669 model .eval ()
670+ return self ._transform (model = model ,
671+ inputs = inputs ,
672+ pad_before_transform = pad_before_transform ,
673+ offset = offset ,
674+ batch_size = batch_size )
675+
676+ @torch .no_grad ()
677+ def _transform (self , model : cebra .models .Model , inputs : torch .Tensor ,
678+ pad_before_transform : bool ,
679+ offset : cebra .data .datatypes .Offset ,
680+ batch_size : Optional [int ]) -> torch .Tensor :
681+ """Compute the embedding on the inputs using the model provided.
682+
683+ Args:
684+ model: Model to use for inference.
685+ inputs: Data.
686+ pad_before_transform: If True zero-pad the batched data.
687+ offset: Offset of the model to consider when padding.
688+ batch_size: If not None, batched inference will not be applied.
689+
690+ Returns:
691+ The embedding.
692+ """
646693 if batch_size is not None and inputs .shape [0 ] > int (
647- batch_size * 2 ) and not isinstance (
648- self .model , cebra .models .ResampleModelMixin ):
649- # NOTE: resampling models are not supported for batched inference.
694+ batch_size * 2 ) and not ( isinstance (
695+ self ._get_model ( 0 ) , cebra .models .ResampleModelMixin ) ):
696+ # NOTE(celia) : resampling models are not supported for batched inference.
650697 output = _batched_transform (
651698 model = model ,
652699 inputs = inputs ,
@@ -655,11 +702,11 @@ def transform(self,
655702 pad_before_transform = pad_before_transform ,
656703 )
657704 else :
658- output = _transform ( model = model ,
659- inputs = inputs ,
660- offset = offset ,
661- pad_before_transform = pad_before_transform )
662-
705+ output = _not_batched_transform (
706+ model = model ,
707+ inputs = inputs ,
708+ offset = offset ,
709+ pad_before_transform = pad_before_transform )
663710 return output
664711
665712 @abc .abstractmethod
@@ -838,3 +885,37 @@ def step(self, batch: cebra.data.Batch) -> dict:
838885 time_neg = time_uniform .item (),
839886 time_total = time_loss .item (),
840887 )
888+
889+
890+ class AuxiliaryVariableSolver (Solver ):
891+
892+ @torch .no_grad ()
893+ def transform (self ,
894+ inputs : torch .Tensor ,
895+ pad_before_transform : bool = True ,
896+ session_id : Optional [int ] = None ,
897+ batch_size : Optional [int ] = None ,
898+ use_reference_model : bool = False ) -> torch .Tensor :
899+ """Compute the embedding.
900+ This function by default use ``model`` that was trained to encode the positive
901+ and negative samples. To use ``reference_model`` instead of ``model``
902+ ``use_reference_model`` should be equal ``True``.
903+ Args:
904+ inputs: The input signal
905+ use_reference_model: Flag for using ``reference_model``
906+ Returns:
907+ The output embedding.
908+ """
909+ self ._check_is_fitted ()
910+ model , offset = self ._select_model (
911+ inputs , session_id , use_reference_model = use_reference_model )
912+
913+ if len (offset ) < 2 and pad_before_transform :
914+ pad_before_transform = False
915+
916+ model .eval ()
917+ return self ._transform (model = model ,
918+ inputs = inputs ,
919+ pad_before_transform = pad_before_transform ,
920+ offset = offset ,
921+ batch_size = batch_size )
0 commit comments