Skip to content

Commit 6f93a90

Browse files
vseypre-commit-ci[bot]Borda
authored
Add save_on_exception option to ModelCheckpoint (#20916)
* add saving of checkpoint if an exception is raised * import callback to checkpoint test file * add test for exception in training callbacks --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka B <[email protected]>
1 parent 577c04d commit 6f93a90

File tree

3 files changed

+469
-6
lines changed

3 files changed

+469
-6
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
### Added
1212

13-
- Added support for general mappings being returned from `training_step` when using manual optimization ([#21011](https://github.com/Lightning-AI/pytorch-lightning/pull/21011))
13+
- Added `save_on_exception` option to `ModelCheckpoint` Callback ([#20916](https://github.com/Lightning-AI/pytorch-lightning/pull/20916))
14+
1415

16+
- Added support for general mappings being returned from `training_step` when using manual optimization ([#21011](https://github.com/Lightning-AI/pytorch-lightning/pull/21011))
1517

1618

1719
### Changed

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class ModelCheckpoint(Checkpoint):
9797
collisions unless ``enable_version_counter`` is set to False. The version counter is unrelated to the top-k
9898
ranking of the checkpoint, and we recommend formatting the filename to include the monitored metric to avoid
9999
collisions.
100+
save_on_exception: Whether to save a checkpoint when an exception is raised. Default: ``False``.
100101
mode: one of {min, max}.
101102
If ``save_top_k != 0``, the decision to overwrite the current save file is made
102103
based on either the maximization or the minimization of the monitored quantity.
@@ -230,6 +231,7 @@ def __init__(
230231
verbose: bool = False,
231232
save_last: Optional[Union[bool, Literal["link"]]] = None,
232233
save_top_k: int = 1,
234+
save_on_exception: bool = False,
233235
save_weights_only: bool = False,
234236
mode: str = "min",
235237
auto_insert_metric_name: bool = True,
@@ -244,6 +246,7 @@ def __init__(
244246
self.verbose = verbose
245247
self.save_last = save_last
246248
self.save_top_k = save_top_k
249+
self.save_on_exception = save_on_exception
247250
self.save_weights_only = save_weights_only
248251
self.auto_insert_metric_name = auto_insert_metric_name
249252
self._save_on_train_epoch_end = save_on_train_epoch_end
@@ -345,6 +348,19 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
345348
self._save_last_checkpoint(trainer, monitor_candidates)
346349

347350
@override
351+
def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None:
352+
"""Save a checkpoint when an exception is raised."""
353+
if not self._should_save_on_exception(trainer):
354+
return
355+
monitor_candidates = self._monitor_candidates(trainer)
356+
filepath = self.format_checkpoint_name(metrics=monitor_candidates)
357+
self._save_checkpoint(trainer, filepath)
358+
self._save_last_checkpoint(trainer, monitor_candidates)
359+
rank_zero_info(
360+
f"An {type(exception).__name__} was raised with message: \
361+
{str(exception)}, saved checkpoint to {filepath}"
362+
)
363+
348364
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
349365
"""Ensure save_last=True is applied when training ends."""
350366
if self.save_last and not self._last_checkpoint_saved:
@@ -439,6 +455,14 @@ def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
439455
or self._last_global_step_saved == trainer.global_step # already saved at the last step
440456
)
441457

458+
def _should_save_on_exception(self, trainer: "pl.Trainer") -> bool:
459+
return (
460+
self.save_on_exception
461+
and not bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run
462+
and not trainer.sanity_checking # don't save anything during sanity check
463+
and self._last_global_step_saved != trainer.global_step # already saved at the last step
464+
)
465+
442466
def _should_save_on_train_epoch_end(self, trainer: "pl.Trainer") -> bool:
443467
if self._save_on_train_epoch_end is not None:
444468
return self._save_on_train_epoch_end
@@ -551,7 +575,7 @@ def _format_checkpoint_name(
551575
self,
552576
filename: Optional[str],
553577
metrics: dict[str, Tensor],
554-
prefix: str = "",
578+
prefix: Optional[str] = None,
555579
auto_insert_metric_name: bool = True,
556580
) -> str:
557581
if not filename:
@@ -578,13 +602,17 @@ def _format_checkpoint_name(
578602
metrics[name] = torch.tensor(0)
579603
filename = filename.format(metrics)
580604

581-
if prefix:
605+
if prefix is not None:
582606
filename = self.CHECKPOINT_JOIN_CHAR.join([prefix, filename])
583607

584608
return filename
585609

586610
def format_checkpoint_name(
587-
self, metrics: dict[str, Tensor], filename: Optional[str] = None, ver: Optional[int] = None
611+
self,
612+
metrics: dict[str, Tensor],
613+
filename: Optional[str] = None,
614+
ver: Optional[int] = None,
615+
prefix: Optional[str] = None,
588616
) -> str:
589617
"""Generate a filename according to the defined template.
590618
@@ -616,7 +644,9 @@ def format_checkpoint_name(
616644
617645
"""
618646
filename = filename or self.filename
619-
filename = self._format_checkpoint_name(filename, metrics, auto_insert_metric_name=self.auto_insert_metric_name)
647+
filename = self._format_checkpoint_name(
648+
filename, metrics, prefix=prefix, auto_insert_metric_name=self.auto_insert_metric_name
649+
)
620650

621651
if ver is not None:
622652
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))

0 commit comments

Comments
 (0)