|
28 | 28 | import numpy as np |
29 | 29 | import numpy.typing as npt |
30 | 30 | import pkg_resources |
| 31 | +import sklearn |
31 | 32 | import sklearn.utils.validation as sklearn_utils_validation |
32 | 33 | import torch |
33 | | -import sklearn |
34 | 34 | from sklearn.base import BaseEstimator |
35 | 35 | from sklearn.base import TransformerMixin |
36 | 36 | from sklearn.utils.metaestimators import available_if |
|
43 | 43 | import cebra.models |
44 | 44 | import cebra.solver |
45 | 45 |
|
| 46 | +# NOTE(stes): From torch 2.6 onwards, we need to specify the following list |
| 47 | +# when loading CEBRA models to allow weights_only = True. |
| 48 | +CEBRA_LOAD_SAFE_GLOBALS = [ |
| 49 | + cebra.data.Offset, torch.torch_version.TorchVersion, np.dtype, |
| 50 | + np.dtypes.Float64DType, np.dtypes.Int64DType |
| 51 | +] |
| 52 | + |
| 53 | + |
46 | 54 | def check_version(estimator): |
47 | 55 | # NOTE(stes): required as a check for the old way of specifying tags |
48 | 56 | # https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165 |
49 | 57 | from packaging import version |
50 | 58 | return version.parse(sklearn.__version__) < version.parse("1.6.dev") |
51 | 59 |
|
| 60 | + |
52 | 61 | def _init_loader( |
53 | 62 | is_cont: bool, |
54 | 63 | is_disc: bool, |
@@ -1409,12 +1418,18 @@ def save(self, |
1409 | 1418 | def load(cls, |
1410 | 1419 | filename: str, |
1411 | 1420 | backend: Literal["auto", "sklearn", "torch"] = "auto", |
| 1421 | + weights_only: bool = True, |
1412 | 1422 | **kwargs) -> "CEBRA": |
1413 | 1423 | """Load a model from disk. |
1414 | 1424 |
|
1415 | 1425 | Args: |
1416 | 1426 | filename: The path to the file in which to save the trained model. |
1417 | 1427 | backend: A string identifying the used backend. |
| 1428 | + weights_only: Indicates whether unpickler should be restricted to loading only tensors, primitive types, |
| 1429 | + dictionaries and any types added via `py:func:torch.serialization.add_safe_globals`. |
| 1430 | + See `py:func:torch.load` with ``weights_only=True`` for more details. It it recommended to leave this |
| 1431 | + at the default value of ``True``. If you experience issues with loading custom models (specified outside |
| 1432 | + of the CEBRA package), you can try to set this to ``False`` if you trust the source of the model. |
1418 | 1433 | kwargs: Optional keyword arguments passed directly to the loader. |
1419 | 1434 |
|
1420 | 1435 | Return: |
@@ -1443,14 +1458,17 @@ def load(cls, |
1443 | 1458 | >>> tmp_file.unlink() |
1444 | 1459 |
|
1445 | 1460 | """ |
1446 | | - |
1447 | 1461 | supported_backends = ["auto", "sklearn", "torch"] |
1448 | 1462 | if backend not in supported_backends: |
1449 | 1463 | raise NotImplementedError( |
1450 | 1464 | f"Unsupported backend: '{backend}'. Supported backends are: {', '.join(supported_backends)}" |
1451 | 1465 | ) |
1452 | 1466 |
|
1453 | | - checkpoint = torch.load(filename, **kwargs) |
| 1467 | + if not weights_only: |
| 1468 | + checkpoint = torch.load(filename, weights_only=False, **kwargs) |
| 1469 | + else: |
| 1470 | + with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS): |
| 1471 | + checkpoint = torch.load(filename, weights_only=True, **kwargs) |
1454 | 1472 |
|
1455 | 1473 | if backend == "auto": |
1456 | 1474 | backend = "sklearn" if isinstance(checkpoint, dict) else "torch" |
|
0 commit comments