Skip to content

Commit 12bd0d6

Browse files
committed
datamodule weights_only args
1 parent 861d7e0 commit 12bd0d6

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

src/lightning/pytorch/core/datamodule.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def load_from_checkpoint(
177177
checkpoint_path: Union[_PATH, IO],
178178
map_location: _MAP_LOCATION_TYPE = None,
179179
hparams_file: Optional[_PATH] = None,
180+
weights_only: Optional[bool] = None,
180181
**kwargs: Any,
181182
) -> Self:
182183
r"""Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint it stores the
@@ -206,6 +207,11 @@ def load_from_checkpoint(
206207
If your datamodule's ``hparams`` argument is :class:`~argparse.Namespace`
207208
and ``.yaml`` file has hierarchical structure, you need to refactor your datamodule to treat
208209
``hparams`` as :class:`~dict`.
210+
weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other
211+
primitive types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use
212+
``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using
213+
``weights_only=True``. For more information, please refer to the
214+
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
209215
\**kwargs: Any extra keyword args needed to init the datamodule. Can also be used to override saved
210216
hyperparameter values.
211217
@@ -242,6 +248,7 @@ def load_from_checkpoint(
242248
map_location=map_location,
243249
hparams_file=hparams_file,
244250
strict=None,
251+
weights_only=weights_only,
245252
**kwargs,
246253
)
247254
return cast(Self, loaded)

src/lightning/pytorch/core/module.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1724,6 +1724,11 @@ def load_from_checkpoint(
17241724
strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys
17251725
returned by this module's state dict. Defaults to ``True`` unless ``LightningModule.strict_loading`` is
17261726
set, in which case it defaults to the value of ``LightningModule.strict_loading``.
1727+
weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other
1728+
primitive types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use
1729+
``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using
1730+
``weights_only=True``. For more information, please refer to the
1731+
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
17271732
\**kwargs: Any extra keyword args needed to init the model. Can also be used to override saved
17281733
hyperparameter values.
17291734

0 commit comments

Comments
 (0)