@@ -764,6 +764,55 @@ def test_ckpt_every_n_train_steps(tmp_path):
764764 assert set (os .listdir (tmp_path )) == set (expected )
765765
766766
767+ def test_model_checkpoint_save_on_exception_in_train_callback (tmp_path ):
768+ """Test that the checkpoint is saved when an exception is raised in a callback on different events."""
769+ class TroublemakerOnTrainBatchStart (Callback ):
770+ def on_train_batch_start (self , trainer , pl_module , batch , batch_idx ):
771+ if batch_idx == 1 :
772+ raise RuntimeError ("Trouble!" )
773+
774+ class TroublemakerOnTrainBatchEnd (Callback ):
775+ def on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
776+ if batch_idx == 1 :
777+ raise RuntimeError ("Trouble!" )
778+
779+ class TroublemakerOnTrainEpochStart (Callback ):
780+ def on_train_epoch_start (self , trainer , pl_module ):
781+ if trainer .current_epoch == 1 :
782+ raise RuntimeError ("Trouble!" )
783+
784+ class TroublemakerOnTrainEpochEnd (Callback ):
785+ def on_train_epoch_end (self , trainer , pl_module ):
786+ if trainer .current_epoch == 1 :
787+ raise RuntimeError ("Trouble!" )
788+
789+
790+ epoch_length = 64
791+ model = BoringModel ()
792+ # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints
793+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
794+
795+ troublemakers = [
796+ TroublemakerOnTrainBatchStart (),
797+ TroublemakerOnTrainBatchEnd (),
798+ TroublemakerOnTrainEpochStart (),
799+ TroublemakerOnTrainEpochEnd ()
800+ ]
801+
802+ expected_ckpts = ["step=1.ckpt" ,
803+ 'step=2.ckpt' ,
804+ f'step={ epoch_length } .ckpt' ,
805+ f'step={ 2 * epoch_length } .ckpt' ,
806+ ]
807+
808+ for troublemaker in troublemakers :
809+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , troublemaker ], max_epochs = 5 , logger = False )
810+
811+ with pytest .raises (RuntimeError , match = "Trouble!" ):
812+ trainer .fit (model )
813+
814+ assert set (os .listdir (tmp_path )) == set (expected_ckpts )
815+
767816@mock .patch ("lightning.pytorch.callbacks.model_checkpoint.time" )
768817def test_model_checkpoint_train_time_interval (mock_datetime , tmp_path ) -> None :
769818 """Tests that the checkpoints are saved at the specified time interval."""
0 commit comments