@@ -526,7 +526,7 @@ def fit(
526
526
val_dataloaders : Optional [EVAL_DATALOADERS ] = None ,
527
527
datamodule : Optional [LightningDataModule ] = None ,
528
528
ckpt_path : Optional [_PATH ] = None ,
529
- weights_only : bool = False ,
529
+ weights_only : Optional [ bool ] = None ,
530
530
) -> None :
531
531
r"""Runs the full optimization routine.
532
532
@@ -591,7 +591,7 @@ def _fit_impl(
591
591
val_dataloaders : Optional [EVAL_DATALOADERS ] = None ,
592
592
datamodule : Optional [LightningDataModule ] = None ,
593
593
ckpt_path : Optional [_PATH ] = None ,
594
- weights_only : bool = False ,
594
+ weights_only : Optional [ bool ] = None ,
595
595
) -> None :
596
596
log .debug (f"{ self .__class__ .__name__ } : trainer fit stage" )
597
597
@@ -630,7 +630,7 @@ def validate(
630
630
model : Optional ["pl.LightningModule" ] = None ,
631
631
dataloaders : Optional [Union [EVAL_DATALOADERS , LightningDataModule ]] = None ,
632
632
ckpt_path : Optional [_PATH ] = None ,
633
- weights_only : bool = False ,
633
+ weights_only : Optional [ bool ] = None ,
634
634
verbose : bool = True ,
635
635
datamodule : Optional [LightningDataModule ] = None ,
636
636
) -> _EVALUATE_OUTPUT :
@@ -694,7 +694,7 @@ def _validate_impl(
694
694
model : Optional ["pl.LightningModule" ] = None ,
695
695
dataloaders : Optional [Union [EVAL_DATALOADERS , LightningDataModule ]] = None ,
696
696
ckpt_path : Optional [_PATH ] = None ,
697
- weights_only : bool = False ,
697
+ weights_only : Optional [ bool ] = None ,
698
698
verbose : bool = True ,
699
699
datamodule : Optional [LightningDataModule ] = None ,
700
700
) -> Optional [Union [_PREDICT_OUTPUT , _EVALUATE_OUTPUT ]]:
@@ -742,7 +742,7 @@ def test(
742
742
model : Optional ["pl.LightningModule" ] = None ,
743
743
dataloaders : Optional [Union [EVAL_DATALOADERS , LightningDataModule ]] = None ,
744
744
ckpt_path : Optional [_PATH ] = None ,
745
- weights_only : bool = False ,
745
+ weights_only : Optional [ bool ] = None ,
746
746
verbose : bool = True ,
747
747
datamodule : Optional [LightningDataModule ] = None ,
748
748
) -> _EVALUATE_OUTPUT :
@@ -807,7 +807,7 @@ def _test_impl(
807
807
model : Optional ["pl.LightningModule" ] = None ,
808
808
dataloaders : Optional [Union [EVAL_DATALOADERS , LightningDataModule ]] = None ,
809
809
ckpt_path : Optional [_PATH ] = None ,
810
- weights_only : bool = False ,
810
+ weights_only : Optional [ bool ] = None ,
811
811
verbose : bool = True ,
812
812
datamodule : Optional [LightningDataModule ] = None ,
813
813
) -> Optional [Union [_PREDICT_OUTPUT , _EVALUATE_OUTPUT ]]:
@@ -857,7 +857,7 @@ def predict(
857
857
datamodule : Optional [LightningDataModule ] = None ,
858
858
return_predictions : Optional [bool ] = None ,
859
859
ckpt_path : Optional [_PATH ] = None ,
860
- weights_only : bool = False ,
860
+ weights_only : Optional [ bool ] = None ,
861
861
) -> Optional [_PREDICT_OUTPUT ]:
862
862
r"""Run inference on your data. This will call the model forward function to compute predictions. Useful to
863
863
perform distributed and batched predictions. Logging is disabled in the predict hooks.
@@ -923,7 +923,7 @@ def _predict_impl(
923
923
datamodule : Optional [LightningDataModule ] = None ,
924
924
return_predictions : Optional [bool ] = None ,
925
925
ckpt_path : Optional [_PATH ] = None ,
926
- weights_only : bool = False ,
926
+ weights_only : Optional [ bool ] = None ,
927
927
) -> Optional [_PREDICT_OUTPUT ]:
928
928
# --------------------
929
929
# SETUP HOOK
@@ -962,7 +962,10 @@ def _predict_impl(
962
962
return results
963
963
964
964
def _run (
965
- self , model : "pl.LightningModule" , ckpt_path : Optional [_PATH ] = None , weights_only : bool = False
965
+ self ,
966
+ model : "pl.LightningModule" ,
967
+ ckpt_path : Optional [_PATH ] = None ,
968
+ weights_only : Optional [bool ] = None ,
966
969
) -> Optional [Union [_EVALUATE_OUTPUT , _PREDICT_OUTPUT ]]:
967
970
if self .state .fn == TrainerFn .FITTING :
968
971
min_epochs , max_epochs = _parse_loop_limits (
@@ -1401,7 +1404,7 @@ def ckpt_path(self, ckpt_path: Optional[_PATH]) -> None:
1401
1404
self ._checkpoint_connector ._user_managed = bool (ckpt_path )
1402
1405
1403
1406
def save_checkpoint (
1404
- self , filepath : _PATH , weights_only : bool = False , storage_options : Optional [Any ] = None
1407
+ self , filepath : _PATH , weights_only : Optional [ bool ] = None , storage_options : Optional [Any ] = None
1405
1408
) -> None :
1406
1409
r"""Runs routine to create a checkpoint.
1407
1410
0 commit comments