Skip to content

Commit 74e5e5a

Browse files
committed
weights_only according pl version
1 parent 2ab89a2 commit 74e5e5a

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,12 @@ def test_load_legacy_checkpoints(tmp_path, pl_version: str):
4646
assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"'
4747
path_ckpt = path_ckpts[-1]
4848

49-
model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24)
49+
if pl_version == "local":
50+
pl_version = pl.__version__
51+
52+
weights_only = Version(pl_version) >= Version("1.5.0")
53+
54+
model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24, weights_only=weights_only)
5055
trainer = Trainer(default_root_dir=tmp_path)
5156
dm = ClassifDataModule(num_features=24, length=6000, batch_size=128, n_clusters_per_class=2, n_informative=8)
5257
res = trainer.test(model, datamodule=dm)

0 commit comments

Comments
 (0)