@@ -557,6 +557,12 @@ def fit(
557557 - ``'registry:version:v2'``: uses the default model set
558558 with ``Trainer(..., model_registry="my-model")`` and version 'v2'
559559
560+ weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
561+ ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
562+ an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
563+ recommend using ``weights_only=True``. For more information, please refer to the
564+ `PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
565+
560566 Raises:
561567 TypeError:
562568 If ``model`` is not :class:`~lightning.pytorch.core.LightningModule` for torch version less than
@@ -630,9 +636,9 @@ def validate(
630636 model : Optional ["pl.LightningModule" ] = None ,
631637 dataloaders : Optional [Union [EVAL_DATALOADERS , LightningDataModule ]] = None ,
632638 ckpt_path : Optional [_PATH ] = None ,
633- weights_only : Optional [bool ] = None ,
634639 verbose : bool = True ,
635640 datamodule : Optional [LightningDataModule ] = None ,
641+ weights_only : Optional [bool ] = None ,
636642 ) -> _EVALUATE_OUTPUT :
637643 r"""Perform one evaluation epoch over the validation set.
638644
@@ -653,6 +659,12 @@ def validate(
653659 datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
654660 the :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook.
655661
662+ weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
663+ ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
664+ an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
665+ recommend using ``weights_only=True``. For more information, please refer to the
666+ `PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
667+
656668 For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
657669
658670 Returns:
@@ -686,17 +698,17 @@ def validate(
686698 self .state .status = TrainerStatus .RUNNING
687699 self .validating = True
688700 return call ._call_and_handle_interrupt (
689- self , self ._validate_impl , model , dataloaders , ckpt_path , weights_only , verbose , datamodule
701+ self , self ._validate_impl , model , dataloaders , ckpt_path , verbose , datamodule , weights_only
690702 )
691703
692704 def _validate_impl (
693705 self ,
694706 model : Optional ["pl.LightningModule" ] = None ,
695707 dataloaders : Optional [Union [EVAL_DATALOADERS , LightningDataModule ]] = None ,
696708 ckpt_path : Optional [_PATH ] = None ,
697- weights_only : Optional [bool ] = None ,
698709 verbose : bool = True ,
699710 datamodule : Optional [LightningDataModule ] = None ,
711+ weights_only : Optional [bool ] = None ,
700712 ) -> Optional [Union [_PREDICT_OUTPUT , _EVALUATE_OUTPUT ]]:
701713 # --------------------
702714 # SETUP HOOK
@@ -742,9 +754,9 @@ def test(
742754 model : Optional ["pl.LightningModule" ] = None ,
743755 dataloaders : Optional [Union [EVAL_DATALOADERS , LightningDataModule ]] = None ,
744756 ckpt_path : Optional [_PATH ] = None ,
745- weights_only : Optional [bool ] = None ,
746757 verbose : bool = True ,
747758 datamodule : Optional [LightningDataModule ] = None ,
759+ weights_only : Optional [bool ] = None ,
748760 ) -> _EVALUATE_OUTPUT :
749761 r"""Perform one evaluation epoch over the test set. It's separated from fit to make sure you never run on your
750762 test set until you want to.
@@ -766,6 +778,12 @@ def test(
766778 datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
767779 the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook.
768780
781+ weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
782+ ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
783+ an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
784+ recommend using ``weights_only=True``. For more information, please refer to the
785+ `PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
786+
769787 For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
770788
771789 Returns:
@@ -799,17 +817,17 @@ def test(
799817 self .state .status = TrainerStatus .RUNNING
800818 self .testing = True
801819 return call ._call_and_handle_interrupt (
802- self , self ._test_impl , model , dataloaders , ckpt_path , weights_only , verbose , datamodule
820+ self , self ._test_impl , model , dataloaders , ckpt_path , verbose , datamodule , weights_only
803821 )
804822
805823 def _test_impl (
806824 self ,
807825 model : Optional ["pl.LightningModule" ] = None ,
808826 dataloaders : Optional [Union [EVAL_DATALOADERS , LightningDataModule ]] = None ,
809827 ckpt_path : Optional [_PATH ] = None ,
810- weights_only : Optional [bool ] = None ,
811828 verbose : bool = True ,
812829 datamodule : Optional [LightningDataModule ] = None ,
830+ weights_only : Optional [bool ] = None ,
813831 ) -> Optional [Union [_PREDICT_OUTPUT , _EVALUATE_OUTPUT ]]:
814832 # --------------------
815833 # SETUP HOOK
@@ -880,6 +898,12 @@ def predict(
880898 Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
881899 if a checkpoint callback is configured.
882900
901+ weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
902+ ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
903+ an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
904+ recommend using ``weights_only=True``. For more information, please refer to the
905+ `PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
906+
883907 For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
884908
885909 Returns:
@@ -920,7 +944,7 @@ def predict(
920944 datamodule ,
921945 return_predictions ,
922946 ckpt_path ,
923- weights_only = weights_only ,
947+ weights_only ,
924948 )
925949
926950 def _predict_impl (
0 commit comments