Skip to content

Commit eaa8940

Browse files
gonlairostesMMathisLab
authored
Improve saving and loading of models (#69)
* first attempt of saving/loading models the right way * some progress * first working proposal * small progress * improve the API. Found self.get_params() * simplify code + first test pass * fix spelling * raise Valuerror + add multisession support * fix name * fix typo * improve test + code * organize code better + add more checks * improve tests * fix import in test * improve docs * add suggestions * remove comments in test * fix typo test * fix docs + delete unnecesary check * remove args from state * improvements * fix typo Co-authored-by: Steffen Schneider <[email protected]> * fix typo Co-authored-by: Steffen Schneider <[email protected]> * fix typo Co-authored-by: Steffen Schneider <[email protected]> * change _label_types to label_types_ * fix docs * Update cebra/integrations/sklearn/cebra.py Co-authored-by: Steffen Schneider <[email protected]> * Update cebra/integrations/sklearn/cebra.py Co-authored-by: Steffen Schneider <[email protected]> * improve typing Co-authored-by: Steffen Schneider <[email protected]> * improve typing Co-authored-by: Steffen Schneider <[email protected]> * fix typo + pre-commit * better docstrings * fix unindent error in docs * Update cebra.py - refine message on fxn --------- Co-authored-by: Steffen Schneider <[email protected]> Co-authored-by: Steffen Schneider <[email protected]> Co-authored-by: Mackenzie Mathis <[email protected]>
1 parent f65c88d commit eaa8940

File tree

2 files changed

+275
-46
lines changed

2 files changed

+275
-46
lines changed

cebra/integrations/sklearn/cebra.py

Lines changed: 198 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,14 @@
1212
"""Define the CEBRA model."""
1313

1414
import copy
15-
from typing import Callable, Iterable, List, Literal, Optional, Tuple, Union
15+
import itertools
16+
import warnings
17+
from typing import (Callable, Dict, Iterable, List, Literal, Optional, Tuple,
18+
Union)
1619

1720
import numpy as np
1821
import numpy.typing as npt
22+
import pkg_resources
1923
import sklearn.utils.validation as sklearn_utils_validation
2024
import torch
2125
from sklearn.base import BaseEstimator
@@ -56,15 +60,15 @@ def _init_loader(
5660
algorithm, which might (depending on the arguments for this function)
5761
be passed to the data loader.
5862
59-
Raises
63+
Raises:
6064
ValueError: If an argument is missing in ``extra_kwargs`` or ``shared_kwargs``
6165
needed to run the requested configuration.
6266
NotImplementedError: If the requested combinations of arguments is not yet
6367
implemented. If this error occurs, check if the desired functionality
6468
is implemented in :py:mod:`cebra.data`, and consider using the CEBRA
6569
PyTorch API directly.
6670
67-
Returns
71+
Returns:
6872
the data loader and name of a suitable solver
6973
7074
Note:
@@ -260,6 +264,94 @@ def _require_arg(key):
260264
f"information to your bug report: \n" + error_message)
261265

262266

267+
def _check_type_checkpoint(checkpoint):
268+
if not isinstance(checkpoint, cebra.CEBRA):
269+
raise RuntimeError("Model loaded from file is not compatible with "
270+
"the current CEBRA version.")
271+
if not sklearn_utils.check_fitted(checkpoint):
272+
raise ValueError(
273+
"CEBRA model is not fitted. Loading it is not supported.")
274+
275+
return checkpoint
276+
277+
278+
def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA":
279+
"""Loads a CEBRA model with a Sklearn backend.
280+
281+
Args:
282+
cebra_info: A dictionary containing information about the CEBRA object,
283+
including the arguments, the state of the object and the state
284+
dictionary of the model.
285+
286+
Returns:
287+
The loaded CEBRA object.
288+
289+
Raises:
290+
ValueError: If the loaded CEBRA model was not already fit, indicating that loading it is not supported.
291+
"""
292+
required_keys = ['args', 'state', 'state_dict']
293+
missing_keys = [key for key in required_keys if key not in cebra_info]
294+
if missing_keys:
295+
raise ValueError(
296+
f"Missing keys in data dictionary: {', '.join(missing_keys)}. "
297+
f"You can try loading the CEBRA model with the torch backend.")
298+
299+
args, state, state_dict = cebra_info['args'], cebra_info[
300+
'state'], cebra_info['state_dict']
301+
cebra_ = cebra.CEBRA(**args)
302+
303+
for key, value in state.items():
304+
setattr(cebra_, key, value)
305+
306+
state_and_args = {**args, **state}
307+
308+
if not sklearn_utils.check_fitted(cebra_):
309+
raise ValueError(
310+
"CEBRA model was not already fit. Loading it is not supported.")
311+
312+
if cebra_.num_sessions_ is None:
313+
model = cebra.models.init(
314+
args["model_architecture"],
315+
num_neurons=state["n_features_in_"],
316+
num_units=args["num_hidden_units"],
317+
num_output=args["output_dimension"],
318+
).to(state['device_'])
319+
320+
elif isinstance(cebra_.num_sessions_, int):
321+
model = nn.ModuleList([
322+
cebra.models.init(
323+
args["model_architecture"],
324+
num_neurons=n_features,
325+
num_units=args["num_hidden_units"],
326+
num_output=args["output_dimension"],
327+
) for n_features in state["n_features_in_"]
328+
]).to(state['device_'])
329+
330+
criterion = cebra_._prepare_criterion()
331+
criterion.to(state['device_'])
332+
333+
optimizer = torch.optim.Adam(
334+
itertools.chain(model.parameters(), criterion.parameters()),
335+
lr=args['learning_rate'],
336+
**dict(args['optimizer_kwargs']),
337+
)
338+
339+
solver = cebra.solver.init(
340+
state['solver_name_'],
341+
model=model,
342+
criterion=criterion,
343+
optimizer=optimizer,
344+
tqdm_on=args['verbose'],
345+
)
346+
solver.load_state_dict(state_dict)
347+
solver.to(state['device_'])
348+
349+
cebra_.model_ = model
350+
cebra_.solver_ = solver
351+
352+
return cebra_
353+
354+
263355
class CEBRA(BaseEstimator, TransformerMixin):
264356
"""CEBRA model defined as part of a ``scikit-learn``-like API.
265357
@@ -735,16 +827,16 @@ def _check_labels_types(self, y: tuple, session_id: Optional[int] = None):
735827
"""
736828
n_idx = len(y)
737829
# Check that same number of index
738-
if len(self._label_types) != n_idx:
830+
if len(self.label_types_) != n_idx:
739831
raise ValueError(
740832
f"Number of index invalid: labels must have the same number of index as for fitting,"
741-
f"expects {len(self._label_types)}, got {n_idx} idx.")
833+
f"expects {len(self.label_types_)}, got {n_idx} idx.")
742834

743-
for i in range(len(self._label_types)): # for each index
835+
for i in range(len(self.label_types_)): # for each index
744836
if self.num_sessions is None:
745-
label_types_idx = self._label_types[i]
837+
label_types_idx = self.label_types_[i]
746838
else:
747-
label_types_idx = self._label_types[i][session_id]
839+
label_types_idx = self.label_types_[i][session_id]
748840

749841
if (len(label_types_idx[1]) > 1 and len(y[i].shape)
750842
> 1): # is there more than one feature in the index
@@ -794,7 +886,7 @@ def _prepare_fit(
794886
criterion = self._prepare_criterion()
795887
criterion.to(self.device_)
796888
optimizer = torch.optim.Adam(
797-
list(model.parameters()) + list(criterion.parameters()),
889+
itertools.chain(model.parameters(), criterion.parameters()),
798890
lr=self.learning_rate,
799891
**dict(self.optimizer_kwargs),
800892
)
@@ -807,8 +899,9 @@ def _prepare_fit(
807899
tqdm_on=self.verbose,
808900
)
809901
solver.to(self.device_)
902+
self.solver_name_ = solver_name
810903

811-
self._label_types = ([[(y_session.dtype, y_session.shape)
904+
self.label_types_ = ([[(y_session.dtype, y_session.shape)
812905
for y_session in y_index]
813906
for y_index in y] if is_multisession else
814907
[(y_.dtype, y_.shape) for y_ in y])
@@ -1191,19 +1284,49 @@ def _more_tags(self):
11911284
# current version of CEBRA.
11921285
return {"non_deterministic": True}
11931286

1194-
def save(self, filename: str, backend: Literal["torch"] = "torch"):
1287+
def _get_state(self):
1288+
cebra_dict = self.__dict__
1289+
state = {
1290+
'label_types_': cebra_dict['label_types_'],
1291+
'device_': cebra_dict['device_'],
1292+
'n_features_': cebra_dict['n_features_'],
1293+
'n_features_in_': cebra_dict['n_features_in_'],
1294+
'num_sessions_': cebra_dict['num_sessions_'],
1295+
'offset_': cebra_dict['offset_'],
1296+
'solver_name_': cebra_dict['solver_name_'],
1297+
}
1298+
return state
1299+
1300+
def save(self,
1301+
filename: str,
1302+
backend: Literal["torch", "sklearn"] = "sklearn"):
11951303
"""Save the model to disk.
11961304
11971305
Args:
11981306
filename: The path to the file in which to save the trained model.
1199-
backend: A string identifying the used backend.
1307+
backend: A string identifying the used backend. Default is "sklearn".
12001308
12011309
Returns:
12021310
The saved model checkpoint.
12031311
12041312
Note:
1205-
Experimental functionality. Do not expect the save/load functionalities to be
1206-
backward compatible yet between CEBRA versions!
1313+
The save/load functionalities may change in a future version.
1314+
1315+
File Format:
1316+
The saved model checkpoint file format depends on the specified backend.
1317+
1318+
"sklearn" backend (default):
1319+
The model is saved in a PyTorch-compatible format using `torch.save`. The saved checkpoint
1320+
is a dictionary containing the following elements:
1321+
- 'args': A dictionary of parameters used to initialize the CEBRA model.
1322+
- 'state': The state of the CEBRA model, which includes various internal attributes.
1323+
- 'state_dict': The state dictionary of the underlying solver used by CEBRA.
1324+
- 'metadata': Additional metadata about the saved model, including the backend used and the version of CEBRA PyTorch, NumPy and scikit-learn.
1325+
1326+
"torch" backend:
1327+
The model is directly saved using `torch.save` with no additional information. The saved
1328+
file contains the entire CEBRA model state.
1329+
12071330
12081331
Example:
12091332
@@ -1216,15 +1339,41 @@ def save(self, filename: str, backend: Literal["torch"] = "torch"):
12161339
>>> cebra_model.save('/tmp/foo.pt')
12171340
12181341
"""
1219-
if backend != "torch":
1220-
raise NotImplementedError(f"Unsupported backend: {backend}")
1221-
checkpoint = torch.save(self, filename)
1342+
if sklearn_utils.check_fitted(self):
1343+
if backend == "torch":
1344+
checkpoint = torch.save(self, filename)
1345+
1346+
elif backend == "sklearn":
1347+
checkpoint = torch.save(
1348+
{
1349+
'args': self.get_params(),
1350+
'state': self._get_state(),
1351+
'state_dict': self.solver_.state_dict(),
1352+
'metadata': {
1353+
'backend':
1354+
backend,
1355+
'cebra_version':
1356+
cebra.__version__,
1357+
'torch_version':
1358+
torch.__version__,
1359+
'numpy_version':
1360+
np.__version__,
1361+
'sklearn_version':
1362+
pkg_resources.get_distribution("scikit-learn"
1363+
).version
1364+
}
1365+
}, filename)
1366+
else:
1367+
raise NotImplementedError(f"Unsupported backend: {backend}")
1368+
else:
1369+
raise ValueError("CEBRA object is not fitted. "
1370+
"Saving a non-fitted model is not supported.")
12221371
return checkpoint
12231372

12241373
@classmethod
12251374
def load(cls,
12261375
filename: str,
1227-
backend: Literal["torch"] = "torch",
1376+
backend: Literal["auto", "sklearn", "torch"] = "auto",
12281377
**kwargs) -> "CEBRA":
12291378
"""Load a model from disk.
12301379
@@ -1240,6 +1389,8 @@ def load(cls,
12401389
Experimental functionality. Do not expect the save/load functionalities to be
12411390
backward compatible yet between CEBRA versions!
12421391
1392+
For information about the file format please refer to :py:meth:`cebra.CEBRA.save`.
1393+
12431394
Example:
12441395
12451396
>>> import cebra
@@ -1249,13 +1400,32 @@ def load(cls,
12491400
>>> embedding = loaded_model.transform(dataset)
12501401
12511402
"""
1252-
if backend != "torch":
1253-
raise NotImplementedError(f"Unsupported backend: {backend}")
1254-
model = torch.load(filename, **kwargs)
1255-
if not isinstance(model, cls):
1256-
raise RuntimeError("Model loaded from file is not compatible with "
1257-
"the current CEBRA version.")
1258-
return model
1403+
1404+
supported_backends = ["auto", "sklearn", "torch"]
1405+
if backend not in supported_backends:
1406+
raise NotImplementedError(
1407+
f"Unsupported backend: '{backend}'. Supported backends are: {', '.join(supported_backends)}"
1408+
)
1409+
1410+
checkpoint = torch.load(filename, **kwargs)
1411+
1412+
if backend == "auto":
1413+
backend = "sklearn" if isinstance(checkpoint, dict) else "torch"
1414+
1415+
if isinstance(checkpoint, dict) and backend == "torch":
1416+
raise RuntimeError(
1417+
f"Cannot use 'torch' backend with a dictionary-based checkpoint. "
1418+
f"Please try a different backend.")
1419+
if not isinstance(checkpoint, dict) and backend == "sklearn":
1420+
raise RuntimeError(f"Cannot use 'sklearn' backend a non dictionary-based checkpoint. "
1421+
f"Please try a different backend.")
1422+
1423+
if backend == "sklearn":
1424+
cebra_ = _load_cebra_with_sklearn_backend(checkpoint)
1425+
else:
1426+
cebra_ = _check_type_checkpoint(checkpoint)
1427+
1428+
return cebra_
12591429

12601430
def to(self, device: Union[str, torch.device]):
12611431
"""Moves the cebra model to the specified device.
@@ -1282,7 +1452,7 @@ def to(self, device: Union[str, torch.device]):
12821452
raise TypeError(
12831453
"The 'device' parameter must be a string or torch.device object."
12841454
)
1285-
1455+
12861456
if isinstance(device, str):
12871457
if (not device == 'cpu') and (not device.startswith('cuda')) and (
12881458
not device == 'mps'):
@@ -1292,7 +1462,8 @@ def to(self, device: Union[str, torch.device]):
12921462

12931463
elif isinstance(device, torch.device):
12941464
if (not device.type == 'cpu') and (
1295-
not device.type.startswith('cuda')) and (not device == 'mps'):
1465+
not device.type.startswith('cuda')) and (not device
1466+
== 'mps'):
12961467
raise ValueError(
12971468
"The 'device' parameter must be a valid device string or device object."
12981469
)

0 commit comments

Comments
 (0)