Skip to content

Commit 436552d

Browse files
committed
weights_only as last arg
1 parent 75ad865 commit 436552d

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

src/lightning/pytorch/trainer/connectors/checkpoint_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def resume_end(self) -> None:
230230
# wait for all to catch up
231231
self.trainer.strategy.barrier("_CheckpointConnector.resume_end")
232232

233-
def restore(self, checkpoint_path: Optional[_PATH] = None) -> None:
233+
def restore(self, checkpoint_path: Optional[_PATH] = None, weights_only: Optional[bool] = None) -> None:
234234
"""Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file through file-read and
235235
state-restore, in this priority:
236236
@@ -244,7 +244,7 @@ def restore(self, checkpoint_path: Optional[_PATH] = None) -> None:
244244
checkpoint_path: Path to a PyTorch Lightning checkpoint file.
245245
246246
"""
247-
self.resume_start(checkpoint_path)
247+
self.resume_start(checkpoint_path, weights_only=weights_only)
248248

249249
# restore module states
250250
self.restore_datamodule()

src/lightning/pytorch/trainer/trainer.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,12 @@ def fit(
557557
- ``'registry:version:v2'``: uses the default model set
558558
with ``Trainer(..., model_registry="my-model")`` and version 'v2'
559559
560+
weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
561+
``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
562+
an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
563+
recommend using ``weights_only=True``. For more information, please refer to the
564+
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
565+
560566
Raises:
561567
TypeError:
562568
If ``model`` is not :class:`~lightning.pytorch.core.LightningModule` for torch version less than
@@ -630,9 +636,9 @@ def validate(
630636
model: Optional["pl.LightningModule"] = None,
631637
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
632638
ckpt_path: Optional[_PATH] = None,
633-
weights_only: Optional[bool] = None,
634639
verbose: bool = True,
635640
datamodule: Optional[LightningDataModule] = None,
641+
weights_only: Optional[bool] = None,
636642
) -> _EVALUATE_OUTPUT:
637643
r"""Perform one evaluation epoch over the validation set.
638644
@@ -653,6 +659,12 @@ def validate(
653659
datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
654660
the :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook.
655661
662+
weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
663+
``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
664+
an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
665+
recommend using ``weights_only=True``. For more information, please refer to the
666+
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
667+
656668
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
657669
658670
Returns:
@@ -686,17 +698,17 @@ def validate(
686698
self.state.status = TrainerStatus.RUNNING
687699
self.validating = True
688700
return call._call_and_handle_interrupt(
689-
self, self._validate_impl, model, dataloaders, ckpt_path, weights_only, verbose, datamodule
701+
self, self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule, weights_only
690702
)
691703

692704
def _validate_impl(
693705
self,
694706
model: Optional["pl.LightningModule"] = None,
695707
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
696708
ckpt_path: Optional[_PATH] = None,
697-
weights_only: Optional[bool] = None,
698709
verbose: bool = True,
699710
datamodule: Optional[LightningDataModule] = None,
711+
weights_only: Optional[bool] = None,
700712
) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]:
701713
# --------------------
702714
# SETUP HOOK
@@ -742,9 +754,9 @@ def test(
742754
model: Optional["pl.LightningModule"] = None,
743755
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
744756
ckpt_path: Optional[_PATH] = None,
745-
weights_only: Optional[bool] = None,
746757
verbose: bool = True,
747758
datamodule: Optional[LightningDataModule] = None,
759+
weights_only: Optional[bool] = None,
748760
) -> _EVALUATE_OUTPUT:
749761
r"""Perform one evaluation epoch over the test set. It's separated from fit to make sure you never run on your
750762
test set until you want to.
@@ -766,6 +778,12 @@ def test(
766778
datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
767779
the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook.
768780
781+
weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
782+
``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
783+
an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
784+
recommend using ``weights_only=True``. For more information, please refer to the
785+
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
786+
769787
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
770788
771789
Returns:
@@ -799,17 +817,17 @@ def test(
799817
self.state.status = TrainerStatus.RUNNING
800818
self.testing = True
801819
return call._call_and_handle_interrupt(
802-
self, self._test_impl, model, dataloaders, ckpt_path, weights_only, verbose, datamodule
820+
self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule, weights_only
803821
)
804822

805823
def _test_impl(
806824
self,
807825
model: Optional["pl.LightningModule"] = None,
808826
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
809827
ckpt_path: Optional[_PATH] = None,
810-
weights_only: Optional[bool] = None,
811828
verbose: bool = True,
812829
datamodule: Optional[LightningDataModule] = None,
830+
weights_only: Optional[bool] = None,
813831
) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]:
814832
# --------------------
815833
# SETUP HOOK
@@ -880,6 +898,12 @@ def predict(
880898
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
881899
if a checkpoint callback is configured.
882900
901+
weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
902+
``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
903+
an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
904+
recommend using ``weights_only=True``. For more information, please refer to the
905+
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
906+
883907
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
884908
885909
Returns:
@@ -920,7 +944,7 @@ def predict(
920944
datamodule,
921945
return_predictions,
922946
ckpt_path,
923-
weights_only=weights_only,
947+
weights_only,
924948
)
925949

926950
def _predict_impl(

0 commit comments

Comments
 (0)