1212"""Define the CEBRA model."""
1313
1414import copy
15- from typing import Callable , Iterable , List , Literal , Optional , Tuple , Union
15+ import itertools
16+ import warnings
17+ from typing import (Callable , Dict , Iterable , List , Literal , Optional , Tuple ,
18+ Union )
1619
1720import numpy as np
1821import numpy .typing as npt
22+ import pkg_resources
1923import sklearn .utils .validation as sklearn_utils_validation
2024import torch
2125from sklearn .base import BaseEstimator
@@ -56,15 +60,15 @@ def _init_loader(
5660 algorithm, which might (depending on the arguments for this function)
5761 be passed to the data loader.
5862
59- Raises
63+ Raises:
6064 ValueError: If an argument is missing in ``extra_kwargs`` or ``shared_kwargs``
6165 needed to run the requested configuration.
6266 NotImplementedError: If the requested combinations of arguments is not yet
6367 implemented. If this error occurs, check if the desired functionality
6468 is implemented in :py:mod:`cebra.data`, and consider using the CEBRA
6569 PyTorch API directly.
6670
67- Returns
71+ Returns:
6872 the data loader and name of a suitable solver
6973
7074 Note:
@@ -260,6 +264,94 @@ def _require_arg(key):
260264 f"information to your bug report: \n " + error_message )
261265
262266
267+ def _check_type_checkpoint (checkpoint ):
268+ if not isinstance (checkpoint , cebra .CEBRA ):
269+ raise RuntimeError ("Model loaded from file is not compatible with "
270+ "the current CEBRA version." )
271+ if not sklearn_utils .check_fitted (checkpoint ):
272+ raise ValueError (
273+ "CEBRA model is not fitted. Loading it is not supported." )
274+
275+ return checkpoint
276+
277+
278+ def _load_cebra_with_sklearn_backend (cebra_info : Dict ) -> "CEBRA" :
279+ """Loads a CEBRA model with a Sklearn backend.
280+
281+ Args:
282+ cebra_info: A dictionary containing information about the CEBRA object,
283+ including the arguments, the state of the object and the state
284+ dictionary of the model.
285+
286+ Returns:
287+ The loaded CEBRA object.
288+
289+ Raises:
290+ ValueError: If the loaded CEBRA model was not already fit, indicating that loading it is not supported.
291+ """
292+ required_keys = ['args' , 'state' , 'state_dict' ]
293+ missing_keys = [key for key in required_keys if key not in cebra_info ]
294+ if missing_keys :
295+ raise ValueError (
296+ f"Missing keys in data dictionary: { ', ' .join (missing_keys )} . "
297+ f"You can try loading the CEBRA model with the torch backend." )
298+
299+ args , state , state_dict = cebra_info ['args' ], cebra_info [
300+ 'state' ], cebra_info ['state_dict' ]
301+ cebra_ = cebra .CEBRA (** args )
302+
303+ for key , value in state .items ():
304+ setattr (cebra_ , key , value )
305+
306+ state_and_args = {** args , ** state }
307+
308+ if not sklearn_utils .check_fitted (cebra_ ):
309+ raise ValueError (
310+ "CEBRA model was not already fit. Loading it is not supported." )
311+
312+ if cebra_ .num_sessions_ is None :
313+ model = cebra .models .init (
314+ args ["model_architecture" ],
315+ num_neurons = state ["n_features_in_" ],
316+ num_units = args ["num_hidden_units" ],
317+ num_output = args ["output_dimension" ],
318+ ).to (state ['device_' ])
319+
320+ elif isinstance (cebra_ .num_sessions_ , int ):
321+ model = nn .ModuleList ([
322+ cebra .models .init (
323+ args ["model_architecture" ],
324+ num_neurons = n_features ,
325+ num_units = args ["num_hidden_units" ],
326+ num_output = args ["output_dimension" ],
327+ ) for n_features in state ["n_features_in_" ]
328+ ]).to (state ['device_' ])
329+
330+ criterion = cebra_ ._prepare_criterion ()
331+ criterion .to (state ['device_' ])
332+
333+ optimizer = torch .optim .Adam (
334+ itertools .chain (model .parameters (), criterion .parameters ()),
335+ lr = args ['learning_rate' ],
336+ ** dict (args ['optimizer_kwargs' ]),
337+ )
338+
339+ solver = cebra .solver .init (
340+ state ['solver_name_' ],
341+ model = model ,
342+ criterion = criterion ,
343+ optimizer = optimizer ,
344+ tqdm_on = args ['verbose' ],
345+ )
346+ solver .load_state_dict (state_dict )
347+ solver .to (state ['device_' ])
348+
349+ cebra_ .model_ = model
350+ cebra_ .solver_ = solver
351+
352+ return cebra_
353+
354+
263355class CEBRA (BaseEstimator , TransformerMixin ):
264356 """CEBRA model defined as part of a ``scikit-learn``-like API.
265357
@@ -735,16 +827,16 @@ def _check_labels_types(self, y: tuple, session_id: Optional[int] = None):
735827 """
736828 n_idx = len (y )
737829 # Check that same number of index
738- if len (self ._label_types ) != n_idx :
830+ if len (self .label_types_ ) != n_idx :
739831 raise ValueError (
740832 f"Number of index invalid: labels must have the same number of index as for fitting,"
741- f"expects { len (self ._label_types )} , got { n_idx } idx." )
833+ f"expects { len (self .label_types_ )} , got { n_idx } idx." )
742834
743- for i in range (len (self ._label_types )): # for each index
835+ for i in range (len (self .label_types_ )): # for each index
744836 if self .num_sessions is None :
745- label_types_idx = self ._label_types [i ]
837+ label_types_idx = self .label_types_ [i ]
746838 else :
747- label_types_idx = self ._label_types [i ][session_id ]
839+ label_types_idx = self .label_types_ [i ][session_id ]
748840
749841 if (len (label_types_idx [1 ]) > 1 and len (y [i ].shape )
750842 > 1 ): # is there more than one feature in the index
@@ -794,7 +886,7 @@ def _prepare_fit(
794886 criterion = self ._prepare_criterion ()
795887 criterion .to (self .device_ )
796888 optimizer = torch .optim .Adam (
797- list (model .parameters ()) + list ( criterion .parameters ()),
889+ itertools . chain (model .parameters (), criterion .parameters ()),
798890 lr = self .learning_rate ,
799891 ** dict (self .optimizer_kwargs ),
800892 )
@@ -807,8 +899,9 @@ def _prepare_fit(
807899 tqdm_on = self .verbose ,
808900 )
809901 solver .to (self .device_ )
902+ self .solver_name_ = solver_name
810903
811- self ._label_types = ([[(y_session .dtype , y_session .shape )
904+ self .label_types_ = ([[(y_session .dtype , y_session .shape )
812905 for y_session in y_index ]
813906 for y_index in y ] if is_multisession else
814907 [(y_ .dtype , y_ .shape ) for y_ in y ])
@@ -1191,19 +1284,49 @@ def _more_tags(self):
11911284 # current version of CEBRA.
11921285 return {"non_deterministic" : True }
11931286
1194- def save (self , filename : str , backend : Literal ["torch" ] = "torch" ):
1287+ def _get_state (self ):
1288+ cebra_dict = self .__dict__
1289+ state = {
1290+ 'label_types_' : cebra_dict ['label_types_' ],
1291+ 'device_' : cebra_dict ['device_' ],
1292+ 'n_features_' : cebra_dict ['n_features_' ],
1293+ 'n_features_in_' : cebra_dict ['n_features_in_' ],
1294+ 'num_sessions_' : cebra_dict ['num_sessions_' ],
1295+ 'offset_' : cebra_dict ['offset_' ],
1296+ 'solver_name_' : cebra_dict ['solver_name_' ],
1297+ }
1298+ return state
1299+
1300+ def save (self ,
1301+ filename : str ,
1302+ backend : Literal ["torch" , "sklearn" ] = "sklearn" ):
11951303 """Save the model to disk.
11961304
11971305 Args:
11981306 filename: The path to the file in which to save the trained model.
1199- backend: A string identifying the used backend.
1307+ backend: A string identifying the used backend. Default is "sklearn".
12001308
12011309 Returns:
12021310 The saved model checkpoint.
12031311
12041312 Note:
1205- Experimental functionality. Do not expect the save/load functionalities to be
1206- backward compatible yet between CEBRA versions!
1313+ The save/load functionalities may change in a future version.
1314+
1315+ File Format:
1316+ The saved model checkpoint file format depends on the specified backend.
1317+
1318+ "sklearn" backend (default):
1319+ The model is saved in a PyTorch-compatible format using `torch.save`. The saved checkpoint
1320+ is a dictionary containing the following elements:
1321+ - 'args': A dictionary of parameters used to initialize the CEBRA model.
1322+ - 'state': The state of the CEBRA model, which includes various internal attributes.
1323+ - 'state_dict': The state dictionary of the underlying solver used by CEBRA.
1324+ - 'metadata': Additional metadata about the saved model, including the backend used and the version of CEBRA PyTorch, NumPy and scikit-learn.
1325+
1326+ "torch" backend:
1327+ The model is directly saved using `torch.save` with no additional information. The saved
1328+ file contains the entire CEBRA model state.
1329+
12071330
12081331 Example:
12091332
@@ -1216,15 +1339,41 @@ def save(self, filename: str, backend: Literal["torch"] = "torch"):
12161339 >>> cebra_model.save('/tmp/foo.pt')
12171340
12181341 """
1219- if backend != "torch" :
1220- raise NotImplementedError (f"Unsupported backend: { backend } " )
1221- checkpoint = torch .save (self , filename )
1342+ if sklearn_utils .check_fitted (self ):
1343+ if backend == "torch" :
1344+ checkpoint = torch .save (self , filename )
1345+
1346+ elif backend == "sklearn" :
1347+ checkpoint = torch .save (
1348+ {
1349+ 'args' : self .get_params (),
1350+ 'state' : self ._get_state (),
1351+ 'state_dict' : self .solver_ .state_dict (),
1352+ 'metadata' : {
1353+ 'backend' :
1354+ backend ,
1355+ 'cebra_version' :
1356+ cebra .__version__ ,
1357+ 'torch_version' :
1358+ torch .__version__ ,
1359+ 'numpy_version' :
1360+ np .__version__ ,
1361+ 'sklearn_version' :
1362+ pkg_resources .get_distribution ("scikit-learn"
1363+ ).version
1364+ }
1365+ }, filename )
1366+ else :
1367+ raise NotImplementedError (f"Unsupported backend: { backend } " )
1368+ else :
1369+ raise ValueError ("CEBRA object is not fitted. "
1370+ "Saving a non-fitted model is not supported." )
12221371 return checkpoint
12231372
12241373 @classmethod
12251374 def load (cls ,
12261375 filename : str ,
1227- backend : Literal ["torch" ] = "torch " ,
1376+ backend : Literal ["auto" , "sklearn" , " torch" ] = "auto " ,
12281377 ** kwargs ) -> "CEBRA" :
12291378 """Load a model from disk.
12301379
@@ -1240,6 +1389,8 @@ def load(cls,
12401389 Experimental functionality. Do not expect the save/load functionalities to be
12411390 backward compatible yet between CEBRA versions!
12421391
1392+ For information about the file format please refer to :py:meth:`cebra.CEBRA.save`.
1393+
12431394 Example:
12441395
12451396 >>> import cebra
@@ -1249,13 +1400,32 @@ def load(cls,
12491400 >>> embedding = loaded_model.transform(dataset)
12501401
12511402 """
1252- if backend != "torch" :
1253- raise NotImplementedError (f"Unsupported backend: { backend } " )
1254- model = torch .load (filename , ** kwargs )
1255- if not isinstance (model , cls ):
1256- raise RuntimeError ("Model loaded from file is not compatible with "
1257- "the current CEBRA version." )
1258- return model
1403+
1404+ supported_backends = ["auto" , "sklearn" , "torch" ]
1405+ if backend not in supported_backends :
1406+ raise NotImplementedError (
1407+ f"Unsupported backend: '{ backend } '. Supported backends are: { ', ' .join (supported_backends )} "
1408+ )
1409+
1410+ checkpoint = torch .load (filename , ** kwargs )
1411+
1412+ if backend == "auto" :
1413+ backend = "sklearn" if isinstance (checkpoint , dict ) else "torch"
1414+
1415+ if isinstance (checkpoint , dict ) and backend == "torch" :
1416+ raise RuntimeError (
1417+ f"Cannot use 'torch' backend with a dictionary-based checkpoint. "
1418+ f"Please try a different backend." )
1419+ if not isinstance (checkpoint , dict ) and backend == "sklearn" :
1420+ raise RuntimeError (f"Cannot use 'sklearn' backend a non dictionary-based checkpoint. "
1421+ f"Please try a different backend." )
1422+
1423+ if backend == "sklearn" :
1424+ cebra_ = _load_cebra_with_sklearn_backend (checkpoint )
1425+ else :
1426+ cebra_ = _check_type_checkpoint (checkpoint )
1427+
1428+ return cebra_
12591429
12601430 def to (self , device : Union [str , torch .device ]):
12611431 """Moves the cebra model to the specified device.
@@ -1282,7 +1452,7 @@ def to(self, device: Union[str, torch.device]):
12821452 raise TypeError (
12831453 "The 'device' parameter must be a string or torch.device object."
12841454 )
1285-
1455+
12861456 if isinstance (device , str ):
12871457 if (not device == 'cpu' ) and (not device .startswith ('cuda' )) and (
12881458 not device == 'mps' ):
@@ -1292,7 +1462,8 @@ def to(self, device: Union[str, torch.device]):
12921462
12931463 elif isinstance (device , torch .device ):
12941464 if (not device .type == 'cpu' ) and (
1295- not device .type .startswith ('cuda' )) and (not device == 'mps' ):
1465+ not device .type .startswith ('cuda' )) and (not device
1466+ == 'mps' ):
12961467 raise ValueError (
12971468 "The 'device' parameter must be a valid device string or device object."
12981469 )
0 commit comments