Skip to content

Commit 0f73167

Browse files
committed
add saving of checkpoint if an exception is raised
1 parent 76d3d22 commit 0f73167

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class ModelCheckpoint(Checkpoint):
9797
collisions unless ``enable_version_counter`` is set to False. The version counter is unrelated to the top-k
9898
ranking of the checkpoint, and we recommend formatting the filename to include the monitored metric to avoid
9999
collisions.
100+
save_on_exception: Whether to save a checkpoint when an exception is raised. Default: ``True``.
100101
mode: one of {min, max}.
101102
If ``save_top_k != 0``, the decision to overwrite the current save file is made
102103
based on either the maximization or the minimization of the monitored quantity.
@@ -224,6 +225,7 @@ def __init__(
224225
verbose: bool = False,
225226
save_last: Optional[Union[bool, Literal["link"]]] = None,
226227
save_top_k: int = 1,
228+
save_on_exception: bool = True,
227229
save_weights_only: bool = False,
228230
mode: str = "min",
229231
auto_insert_metric_name: bool = True,
@@ -238,6 +240,7 @@ def __init__(
238240
self.verbose = verbose
239241
self.save_last = save_last
240242
self.save_top_k = save_top_k
243+
self.save_on_exception = save_on_exception
241244
self.save_weights_only = save_weights_only
242245
self.auto_insert_metric_name = auto_insert_metric_name
243246
self._save_on_train_epoch_end = save_on_train_epoch_end
@@ -338,6 +341,17 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
338341
self._save_topk_checkpoint(trainer, monitor_candidates)
339342
self._save_last_checkpoint(trainer, monitor_candidates)
340343

344+
@override
345+
def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: Exception) -> None:
346+
if self.save_on_exception and not self._should_skip_saving_checkpoint(trainer):
347+
monitor_candidates = self._monitor_candidates(trainer)
348+
filepath = self.format_checkpoint_name(metrics=monitor_candidates)
349+
print(type(exception))
350+
self._save_checkpoint(trainer, filepath)
351+
self._save_last_checkpoint(trainer, monitor_candidates)
352+
rank_zero_info(f"An exception was raised saved checkpoint to {filepath}")
353+
354+
341355
@override
342356
def state_dict(self) -> dict[str, Any]:
343357
return {

0 commit comments

Comments
 (0)