@@ -97,6 +97,7 @@ class ModelCheckpoint(Checkpoint):
9797 collisions unless ``enable_version_counter`` is set to False. The version counter is unrelated to the top-k
9898 ranking of the checkpoint, and we recommend formatting the filename to include the monitored metric to avoid
9999 collisions.
100+ save_on_exception: Whether to save a checkpoint when an exception is raised. Default: ``True``.
100101 mode: one of {min, max}.
101102 If ``save_top_k != 0``, the decision to overwrite the current save file is made
102103 based on either the maximization or the minimization of the monitored quantity.
@@ -224,6 +225,7 @@ def __init__(
224225 verbose : bool = False ,
225226 save_last : Optional [Union [bool , Literal ["link" ]]] = None ,
226227 save_top_k : int = 1 ,
228+ save_on_exception : bool = True ,
227229 save_weights_only : bool = False ,
228230 mode : str = "min" ,
229231 auto_insert_metric_name : bool = True ,
@@ -238,6 +240,7 @@ def __init__(
238240 self .verbose = verbose
239241 self .save_last = save_last
240242 self .save_top_k = save_top_k
243+ self .save_on_exception = save_on_exception
241244 self .save_weights_only = save_weights_only
242245 self .auto_insert_metric_name = auto_insert_metric_name
243246 self ._save_on_train_epoch_end = save_on_train_epoch_end
@@ -338,6 +341,17 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
338341 self ._save_topk_checkpoint (trainer , monitor_candidates )
339342 self ._save_last_checkpoint (trainer , monitor_candidates )
340343
344+ @override
345+ def on_exception (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , exception : Exception ) -> None :
346+ if self .save_on_exception and not self ._should_skip_saving_checkpoint (trainer ):
347+ monitor_candidates = self ._monitor_candidates (trainer )
348+ filepath = self .format_checkpoint_name (metrics = monitor_candidates )
349+ print (type (exception ))
350+ self ._save_checkpoint (trainer , filepath )
351+ self ._save_last_checkpoint (trainer , monitor_candidates )
352+ rank_zero_info (f"An exception was raised saved checkpoint to { filepath } " )
353+
354+
341355 @override
342356 def state_dict (self ) -> dict [str , Any ]:
343357 return {
0 commit comments