@@ -1235,7 +1235,7 @@ def transform(self,
12351235 sklearn_utils_validation .check_is_fitted (self , "n_features_" )
12361236 self .solver_ ._check_is_session_id_valid (session_id = session_id )
12371237
1238- if torch .is_tensor (X ) and X . device . type == "cuda" :
1238+ if torch .is_tensor (X ):
12391239 X = X .detach ().cpu ()
12401240
12411241 X = sklearn_utils .check_input_array (X , min_samples = len (self .offset_ ))
@@ -1256,6 +1256,60 @@ def transform(self,
12561256
12571257 return output .detach ().cpu ().numpy ()
12581258
1259+ # Deprecated, kept for testing.
1260+ def transform_deprecated (self ,
1261+ X : Union [npt .NDArray , torch .Tensor ],
1262+ session_id : Optional [int ] = None ) -> npt .NDArray :
1263+ """Transform an input sequence and return the embedding.
1264+
1265+ Args:
1266+ X: A numpy array or torch tensor of size ``time x dimension``.
1267+ session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for
1268+ multisession, set to ``None`` for single session.
1269+
1270+ Returns:
1271+ A :py:func:`numpy.array` of size ``time x output_dimension``.
1272+
1273+ Example:
1274+
1275+ >>> import cebra
1276+ >>> import numpy as np
1277+ >>> dataset = np.random.uniform(0, 1, (1000, 30))
1278+ >>> cebra_model = cebra.CEBRA(max_iterations=10)
1279+ >>> cebra_model.fit(dataset)
1280+ CEBRA(max_iterations=10)
1281+ >>> embedding = cebra_model.transform(dataset)
1282+
1283+ """
1284+
1285+ sklearn_utils_validation .check_is_fitted (self , "n_features_" )
1286+ model , offset = self ._select_model (X , session_id )
1287+
1288+ # Input validation
1289+ X = sklearn_utils .check_input_array (X , min_samples = len (self .offset_ ))
1290+ input_dtype = X .dtype
1291+
1292+ with torch .no_grad ():
1293+ model .eval ()
1294+
1295+ if self .pad_before_transform :
1296+ X = np .pad (X , ((offset .left , offset .right - 1 ), (0 , 0 )),
1297+ mode = "edge" )
1298+ X = torch .from_numpy (X ).float ().to (self .device_ )
1299+
1300+ if isinstance (model , cebra .models .ConvolutionalModelMixin ):
1301+ # Fully convolutional evaluation, switch (T, C) -> (1, C, T)
1302+ X = X .transpose (1 , 0 ).unsqueeze (0 )
1303+ output = model (X ).cpu ().numpy ().squeeze (0 ).transpose (1 , 0 )
1304+ else :
1305+ # Standard evaluation, (T, C, dt)
1306+ output = model (X ).cpu ().numpy ()
1307+
1308+ if input_dtype == "float64" :
1309+ return output .astype (input_dtype )
1310+
1311+ return output
1312+
12591313 def fit_transform (
12601314 self ,
12611315 X : Union [npt .NDArray , torch .Tensor ],
0 commit comments