@@ -928,6 +928,29 @@ def on_validation_end(self, trainer, pl_module):
928928        pytest .param (TroublemakerOnValidationEnd , CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES , id = "on_validation_end" ), 
929929    ], 
930930) 
931+ def  test_model_checkpoint_save_on_exception_in_other_callbacks (
932+     tmp_path , TroubledCallback , expected_checkpoint_global_step 
933+ ):
934+     """Test that an checkpoint is saved when an exception is raised in an other callback.""" 
935+ 
936+     model  =  BoringModel ()
937+     checkpoint_callback  =  ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
938+     trainer  =  Trainer (
939+         default_root_dir = tmp_path ,
940+         callbacks = [checkpoint_callback , TroubledCallback ()],
941+         max_epochs = CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS ,
942+         limit_train_batches = CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES ,
943+         logger = False ,
944+         enable_progress_bar = False ,
945+     )
946+     with  pytest .raises (RuntimeError , match = "Trouble!" ):
947+         trainer .fit (model )
948+ 
949+     assert  os .path .isfile (tmp_path  /  f"step={ expected_checkpoint_global_step }  .ckpt" )
950+     checkpoint  =  torch .load (tmp_path  /  f"step={ expected_checkpoint_global_step }  .ckpt" , weights_only = True )
951+     assert  checkpoint ["global_step" ] ==  expected_checkpoint_global_step 
952+ 
953+ 
931954@mock .patch ("lightning.pytorch.callbacks.model_checkpoint.time" ) 
932955def  test_model_checkpoint_train_time_interval (mock_datetime , tmp_path ) ->  None :
933956    """Tests that the checkpoints are saved at the specified time interval.""" 
0 commit comments