@@ -526,6 +526,7 @@ def fit(
526
526
val_dataloaders : Optional [EVAL_DATALOADERS ] = None ,
527
527
datamodule : Optional [LightningDataModule ] = None ,
528
528
ckpt_path : Optional [_PATH ] = None ,
529
+ weights_only : bool = False ,
529
530
) -> None :
530
531
r"""Runs the full optimization routine.
531
532
@@ -573,7 +574,14 @@ def fit(
573
574
self .training = True
574
575
self .should_stop = False
575
576
call ._call_and_handle_interrupt (
576
- self , self ._fit_impl , model , train_dataloaders , val_dataloaders , datamodule , ckpt_path
577
+ self ,
578
+ self ._fit_impl ,
579
+ model ,
580
+ train_dataloaders ,
581
+ val_dataloaders ,
582
+ datamodule ,
583
+ ckpt_path ,
584
+ weights_only ,
577
585
)
578
586
579
587
def _fit_impl (
@@ -583,6 +591,7 @@ def _fit_impl(
583
591
val_dataloaders : Optional [EVAL_DATALOADERS ] = None ,
584
592
datamodule : Optional [LightningDataModule ] = None ,
585
593
ckpt_path : Optional [_PATH ] = None ,
594
+ weights_only : bool = False ,
586
595
) -> None :
587
596
log .debug (f"{ self .__class__ .__name__ } : trainer fit stage" )
588
597
@@ -610,7 +619,7 @@ def _fit_impl(
610
619
model_provided = True ,
611
620
model_connected = self .lightning_module is not None ,
612
621
)
613
- self ._run (model , ckpt_path = ckpt_path )
622
+ self ._run (model , ckpt_path = ckpt_path , weights_only = weights_only )
614
623
615
624
assert self .state .stopped
616
625
self .training = False
@@ -621,6 +630,7 @@ def validate(
621
630
model : Optional ["pl.LightningModule" ] = None ,
622
631
dataloaders : Optional [Union [EVAL_DATALOADERS , LightningDataModule ]] = None ,
623
632
ckpt_path : Optional [_PATH ] = None ,
633
+ weights_only : bool = False ,
624
634
verbose : bool = True ,
625
635
datamodule : Optional [LightningDataModule ] = None ,
626
636
) -> _EVALUATE_OUTPUT :
@@ -676,14 +686,15 @@ def validate(
676
686
self .state .status = TrainerStatus .RUNNING
677
687
self .validating = True
678
688
return call ._call_and_handle_interrupt (
679
- self , self ._validate_impl , model , dataloaders , ckpt_path , verbose , datamodule
689
+ self , self ._validate_impl , model , dataloaders , ckpt_path , weights_only , verbose , datamodule
680
690
)
681
691
682
692
def _validate_impl (
683
693
self ,
684
694
model : Optional ["pl.LightningModule" ] = None ,
685
695
dataloaders : Optional [Union [EVAL_DATALOADERS , LightningDataModule ]] = None ,
686
696
ckpt_path : Optional [_PATH ] = None ,
697
+ weights_only : bool = False ,
687
698
verbose : bool = True ,
688
699
datamodule : Optional [LightningDataModule ] = None ,
689
700
) -> Optional [Union [_PREDICT_OUTPUT , _EVALUATE_OUTPUT ]]:
@@ -717,7 +728,7 @@ def _validate_impl(
717
728
ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
718
729
self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
719
730
)
720
- results = self ._run (model , ckpt_path = ckpt_path )
731
+ results = self ._run (model , ckpt_path = ckpt_path , weights_only = weights_only )
721
732
# remove the tensors from the validation results
722
733
results = convert_tensors_to_scalars (results )
723
734
@@ -731,6 +742,7 @@ def test(
731
742
model : Optional ["pl.LightningModule" ] = None ,
732
743
dataloaders : Optional [Union [EVAL_DATALOADERS , LightningDataModule ]] = None ,
733
744
ckpt_path : Optional [_PATH ] = None ,
745
+ weights_only : bool = False ,
734
746
verbose : bool = True ,
735
747
datamodule : Optional [LightningDataModule ] = None ,
736
748
) -> _EVALUATE_OUTPUT :
@@ -787,14 +799,15 @@ def test(
787
799
self .state .status = TrainerStatus .RUNNING
788
800
self .testing = True
789
801
return call ._call_and_handle_interrupt (
790
- self , self ._test_impl , model , dataloaders , ckpt_path , verbose , datamodule
802
+ self , self ._test_impl , model , dataloaders , ckpt_path , weights_only , verbose , datamodule
791
803
)
792
804
793
805
def _test_impl (
794
806
self ,
795
807
model : Optional ["pl.LightningModule" ] = None ,
796
808
dataloaders : Optional [Union [EVAL_DATALOADERS , LightningDataModule ]] = None ,
797
809
ckpt_path : Optional [_PATH ] = None ,
810
+ weights_only : bool = False ,
798
811
verbose : bool = True ,
799
812
datamodule : Optional [LightningDataModule ] = None ,
800
813
) -> Optional [Union [_PREDICT_OUTPUT , _EVALUATE_OUTPUT ]]:
@@ -828,7 +841,7 @@ def _test_impl(
828
841
ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
829
842
self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
830
843
)
831
- results = self ._run (model , ckpt_path = ckpt_path )
844
+ results = self ._run (model , ckpt_path = ckpt_path , weights_only = weights_only )
832
845
# remove the tensors from the test results
833
846
results = convert_tensors_to_scalars (results )
834
847
@@ -844,6 +857,7 @@ def predict(
844
857
datamodule : Optional [LightningDataModule ] = None ,
845
858
return_predictions : Optional [bool ] = None ,
846
859
ckpt_path : Optional [_PATH ] = None ,
860
+ weights_only : bool = False ,
847
861
) -> Optional [_PREDICT_OUTPUT ]:
848
862
r"""Run inference on your data. This will call the model forward function to compute predictions. Useful to
849
863
perform distributed and batched predictions. Logging is disabled in the predict hooks.
@@ -899,7 +913,7 @@ def predict(
899
913
self .state .status = TrainerStatus .RUNNING
900
914
self .predicting = True
901
915
return call ._call_and_handle_interrupt (
902
- self , self ._predict_impl , model , dataloaders , datamodule , return_predictions , ckpt_path
916
+ self , self ._predict_impl , model , dataloaders , datamodule , return_predictions , ckpt_path , weights_only
903
917
)
904
918
905
919
def _predict_impl (
@@ -909,6 +923,7 @@ def _predict_impl(
909
923
datamodule : Optional [LightningDataModule ] = None ,
910
924
return_predictions : Optional [bool ] = None ,
911
925
ckpt_path : Optional [_PATH ] = None ,
926
+ weights_only : bool = False ,
912
927
) -> Optional [_PREDICT_OUTPUT ]:
913
928
# --------------------
914
929
# SETUP HOOK
@@ -939,15 +954,15 @@ def _predict_impl(
939
954
ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
940
955
self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
941
956
)
942
- results = self ._run (model , ckpt_path = ckpt_path )
957
+ results = self ._run (model , ckpt_path = ckpt_path , weights_only = weights_only )
943
958
944
959
assert self .state .stopped
945
960
self .predicting = False
946
961
947
962
return results
948
963
949
964
def _run (
950
- self , model : "pl.LightningModule" , ckpt_path : Optional [_PATH ] = None
965
+ self , model : "pl.LightningModule" , ckpt_path : Optional [_PATH ] = None , weights_only : bool = False
951
966
) -> Optional [Union [_EVALUATE_OUTPUT , _PREDICT_OUTPUT ]]:
952
967
if self .state .fn == TrainerFn .FITTING :
953
968
min_epochs , max_epochs = _parse_loop_limits (
@@ -992,7 +1007,7 @@ def _run(
992
1007
# check if we should delay restoring checkpoint till later
993
1008
if not self .strategy .restore_checkpoint_after_setup :
994
1009
log .debug (f"{ self .__class__ .__name__ } : restoring module and callbacks from checkpoint path: { ckpt_path } " )
995
- self ._checkpoint_connector ._restore_modules_and_callbacks (ckpt_path )
1010
+ self ._checkpoint_connector ._restore_modules_and_callbacks (ckpt_path , weights_only )
996
1011
997
1012
# reset logger connector
998
1013
self ._logger_connector .reset_results ()
0 commit comments