@@ -177,6 +177,7 @@ def load_from_checkpoint(
177177 checkpoint_path : Union [_PATH , IO ],
178178 map_location : _MAP_LOCATION_TYPE = None ,
179179 hparams_file : Optional [_PATH ] = None ,
180+ weights_only : Optional [bool ] = None ,
180181 ** kwargs : Any ,
181182 ) -> Self :
182183 r"""Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint it stores the
@@ -206,6 +207,11 @@ def load_from_checkpoint(
206207 If your datamodule's ``hparams`` argument is :class:`~argparse.Namespace`
207208 and ``.yaml`` file has hierarchical structure, you need to refactor your datamodule to treat
208209 ``hparams`` as :class:`~dict`.
210+ weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other
211+ primitive types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use
212+ ``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using
213+ ``weights_only=True``. For more information, please refer to the
214+ `PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
209215 \**kwargs: Any extra keyword args needed to init the datamodule. Can also be used to override saved
210216 hyperparameter values.
211217
@@ -242,6 +248,7 @@ def load_from_checkpoint(
242248 map_location = map_location ,
243249 hparams_file = hparams_file ,
244250 strict = None ,
251+ weights_only = weights_only ,
245252 ** kwargs ,
246253 )
247254 return cast (Self , loaded )
0 commit comments