@@ -97,6 +97,7 @@ class ModelCheckpoint(Checkpoint):
97
97
collisions unless ``enable_version_counter`` is set to False. The version counter is unrelated to the top-k
98
98
ranking of the checkpoint, and we recommend formatting the filename to include the monitored metric to avoid
99
99
collisions.
100
+ save_on_exception: Whether to save a checkpoint when an exception is raised. Default: ``True``.
100
101
mode: one of {min, max}.
101
102
If ``save_top_k != 0``, the decision to overwrite the current save file is made
102
103
based on either the maximization or the minimization of the monitored quantity.
@@ -224,6 +225,7 @@ def __init__(
224
225
verbose : bool = False ,
225
226
save_last : Optional [Union [bool , Literal ["link" ]]] = None ,
226
227
save_top_k : int = 1 ,
228
+ save_on_exception : bool = True ,
227
229
save_weights_only : bool = False ,
228
230
mode : str = "min" ,
229
231
auto_insert_metric_name : bool = True ,
@@ -238,6 +240,7 @@ def __init__(
238
240
self .verbose = verbose
239
241
self .save_last = save_last
240
242
self .save_top_k = save_top_k
243
+ self .save_on_exception = save_on_exception
241
244
self .save_weights_only = save_weights_only
242
245
self .auto_insert_metric_name = auto_insert_metric_name
243
246
self ._save_on_train_epoch_end = save_on_train_epoch_end
@@ -338,6 +341,17 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
338
341
self ._save_topk_checkpoint (trainer , monitor_candidates )
339
342
self ._save_last_checkpoint (trainer , monitor_candidates )
340
343
344
+ @override
345
+ def on_exception (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , exception : Exception ) -> None :
346
+ if self .save_on_exception and not self ._should_skip_saving_checkpoint (trainer ):
347
+ monitor_candidates = self ._monitor_candidates (trainer )
348
+ filepath = self .format_checkpoint_name (metrics = monitor_candidates )
349
+ print (type (exception ))
350
+ self ._save_checkpoint (trainer , filepath )
351
+ self ._save_last_checkpoint (trainer , monitor_candidates )
352
+ rank_zero_info (f"An exception was raised saved checkpoint to { filepath } " )
353
+
354
+
341
355
@override
342
356
def state_dict (self ) -> dict [str , Any ]:
343
357
return {
0 commit comments