Skip to content

Commit 28f53ae

Browse files
committed
add weights_only arg to checkpoint_io
1 parent 601e300 commit 28f53ae

File tree

4 files changed

+19
-8
lines changed

4 files changed

+19
-8
lines changed

src/lightning/fabric/plugins/io/checkpoint_io.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,20 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio
4747
"""
4848

4949
@abstractmethod
50-
def load_checkpoint(self, path: _PATH, map_location: Optional[Any] = None) -> dict[str, Any]:
50+
def load_checkpoint(
51+
self, path: _PATH, map_location: Optional[Any] = None, weights_only: bool = True
52+
) -> dict[str, Any]:
5153
"""Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages.
5254
5355
Args:
5456
path: Path to checkpoint
5557
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
5658
locations.
59+
weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other
60+
primitive types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use
61+
``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using
62+
``weights_only=True``. For more information, please refer to the
63+
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`__.
5764
5865
Returns: The loaded checkpoint.
5966

src/lightning/fabric/plugins/io/torch_io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio
5959

6060
@override
6161
def load_checkpoint(
62-
self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage
62+
self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage, weights_only: bool = True
6363
) -> dict[str, Any]:
6464
"""Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files.
6565
@@ -80,7 +80,7 @@ def load_checkpoint(
8080
if not fs.exists(path):
8181
raise FileNotFoundError(f"Checkpoint file not found: {path}")
8282

83-
return pl_load(path, map_location=map_location)
83+
return pl_load(path, map_location=map_location, weights_only=weights_only)
8484

8585
@override
8686
def remove_checkpoint(self, path: _PATH) -> None:

tests/legacy/generate_checkpoints.sh

100644100755
Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ printf "PYTHONPATH: $PYTHONPATH"
1616
rm -rf $ENV_PATH
1717

1818
function create_and_save_checkpoint {
19-
python --version
20-
python -m pip --version
21-
python -m pip list
19+
# python --version
20+
# python -m pip --version
21+
# python -m pip list
2222

2323
python $LEGACY_FOLDER/simple_classif_training.py $pl_ver
2424

@@ -52,10 +52,12 @@ done
5252
if [[ -z "$@" ]]; then
5353
printf "\n\n processing local version\n"
5454

55-
python -m pip install \
55+
# python -m pip install \
56+
uv pip install \
5657
-r $LEGACY_FOLDER/requirements.txt \
5758
-r "$(dirname $TESTS_FOLDER)/requirements/pytorch/test.txt" \
5859
-f https://download.pytorch.org/whl/cpu/torch_stable.html
5960
pl_ver="local"
61+
# pl_ver=$(python -c "import lightning.pytorch as pl; print(pl.__version__)")
6062
create_and_save_checkpoint
6163
fi

tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ class CustomCheckpointIO(CheckpointIO):
3232
def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
3333
torch.save(checkpoint, path)
3434

35-
def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> dict[str, Any]:
35+
def load_checkpoint(
36+
self, path: _PATH, storage_options: Optional[Any] = None, weights_only: bool = True
37+
) -> dict[str, Any]:
3638
return torch.load(path, weights_only=True)
3739

3840
def remove_checkpoint(self, path: _PATH) -> None:

0 commit comments

Comments
 (0)