Skip to content

Commit 2a53f2f

Browse files
committed
weights_only=True default for torch>=2.6
1 parent 93cbe94 commit 2a53f2f

File tree

7 files changed

+29
-23
lines changed

7 files changed

+29
-23
lines changed

src/lightning/fabric/plugins/io/checkpoint_io.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,18 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio
4848

4949
@abstractmethod
5050
def load_checkpoint(
51-
self, path: _PATH, map_location: Optional[Any] = None, weights_only: bool = True
51+
self, path: _PATH, map_location: Optional[Any] = None, weights_only: Optional[bool] = None
5252
) -> dict[str, Any]:
5353
"""Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages.
5454
5555
Args:
5656
path: Path to checkpoint
5757
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
5858
locations.
59-
weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other
60-
primitive types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use
61-
``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using
62-
``weights_only=True``. For more information, please refer to the
59+
weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
60+
``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
61+
an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
62+
recommend using ``weights_only=True``. For more information, please refer to the
6363
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
6464
6565
Returns: The loaded checkpoint.

src/lightning/fabric/plugins/io/torch_io.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,22 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio
5959

6060
@override
6161
def load_checkpoint(
62-
self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage, weights_only: bool = True
62+
self,
63+
path: _PATH,
64+
map_location: Optional[Callable] = lambda storage, loc: storage,
65+
weights_only: Optional[bool] = None,
6366
) -> dict[str, Any]:
6467
"""Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files.
6568
6669
Args:
6770
path: Path to checkpoint
6871
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
6972
locations.
73+
weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
74+
``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
75+
an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
76+
recommend using ``weights_only=True``. For more information, please refer to the
77+
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
7078
7179
Returns: The loaded checkpoint.
7280

src/lightning/fabric/utilities/cloud_io.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import io
1818
import logging
1919
from pathlib import Path
20-
from typing import IO, Any, Union
20+
from typing import IO, Any, Optional, Union
2121

2222
import fsspec
2323
import fsspec.utils
@@ -26,6 +26,7 @@
2626
from fsspec.implementations.local import AbstractFileSystem
2727
from lightning_utilities.core.imports import module_available
2828

29+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6
2930
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
3031

3132
log = logging.getLogger(__name__)
@@ -34,7 +35,7 @@
3435
def _load(
3536
path_or_url: Union[IO, _PATH],
3637
map_location: _MAP_LOCATION_TYPE = None,
37-
weights_only: bool = True,
38+
weights_only: Optional[bool] = None,
3839
) -> Any:
3940
"""Loads a checkpoint.
4041
@@ -48,6 +49,11 @@ def _load(
4849
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
4950
5051
"""
52+
# default to `weights_only=True` for torch>=2.6
53+
if weights_only is None and _TORCH_GREATER_EQUAL_2_6:
54+
log.debug("Defaulting to `weights_only=True` for torch>=2.6.")
55+
weights_only = True
56+
5157
if not isinstance(path_or_url, (str, Path)):
5258
# any sort of BytesIO or similar
5359
return torch.load(

src/lightning/fabric/utilities/imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0")
3636
_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1")
3737
_TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0")
38+
_TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0")
3839
_TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0")
3940
_TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0")
4041
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)

src/lightning/pytorch/core/saving.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,6 @@ def _load_from_checkpoint(
6161
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
6262
map_location = map_location or _default_map_location
6363

64-
if weights_only is None:
65-
log.debug("`weights_only` not specified, defaulting to `True`.")
66-
weights_only = True
67-
6864
with pl_legacy_patch():
6965
checkpoint = pl_load(checkpoint_path, map_location=map_location, weights_only=weights_only)
7066

tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,7 @@ def test_load_legacy_checkpoints(tmp_path, pl_version: str):
4646
assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"'
4747
path_ckpt = path_ckpts[-1]
4848

49-
# legacy load utility added in 1.5.0 (see https://github.com/Lightning-AI/pytorch-lightning/pull/9166)
50-
if pl_version == "local":
51-
pl_version = pl.__version__
52-
weights_only = not Version(pl_version) < Version("1.5.0")
53-
54-
model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24, weights_only=weights_only)
49+
model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24)
5550
trainer = Trainer(default_root_dir=tmp_path)
5651
dm = ClassifDataModule(num_features=24, length=6000, batch_size=128, n_clusters_per_class=2, n_informative=8)
5752
res = trainer.test(model, datamodule=dm)

tests/tests_pytorch/models/test_hparams.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(self, hparams, *my_args, **my_kwargs):
9494
# -------------------------
9595
# STANDARD TESTS
9696
# -------------------------
97-
def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwrite=False, weights_only=True):
97+
def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwrite=False):
9898
"""Tests for the existence of an arg 'test_arg=14'."""
9999
obj = datamodule if issubclass(cls, LightningDataModule) else model
100100

@@ -108,20 +108,20 @@ def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwr
108108

109109
# make sure the raw checkpoint saved the properties
110110
raw_checkpoint_path = _raw_checkpoint_path(trainer)
111-
raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=weights_only)
111+
raw_checkpoint = torch.load(raw_checkpoint_path)
112112

113113
assert cls.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
114114
assert raw_checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14
115115

116116
# verify that model loads correctly
117-
obj2 = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=weights_only)
117+
obj2 = cls.load_from_checkpoint(raw_checkpoint_path)
118118
assert obj2.hparams.test_arg == 14
119119

120120
assert isinstance(obj2.hparams, hparam_type)
121121

122122
if try_overwrite:
123123
# verify that we can overwrite the property
124-
obj3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78, weights_only=weights_only)
124+
obj3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78)
125125
assert obj3.hparams.test_arg == 78
126126

127127
return raw_checkpoint_path
@@ -176,7 +176,7 @@ def test_omega_conf_hparams(tmp_path, cls):
176176
assert isinstance(obj.hparams, Container)
177177

178178
# run standard test suite
179-
raw_checkpoint_path = _run_standard_hparams_test(tmp_path, model, cls, datamodule=datamodule, weights_only=False)
179+
raw_checkpoint_path = _run_standard_hparams_test(tmp_path, model, cls, datamodule=datamodule)
180180
obj2 = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=False)
181181

182182
assert isinstance(obj2.hparams, Container)

0 commit comments

Comments
 (0)