2222"""Define the CEBRA model."""
2323
2424import itertools
25+ import warnings
2526from typing import (Callable , Dict , Iterable , List , Literal , Optional , Tuple ,
2627 Union )
2728
@@ -129,7 +130,7 @@ def _init_loader(
129130 (not is_cont , not is_disc , is_multi ),
130131 ]
131132 if any (all (combination ) for combination in incompatible_combinations ):
132- raise ValueError (f "Invalid index combination.\n "
133+ raise ValueError ("Invalid index combination.\n "
133134 f"Continuous: { is_cont } ,\n "
134135 f"Discrete: { is_disc } ,\n "
135136 f"Hybrid training: { is_hybrid } ,\n "
@@ -293,7 +294,7 @@ def _require_arg(key):
293294 "single-session" ,
294295 )
295296
296- error_message = (f "Invalid index combination.\n "
297+ error_message = ("Invalid index combination.\n "
297298 f"Continuous: { is_cont } ,\n "
298299 f"Discrete: { is_disc } ,\n "
299300 f"Hybrid training: { is_hybrid } ,\n "
@@ -340,7 +341,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA":
340341 if missing_keys :
341342 raise ValueError (
342343 f"Missing keys in data dictionary: { ', ' .join (missing_keys )} . "
343- f "You can try loading the CEBRA model with the torch backend." )
344+ "You can try loading the CEBRA model with the torch backend." )
344345
345346 args , state , state_dict = cebra_info ['args' ], cebra_info [
346347 'state' ], cebra_info ['state_dict' ]
@@ -656,12 +657,12 @@ def _get_dataset_multi(X: List[Iterable], y: List[Iterable]):
656657 # TODO(celia): to make it work for multiple set of index. For now, y should be a tuple of one list only
657658 if isinstance (y , tuple ) and len (y ) > 1 :
658659 raise NotImplementedError (
659- f "Support for multiple set of index is not implemented in multissesion training, "
660+ "Support for multiple set of index is not implemented in multissesion training, "
660661 f"got { len (y )} sets of indexes." )
661662
662663 if not _are_sessions_equal (X , y ):
663664 raise ValueError (
664- f "Invalid number of sessions: number of sessions in X and y need to match, "
665+ "Invalid number of sessions: number of sessions in X and y need to match, "
665666 f"got X:{ len (X )} and y:{ [len (y_i ) for y_i in y ]} ." )
666667
667668 for session in range (len (X )):
@@ -685,8 +686,8 @@ def _get_dataset_multi(X: List[Iterable], y: List[Iterable]):
685686 else :
686687 if not _are_sessions_equal (X , y ):
687688 raise ValueError (
688- f "Invalid number of samples or labels sessions: provide one session for single-session training, "
689- f "and make sure the number of samples in X and y need match, "
689+ "Invalid number of samples or labels sessions: provide one session for single-session training, "
690+ "and make sure the number of samples in X and y need match, "
690691 f"got { len (X )} and { [len (y_i ) for y_i in y ]} ." )
691692 is_multisession = False
692693 dataset = _get_dataset (X , y )
@@ -848,7 +849,7 @@ def _check_labels_types(self, y: tuple, session_id: Optional[int] = None):
848849 # Check that same number of index
849850 if len (self .label_types_ ) != n_idx :
850851 raise ValueError (
851- f "Number of index invalid: labels must have the same number of index as for fitting,"
852+ "Number of index invalid: labels must have the same number of index as for fitting,"
852853 f"expects { len (self .label_types_ )} , got { n_idx } idx." )
853854
854855 for i in range (len (self .label_types_ )): # for each index
@@ -861,12 +862,12 @@ def _check_labels_types(self, y: tuple, session_id: Optional[int] = None):
861862 > 1 ): # is there more than one feature in the index
862863 if label_types_idx [1 ][1 ] != y [i ].shape [1 ]:
863864 raise ValueError (
864- f "Labels invalid: must have the same number of features as the ones used for fitting,"
865+ "Labels invalid: must have the same number of features as the ones used for fitting,"
865866 f"expects { label_types_idx [1 ]} , got { y [i ].shape } ." )
866867
867868 if label_types_idx [0 ] != y [i ].dtype :
868869 raise ValueError (
869- f "Labels invalid: must have the same type of features as the ones used for fitting,"
870+ "Labels invalid: must have the same type of features as the ones used for fitting,"
870871 f"expects { label_types_idx [0 ]} , got { y [i ].dtype } ." )
871872
872873 def _prepare_fit (
@@ -1254,7 +1255,8 @@ def transform(self,
12541255
12551256 return output .detach ().cpu ().numpy ()
12561257
1257- #NOTE: Deprecated: transform is now handled in the solver but kept for testing.
1258+ #NOTE: Deprecated: transform is now handled in the solver but the original
1259+ # method is kept here for testing.
12581260 def transform_deprecated (self ,
12591261 X : Union [npt .NDArray , torch .Tensor ],
12601262 session_id : Optional [int ] = None ) -> npt .NDArray :
@@ -1279,6 +1281,12 @@ def transform_deprecated(self,
12791281 >>> embedding = cebra_model.transform(dataset)
12801282
12811283 """
1284+ warnings .warn (
1285+ "The method `transform_deprecated` is deprecated "
1286+ "but kept for testing puroposes."
1287+ "We recommend using `transform` instead." ,
1288+ DeprecationWarning ,
1289+ stacklevel = 2 )
12821290
12831291 sklearn_utils_validation .check_is_fitted (self , "n_features_" )
12841292 model , offset = self ._select_model (X , session_id )
0 commit comments