Skip to content

Commit 5ac9695

Browse files
committed
add weights_only to trainer.fit, validate, test, predict
1 parent a3183ba commit 5ac9695

File tree

3 files changed

+33
-16
lines changed

3 files changed

+33
-16
lines changed

src/lightning/pytorch/strategies/strategy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,9 +363,9 @@ def lightning_module(self) -> Optional["pl.LightningModule"]:
363363
"""Returns the pure LightningModule without potential wrappers."""
364364
return self._lightning_module
365365

366-
def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
366+
def load_checkpoint(self, checkpoint_path: _PATH, weights_only: bool) -> dict[str, Any]:
367367
torch.cuda.empty_cache()
368-
return self.checkpoint_io.load_checkpoint(checkpoint_path)
368+
return self.checkpoint_io.load_checkpoint(checkpoint_path, weights_only=weights_only)
369369

370370
def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None:
371371
assert self.lightning_module is not None

src/lightning/pytorch/trainer/connectors/checkpoint_connector.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _hpc_resume_path(self) -> Optional[str]:
6464
return dir_path_hpc + fs.sep + f"hpc_ckpt_{max_version}.ckpt"
6565
return None
6666

67-
def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
67+
def resume_start(self, checkpoint_path: Optional[_PATH] = None, weights_only: bool = False) -> None:
6868
"""Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:
6969
7070
1. from HPC weights if `checkpoint_path` is ``None`` and on SLURM or passed keyword `"hpc"`.
@@ -80,7 +80,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
8080

8181
rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}")
8282
with pl_legacy_patch():
83-
loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path)
83+
loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path, weights_only)
8484
self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path)
8585

8686
def _select_ckpt_path(
@@ -403,9 +403,11 @@ def restore_lr_schedulers(self) -> None:
403403
for config, lrs_state in zip(self.trainer.lr_scheduler_configs, lr_schedulers):
404404
config.scheduler.load_state_dict(lrs_state)
405405

406-
def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None:
406+
def _restore_modules_and_callbacks(
407+
self, checkpoint_path: Optional[_PATH] = None, weights_only: bool = False
408+
) -> None:
407409
# restore modules after setup
408-
self.resume_start(checkpoint_path)
410+
self.resume_start(checkpoint_path, weights_only)
409411
self.restore_model()
410412
self.restore_datamodule()
411413
self.restore_callbacks()

src/lightning/pytorch/trainer/trainer.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,7 @@ def fit(
526526
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
527527
datamodule: Optional[LightningDataModule] = None,
528528
ckpt_path: Optional[_PATH] = None,
529+
weights_only: bool = False,
529530
) -> None:
530531
r"""Runs the full optimization routine.
531532
@@ -573,7 +574,14 @@ def fit(
573574
self.training = True
574575
self.should_stop = False
575576
call._call_and_handle_interrupt(
576-
self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
577+
self,
578+
self._fit_impl,
579+
model,
580+
train_dataloaders,
581+
val_dataloaders,
582+
datamodule,
583+
ckpt_path,
584+
weights_only,
577585
)
578586

579587
def _fit_impl(
@@ -583,6 +591,7 @@ def _fit_impl(
583591
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
584592
datamodule: Optional[LightningDataModule] = None,
585593
ckpt_path: Optional[_PATH] = None,
594+
weights_only: bool = False,
586595
) -> None:
587596
log.debug(f"{self.__class__.__name__}: trainer fit stage")
588597

@@ -610,7 +619,7 @@ def _fit_impl(
610619
model_provided=True,
611620
model_connected=self.lightning_module is not None,
612621
)
613-
self._run(model, ckpt_path=ckpt_path)
622+
self._run(model, ckpt_path=ckpt_path, weights_only=weights_only)
614623

615624
assert self.state.stopped
616625
self.training = False
@@ -621,6 +630,7 @@ def validate(
621630
model: Optional["pl.LightningModule"] = None,
622631
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
623632
ckpt_path: Optional[_PATH] = None,
633+
weights_only: bool = False,
624634
verbose: bool = True,
625635
datamodule: Optional[LightningDataModule] = None,
626636
) -> _EVALUATE_OUTPUT:
@@ -676,14 +686,15 @@ def validate(
676686
self.state.status = TrainerStatus.RUNNING
677687
self.validating = True
678688
return call._call_and_handle_interrupt(
679-
self, self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule
689+
self, self._validate_impl, model, dataloaders, ckpt_path, weights_only, verbose, datamodule
680690
)
681691

682692
def _validate_impl(
683693
self,
684694
model: Optional["pl.LightningModule"] = None,
685695
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
686696
ckpt_path: Optional[_PATH] = None,
697+
weights_only: bool = False,
687698
verbose: bool = True,
688699
datamodule: Optional[LightningDataModule] = None,
689700
) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]:
@@ -717,7 +728,7 @@ def _validate_impl(
717728
ckpt_path = self._checkpoint_connector._select_ckpt_path(
718729
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
719730
)
720-
results = self._run(model, ckpt_path=ckpt_path)
731+
results = self._run(model, ckpt_path=ckpt_path, weights_only=weights_only)
721732
# remove the tensors from the validation results
722733
results = convert_tensors_to_scalars(results)
723734

@@ -731,6 +742,7 @@ def test(
731742
model: Optional["pl.LightningModule"] = None,
732743
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
733744
ckpt_path: Optional[_PATH] = None,
745+
weights_only: bool = False,
734746
verbose: bool = True,
735747
datamodule: Optional[LightningDataModule] = None,
736748
) -> _EVALUATE_OUTPUT:
@@ -787,14 +799,15 @@ def test(
787799
self.state.status = TrainerStatus.RUNNING
788800
self.testing = True
789801
return call._call_and_handle_interrupt(
790-
self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule
802+
self, self._test_impl, model, dataloaders, ckpt_path, weights_only, verbose, datamodule
791803
)
792804

793805
def _test_impl(
794806
self,
795807
model: Optional["pl.LightningModule"] = None,
796808
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
797809
ckpt_path: Optional[_PATH] = None,
810+
weights_only: bool = False,
798811
verbose: bool = True,
799812
datamodule: Optional[LightningDataModule] = None,
800813
) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]:
@@ -828,7 +841,7 @@ def _test_impl(
828841
ckpt_path = self._checkpoint_connector._select_ckpt_path(
829842
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
830843
)
831-
results = self._run(model, ckpt_path=ckpt_path)
844+
results = self._run(model, ckpt_path=ckpt_path, weights_only=weights_only)
832845
# remove the tensors from the test results
833846
results = convert_tensors_to_scalars(results)
834847

@@ -844,6 +857,7 @@ def predict(
844857
datamodule: Optional[LightningDataModule] = None,
845858
return_predictions: Optional[bool] = None,
846859
ckpt_path: Optional[_PATH] = None,
860+
weights_only: bool = False,
847861
) -> Optional[_PREDICT_OUTPUT]:
848862
r"""Run inference on your data. This will call the model forward function to compute predictions. Useful to
849863
perform distributed and batched predictions. Logging is disabled in the predict hooks.
@@ -899,7 +913,7 @@ def predict(
899913
self.state.status = TrainerStatus.RUNNING
900914
self.predicting = True
901915
return call._call_and_handle_interrupt(
902-
self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
916+
self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path, weights_only
903917
)
904918

905919
def _predict_impl(
@@ -909,6 +923,7 @@ def _predict_impl(
909923
datamodule: Optional[LightningDataModule] = None,
910924
return_predictions: Optional[bool] = None,
911925
ckpt_path: Optional[_PATH] = None,
926+
weights_only: bool = False,
912927
) -> Optional[_PREDICT_OUTPUT]:
913928
# --------------------
914929
# SETUP HOOK
@@ -939,15 +954,15 @@ def _predict_impl(
939954
ckpt_path = self._checkpoint_connector._select_ckpt_path(
940955
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
941956
)
942-
results = self._run(model, ckpt_path=ckpt_path)
957+
results = self._run(model, ckpt_path=ckpt_path, weights_only=weights_only)
943958

944959
assert self.state.stopped
945960
self.predicting = False
946961

947962
return results
948963

949964
def _run(
950-
self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None
965+
self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None, weights_only: bool = False
951966
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
952967
if self.state.fn == TrainerFn.FITTING:
953968
min_epochs, max_epochs = _parse_loop_limits(
@@ -992,7 +1007,7 @@ def _run(
9921007
# check if we should delay restoring checkpoint till later
9931008
if not self.strategy.restore_checkpoint_after_setup:
9941009
log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
995-
self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
1010+
self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path, weights_only)
9961011

9971012
# reset logger connector
9981013
self._logger_connector.reset_results()

0 commit comments

Comments
 (0)