Skip to content

Commit 46cc788

Browse files
committed
trainer default to weights_only=None
1 parent 653dd6f commit 46cc788

File tree

3 files changed

+16
-13
lines changed

3 files changed

+16
-13
lines changed

src/lightning/pytorch/strategies/model_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def save_checkpoint(
329329
return super().save_checkpoint(checkpoint=checkpoint, filepath=path)
330330

331331
@override
332-
def load_checkpoint(self, checkpoint_path: _PATH, weights_only: bool = False) -> dict[str, Any]:
332+
def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]:
333333
# broadcast the path from rank 0 to ensure all the states are loaded from a common path
334334
path = Path(self.broadcast(checkpoint_path))
335335
state = {

src/lightning/pytorch/trainer/connectors/checkpoint_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _hpc_resume_path(self) -> Optional[str]:
6464
return dir_path_hpc + fs.sep + f"hpc_ckpt_{max_version}.ckpt"
6565
return None
6666

67-
def resume_start(self, checkpoint_path: Optional[_PATH] = None, weights_only: bool = False) -> None:
67+
def resume_start(self, checkpoint_path: Optional[_PATH] = None, weights_only: Optional[bool] = None) -> None:
6868
"""Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:
6969
7070
1. from HPC weights if `checkpoint_path` is ``None`` and on SLURM or passed keyword `"hpc"`.
@@ -404,7 +404,7 @@ def restore_lr_schedulers(self) -> None:
404404
config.scheduler.load_state_dict(lrs_state)
405405

406406
def _restore_modules_and_callbacks(
407-
self, checkpoint_path: Optional[_PATH] = None, weights_only: bool = False
407+
self, checkpoint_path: Optional[_PATH] = None, weights_only: Optional[bool] = None
408408
) -> None:
409409
# restore modules after setup
410410
self.resume_start(checkpoint_path, weights_only)

src/lightning/pytorch/trainer/trainer.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ def fit(
526526
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
527527
datamodule: Optional[LightningDataModule] = None,
528528
ckpt_path: Optional[_PATH] = None,
529-
weights_only: bool = False,
529+
weights_only: Optional[bool] = None,
530530
) -> None:
531531
r"""Runs the full optimization routine.
532532
@@ -591,7 +591,7 @@ def _fit_impl(
591591
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
592592
datamodule: Optional[LightningDataModule] = None,
593593
ckpt_path: Optional[_PATH] = None,
594-
weights_only: bool = False,
594+
weights_only: Optional[bool] = None,
595595
) -> None:
596596
log.debug(f"{self.__class__.__name__}: trainer fit stage")
597597

@@ -630,7 +630,7 @@ def validate(
630630
model: Optional["pl.LightningModule"] = None,
631631
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
632632
ckpt_path: Optional[_PATH] = None,
633-
weights_only: bool = False,
633+
weights_only: Optional[bool] = None,
634634
verbose: bool = True,
635635
datamodule: Optional[LightningDataModule] = None,
636636
) -> _EVALUATE_OUTPUT:
@@ -694,7 +694,7 @@ def _validate_impl(
694694
model: Optional["pl.LightningModule"] = None,
695695
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
696696
ckpt_path: Optional[_PATH] = None,
697-
weights_only: bool = False,
697+
weights_only: Optional[bool] = None,
698698
verbose: bool = True,
699699
datamodule: Optional[LightningDataModule] = None,
700700
) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]:
@@ -742,7 +742,7 @@ def test(
742742
model: Optional["pl.LightningModule"] = None,
743743
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
744744
ckpt_path: Optional[_PATH] = None,
745-
weights_only: bool = False,
745+
weights_only: Optional[bool] = None,
746746
verbose: bool = True,
747747
datamodule: Optional[LightningDataModule] = None,
748748
) -> _EVALUATE_OUTPUT:
@@ -807,7 +807,7 @@ def _test_impl(
807807
model: Optional["pl.LightningModule"] = None,
808808
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
809809
ckpt_path: Optional[_PATH] = None,
810-
weights_only: bool = False,
810+
weights_only: Optional[bool] = None,
811811
verbose: bool = True,
812812
datamodule: Optional[LightningDataModule] = None,
813813
) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]:
@@ -857,7 +857,7 @@ def predict(
857857
datamodule: Optional[LightningDataModule] = None,
858858
return_predictions: Optional[bool] = None,
859859
ckpt_path: Optional[_PATH] = None,
860-
weights_only: bool = False,
860+
weights_only: Optional[bool] = None,
861861
) -> Optional[_PREDICT_OUTPUT]:
862862
r"""Run inference on your data. This will call the model forward function to compute predictions. Useful to
863863
perform distributed and batched predictions. Logging is disabled in the predict hooks.
@@ -923,7 +923,7 @@ def _predict_impl(
923923
datamodule: Optional[LightningDataModule] = None,
924924
return_predictions: Optional[bool] = None,
925925
ckpt_path: Optional[_PATH] = None,
926-
weights_only: bool = False,
926+
weights_only: Optional[bool] = None,
927927
) -> Optional[_PREDICT_OUTPUT]:
928928
# --------------------
929929
# SETUP HOOK
@@ -962,7 +962,10 @@ def _predict_impl(
962962
return results
963963

964964
def _run(
965-
self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None, weights_only: bool = False
965+
self,
966+
model: "pl.LightningModule",
967+
ckpt_path: Optional[_PATH] = None,
968+
weights_only: Optional[bool] = None,
966969
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
967970
if self.state.fn == TrainerFn.FITTING:
968971
min_epochs, max_epochs = _parse_loop_limits(
@@ -1401,7 +1404,7 @@ def ckpt_path(self, ckpt_path: Optional[_PATH]) -> None:
14011404
self._checkpoint_connector._user_managed = bool(ckpt_path)
14021405

14031406
def save_checkpoint(
1404-
self, filepath: _PATH, weights_only: bool = False, storage_options: Optional[Any] = None
1407+
self, filepath: _PATH, weights_only: Optional[bool] = None, storage_options: Optional[Any] = None
14051408
) -> None:
14061409
r"""Runs routine to create a checkpoint.
14071410

0 commit comments

Comments
 (0)