|
18 | 18 |
|
19 | 19 | import pytest
|
20 | 20 | import torch
|
| 21 | +from packaging.version import Version |
21 | 22 |
|
22 | 23 | import lightning.pytorch as pl
|
23 | 24 | from lightning.pytorch import Callback, Trainer
|
@@ -45,7 +46,12 @@ def test_load_legacy_checkpoints(tmp_path, pl_version: str):
|
45 | 46 | assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"'
|
46 | 47 | path_ckpt = path_ckpts[-1]
|
47 | 48 |
|
48 |
| - model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24) |
| 49 | + # legacy load utility added in 1.5.0 (see https://github.com/Lightning-AI/pytorch-lightning/pull/9166) |
| 50 | + if pl_version == "local": |
| 51 | + pl_version = pl.__version__ |
| 52 | + weights_only = not Version(pl_version) < Version("1.5.0") |
| 53 | + |
| 54 | + model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24, weights_only=weights_only) |
49 | 55 | trainer = Trainer(default_root_dir=tmp_path)
|
50 | 56 | dm = ClassifDataModule(num_features=24, length=6000, batch_size=128, n_clusters_per_class=2, n_informative=8)
|
51 | 57 | res = trainer.test(model, datamodule=dm)
|
@@ -73,13 +79,18 @@ def test_legacy_ckpt_threading(pl_version: str):
|
73 | 79 | assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"'
|
74 | 80 | path_ckpt = path_ckpts[-1]
|
75 | 81 |
|
| 82 | + # legacy load utility added in 1.5.0 (see https://github.com/Lightning-AI/pytorch-lightning/pull/9166) |
| 83 | + if pl_version == "local": |
| 84 | + pl_version = pl.__version__ |
| 85 | + weights_only = not Version(pl_version) < Version("1.5.0") |
| 86 | + |
76 | 87 | def load_model():
|
77 | 88 | import torch
|
78 | 89 |
|
79 | 90 | from lightning.pytorch.utilities.migration import pl_legacy_patch
|
80 | 91 |
|
81 | 92 | with pl_legacy_patch():
|
82 |
| - _ = torch.load(path_ckpt, weights_only=False) |
| 93 | + _ = torch.load(path_ckpt, weights_only=weights_only) |
83 | 94 |
|
84 | 95 | with patch("sys.path", [PATH_LEGACY] + sys.path):
|
85 | 96 | t1 = ThreadExceptionHandler(target=load_model)
|
|
0 commit comments