Skip to content

Commit f276114

Browse files
committed
add weights_only arg to checkpoint save. weights_only during test set based on ckpt version
1 parent 4eaaf58 commit f276114

File tree

3 files changed

+23
-4
lines changed

3 files changed

+23
-4
lines changed

src/lightning/pytorch/core/saving.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,17 @@ def _load_from_checkpoint(
5656
map_location: _MAP_LOCATION_TYPE = None,
5757
hparams_file: Optional[_PATH] = None,
5858
strict: Optional[bool] = None,
59+
weights_only: Optional[bool] = None,
5960
**kwargs: Any,
6061
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
6162
map_location = map_location or _default_map_location
63+
64+
if weights_only is None:
65+
log.debug("`weights_only` not specified, defaulting to `True`.")
66+
weights_only = True
67+
6268
with pl_legacy_patch():
63-
checkpoint = pl_load(checkpoint_path, map_location=map_location)
69+
checkpoint = pl_load(checkpoint_path, map_location=map_location, weights_only=weights_only)
6470

6571
# convert legacy checkpoints to the new format
6672
checkpoint = _pl_migrate_checkpoint(

src/lightning/pytorch/trainer/connectors/checkpoint_connector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,9 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
414414
"""Creating a model checkpoint dictionary object from various component states.
415415
416416
Args:
417-
weights_only: saving model weights only
417+
weights_only: If True, only saves model and loops state_dict objects. If False,
418+
additionally saves callbacks, optimizers, schedulers, and precision plugin states.
419+
418420
Return:
419421
structured dictionary: {
420422
'epoch': training epoch

tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import pytest
2020
import torch
21+
from packaging.version import Version
2122

2223
import lightning.pytorch as pl
2324
from lightning.pytorch import Callback, Trainer
@@ -45,7 +46,12 @@ def test_load_legacy_checkpoints(tmp_path, pl_version: str):
4546
assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"'
4647
path_ckpt = path_ckpts[-1]
4748

48-
model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24)
49+
# legacy load utility added in 1.5.0 (see https://github.com/Lightning-AI/pytorch-lightning/pull/9166)
50+
if pl_version == "local":
51+
pl_version = pl.__version__
52+
weights_only = not Version(pl_version) < Version("1.5.0")
53+
54+
model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24, weights_only=weights_only)
4955
trainer = Trainer(default_root_dir=tmp_path)
5056
dm = ClassifDataModule(num_features=24, length=6000, batch_size=128, n_clusters_per_class=2, n_informative=8)
5157
res = trainer.test(model, datamodule=dm)
@@ -73,13 +79,18 @@ def test_legacy_ckpt_threading(pl_version: str):
7379
assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"'
7480
path_ckpt = path_ckpts[-1]
7581

82+
# legacy load utility added in 1.5.0 (see https://github.com/Lightning-AI/pytorch-lightning/pull/9166)
83+
if pl_version == "local":
84+
pl_version = pl.__version__
85+
weights_only = not Version(pl_version) < Version("1.5.0")
86+
7687
def load_model():
7788
import torch
7889

7990
from lightning.pytorch.utilities.migration import pl_legacy_patch
8091

8192
with pl_legacy_patch():
82-
_ = torch.load(path_ckpt, weights_only=False)
93+
_ = torch.load(path_ckpt, weights_only=weights_only)
8394

8495
with patch("sys.path", [PATH_LEGACY] + sys.path):
8596
t1 = ThreadExceptionHandler(target=load_model)

0 commit comments

Comments
 (0)