Skip to content

Commit 354f36e

Browse files
committed
Add support for torch.load with weights_only=True
1 parent 46af88e commit 354f36e

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

cebra/integrations/sklearn/cebra.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
import numpy as np
2929
import numpy.typing as npt
3030
import pkg_resources
31+
import sklearn
3132
import sklearn.utils.validation as sklearn_utils_validation
3233
import torch
33-
import sklearn
3434
from sklearn.base import BaseEstimator
3535
from sklearn.base import TransformerMixin
3636
from sklearn.utils.metaestimators import available_if
@@ -43,12 +43,21 @@
4343
import cebra.models
4444
import cebra.solver
4545

46+
# NOTE(stes): From torch 2.6 onwards, we need to specify the following list
47+
# when loading CEBRA models to allow weights_only = True.
48+
CEBRA_LOAD_SAFE_GLOBALS = [
49+
cebra.data.Offset, torch.torch_version.TorchVersion, np.dtype,
50+
np.dtypes.Float64DType, np.dtypes.Int64DType
51+
]
52+
53+
4654
def check_version(estimator):
4755
# NOTE(stes): required as a check for the old way of specifying tags
4856
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
4957
from packaging import version
5058
return version.parse(sklearn.__version__) < version.parse("1.6.dev")
5159

60+
5261
def _init_loader(
5362
is_cont: bool,
5463
is_disc: bool,
@@ -1409,12 +1418,18 @@ def save(self,
14091418
def load(cls,
14101419
filename: str,
14111420
backend: Literal["auto", "sklearn", "torch"] = "auto",
1421+
weights_only: bool = True,
14121422
**kwargs) -> "CEBRA":
14131423
"""Load a model from disk.
14141424
14151425
Args:
14161426
filename: The path to the file in which to save the trained model.
14171427
backend: A string identifying the used backend.
1428+
weights_only: Indicates whether unpickler should be restricted to loading only tensors, primitive types,
1429+
dictionaries and any types added via `py:func:torch.serialization.add_safe_globals`.
1430+
See `py:func:torch.load` with ``weights_only=True`` for more details. It it recommended to leave this
1431+
at the default value of ``True``. If you experience issues with loading custom models (specified outside
1432+
of the CEBRA package), you can try to set this to ``False`` if you trust the source of the model.
14181433
kwargs: Optional keyword arguments passed directly to the loader.
14191434
14201435
Return:
@@ -1443,14 +1458,17 @@ def load(cls,
14431458
>>> tmp_file.unlink()
14441459
14451460
"""
1446-
14471461
supported_backends = ["auto", "sklearn", "torch"]
14481462
if backend not in supported_backends:
14491463
raise NotImplementedError(
14501464
f"Unsupported backend: '{backend}'. Supported backends are: {', '.join(supported_backends)}"
14511465
)
14521466

1453-
checkpoint = torch.load(filename, **kwargs)
1467+
if not weights_only:
1468+
checkpoint = torch.load(filename, weights_only=False, **kwargs)
1469+
else:
1470+
with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS):
1471+
checkpoint = torch.load(filename, weights_only=True, **kwargs)
14541472

14551473
if backend == "auto":
14561474
backend = "sklearn" if isinstance(checkpoint, dict) else "torch"

0 commit comments

Comments
 (0)