@@ -791,33 +791,7 @@ def _configure_for_all(
791791
792792 def _select_model (self , X : Union [npt .NDArray , torch .Tensor ],
793793 session_id : int ):
794- # Choose the model and get its corresponding offset
795- if self .num_sessions is not None : # multisession implementation
796- if session_id is None :
797- raise RuntimeError (
798- "No session_id provided: multisession model requires a session_id to choose the model corresponding to your data shape."
799- )
800- if session_id >= self .num_sessions or session_id < 0 :
801- raise RuntimeError (
802- f"Invalid session_id { session_id } : session_id for the current multisession model must be between 0 and { self .num_sessions - 1 } ."
803- )
804- if self .n_features_ [session_id ] != X .shape [1 ]:
805- raise ValueError (
806- f"Invalid input shape: model for session { session_id } requires an input of shape"
807- f"(n_samples, { self .n_features_ [session_id ]} ), got (n_samples, { X .shape [1 ]} )."
808- )
809-
810- model = self .model_ [session_id ]
811- model .to (self .device_ )
812- else : # single session
813- if session_id is not None and session_id > 0 :
814- raise RuntimeError (
815- f"Invalid session_id { session_id } : single session models only takes an optional null session_id."
816- )
817- model = self .model_
818-
819- offset = model .get_offset ()
820- return model , offset
794+ return self .solver_ ._select_model (X , session_id = session_id )
821795
822796 def _check_labels_types (self , y : tuple , session_id : Optional [int ] = None ):
823797 """Check that the input labels are compatible with the labels used to fit the model.
@@ -1224,16 +1198,16 @@ def transform(self,
12241198 >>> embedding = cebra_model.transform(dataset)
12251199
12261200 """
1227-
1201+ self . solver_ . _check_is_session_id_valid ( session_id = session_id )
12281202 sklearn_utils_validation .check_is_fitted (self , "n_features_" )
1229- # Input validation
1230- #TODO: if inputs are in cuda, then it throws an error, deal with this.
1203+
1204+ if torch .is_tensor (X ) and X .device .type == "cuda" :
1205+ X = X .detach ().cpu ()
1206+
12311207 X = sklearn_utils .check_input_array (X , min_samples = len (self .offset_ ))
1232- #input_dtype = X.dtype
12331208
12341209 if isinstance (X , np .ndarray ):
12351210 X = torch .from_numpy (X )
1236- # TODO: which type and device should be put there?
12371211
12381212 with torch .no_grad ():
12391213 output = self .solver_ .transform (
@@ -1242,11 +1216,7 @@ def transform(self,
12421216 session_id = session_id ,
12431217 batch_size = batch_size )
12441218
1245- #TODO: check if this is safe.
1246- return output .numpy (force = True )
1247-
1248- #if input_dtype == "float64":
1249- # return output.astype(input_dtype)
1219+ return output .detach ().cpu ().numpy ()
12501220
12511221 def fit_transform (
12521222 self ,
0 commit comments