Skip to content

Commit a4c9efe

Browse files
committed
set env var TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 for pl < 1.5.0
1 parent 74e5e5a commit a4c9efe

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

src/lightning/fabric/utilities/cloud_io.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from fsspec.implementations.local import AbstractFileSystem
2727
from lightning_utilities.core.imports import module_available
2828

29-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6
3029
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
3130

3231
log = logging.getLogger(__name__)
@@ -49,11 +48,6 @@ def _load(
4948
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
5049
5150
"""
52-
# default to `weights_only=True` for torch>=2.6
53-
if weights_only is None and _TORCH_GREATER_EQUAL_2_6:
54-
log.debug("Defaulting to `weights_only=True` for torch>=2.6.")
55-
weights_only = True
56-
5751
if not isinstance(path_or_url, (str, Path)):
5852
# any sort of BytesIO or similar
5953
return torch.load(

tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,14 @@ def load_model():
105105

106106
@pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS)
107107
@RunIf(sklearn=True)
108-
def test_resume_legacy_checkpoints(tmp_path, pl_version: str):
108+
def test_resume_legacy_checkpoints(monkeypatch, tmp_path, pl_version: str):
109109
PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version)
110110
with patch("sys.path", [PATH_LEGACY] + sys.path):
111+
if pl_version == "local":
112+
pl_version = pl.__version__
113+
if Version(pl_version) < Version("1.5.0"):
114+
monkeypatch.setenv("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "1")
115+
111116
path_ckpts = sorted(glob.glob(os.path.join(PATH_LEGACY, f"*{CHECKPOINT_EXTENSION}")))
112117
assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"'
113118
path_ckpt = path_ckpts[-1]

0 commit comments

Comments
 (0)