@@ -557,6 +557,12 @@ def fit(
557
557
- ``'registry:version:v2'``: uses the default model set
558
558
with ``Trainer(..., model_registry="my-model")`` and version 'v2'
559
559
560
+ weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
561
+ ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
562
+ an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
563
+ recommend using ``weights_only=True``. For more information, please refer to the
564
+ `PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
565
+
560
566
Raises:
561
567
TypeError:
562
568
If ``model`` is not :class:`~lightning.pytorch.core.LightningModule` for torch version less than
@@ -630,9 +636,9 @@ def validate(
630
636
model : Optional ["pl.LightningModule" ] = None ,
631
637
dataloaders : Optional [Union [EVAL_DATALOADERS , LightningDataModule ]] = None ,
632
638
ckpt_path : Optional [_PATH ] = None ,
633
- weights_only : Optional [bool ] = None ,
634
639
verbose : bool = True ,
635
640
datamodule : Optional [LightningDataModule ] = None ,
641
+ weights_only : Optional [bool ] = None ,
636
642
) -> _EVALUATE_OUTPUT :
637
643
r"""Perform one evaluation epoch over the validation set.
638
644
@@ -653,6 +659,12 @@ def validate(
653
659
datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
654
660
the :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook.
655
661
662
+ weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
663
+ ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
664
+ an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
665
+ recommend using ``weights_only=True``. For more information, please refer to the
666
+ `PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
667
+
656
668
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
657
669
658
670
Returns:
@@ -686,17 +698,17 @@ def validate(
686
698
self .state .status = TrainerStatus .RUNNING
687
699
self .validating = True
688
700
return call ._call_and_handle_interrupt (
689
- self , self ._validate_impl , model , dataloaders , ckpt_path , weights_only , verbose , datamodule
701
+ self , self ._validate_impl , model , dataloaders , ckpt_path , verbose , datamodule , weights_only
690
702
)
691
703
692
704
def _validate_impl (
693
705
self ,
694
706
model : Optional ["pl.LightningModule" ] = None ,
695
707
dataloaders : Optional [Union [EVAL_DATALOADERS , LightningDataModule ]] = None ,
696
708
ckpt_path : Optional [_PATH ] = None ,
697
- weights_only : Optional [bool ] = None ,
698
709
verbose : bool = True ,
699
710
datamodule : Optional [LightningDataModule ] = None ,
711
+ weights_only : Optional [bool ] = None ,
700
712
) -> Optional [Union [_PREDICT_OUTPUT , _EVALUATE_OUTPUT ]]:
701
713
# --------------------
702
714
# SETUP HOOK
@@ -742,9 +754,9 @@ def test(
742
754
model : Optional ["pl.LightningModule" ] = None ,
743
755
dataloaders : Optional [Union [EVAL_DATALOADERS , LightningDataModule ]] = None ,
744
756
ckpt_path : Optional [_PATH ] = None ,
745
- weights_only : Optional [bool ] = None ,
746
757
verbose : bool = True ,
747
758
datamodule : Optional [LightningDataModule ] = None ,
759
+ weights_only : Optional [bool ] = None ,
748
760
) -> _EVALUATE_OUTPUT :
749
761
r"""Perform one evaluation epoch over the test set. It's separated from fit to make sure you never run on your
750
762
test set until you want to.
@@ -766,6 +778,12 @@ def test(
766
778
datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
767
779
the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook.
768
780
781
+ weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
782
+ ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
783
+ an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
784
+ recommend using ``weights_only=True``. For more information, please refer to the
785
+ `PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
786
+
769
787
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
770
788
771
789
Returns:
@@ -799,17 +817,17 @@ def test(
799
817
self .state .status = TrainerStatus .RUNNING
800
818
self .testing = True
801
819
return call ._call_and_handle_interrupt (
802
- self , self ._test_impl , model , dataloaders , ckpt_path , weights_only , verbose , datamodule
820
+ self , self ._test_impl , model , dataloaders , ckpt_path , verbose , datamodule , weights_only
803
821
)
804
822
805
823
def _test_impl (
806
824
self ,
807
825
model : Optional ["pl.LightningModule" ] = None ,
808
826
dataloaders : Optional [Union [EVAL_DATALOADERS , LightningDataModule ]] = None ,
809
827
ckpt_path : Optional [_PATH ] = None ,
810
- weights_only : Optional [bool ] = None ,
811
828
verbose : bool = True ,
812
829
datamodule : Optional [LightningDataModule ] = None ,
830
+ weights_only : Optional [bool ] = None ,
813
831
) -> Optional [Union [_PREDICT_OUTPUT , _EVALUATE_OUTPUT ]]:
814
832
# --------------------
815
833
# SETUP HOOK
@@ -880,6 +898,12 @@ def predict(
880
898
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
881
899
if a checkpoint callback is configured.
882
900
901
+ weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
902
+ ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
903
+ an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
904
+ recommend using ``weights_only=True``. For more information, please refer to the
905
+ `PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
906
+
883
907
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
884
908
885
909
Returns:
@@ -920,7 +944,7 @@ def predict(
920
944
datamodule ,
921
945
return_predictions ,
922
946
ckpt_path ,
923
- weights_only = weights_only ,
947
+ weights_only ,
924
948
)
925
949
926
950
def _predict_impl (
0 commit comments