@@ -764,6 +764,55 @@ def test_ckpt_every_n_train_steps(tmp_path):
764
764
assert set (os .listdir (tmp_path )) == set (expected )
765
765
766
766
767
+ def test_model_checkpoint_save_on_exception_in_train_callback (tmp_path ):
768
+ """Test that the checkpoint is saved when an exception is raised in a callback on different events."""
769
+ class TroublemakerOnTrainBatchStart (Callback ):
770
+ def on_train_batch_start (self , trainer , pl_module , batch , batch_idx ):
771
+ if batch_idx == 1 :
772
+ raise RuntimeError ("Trouble!" )
773
+
774
+ class TroublemakerOnTrainBatchEnd (Callback ):
775
+ def on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
776
+ if batch_idx == 1 :
777
+ raise RuntimeError ("Trouble!" )
778
+
779
+ class TroublemakerOnTrainEpochStart (Callback ):
780
+ def on_train_epoch_start (self , trainer , pl_module ):
781
+ if trainer .current_epoch == 1 :
782
+ raise RuntimeError ("Trouble!" )
783
+
784
+ class TroublemakerOnTrainEpochEnd (Callback ):
785
+ def on_train_epoch_end (self , trainer , pl_module ):
786
+ if trainer .current_epoch == 1 :
787
+ raise RuntimeError ("Trouble!" )
788
+
789
+
790
+ epoch_length = 64
791
+ model = BoringModel ()
792
+ # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints
793
+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
794
+
795
+ troublemakers = [
796
+ TroublemakerOnTrainBatchStart (),
797
+ TroublemakerOnTrainBatchEnd (),
798
+ TroublemakerOnTrainEpochStart (),
799
+ TroublemakerOnTrainEpochEnd ()
800
+ ]
801
+
802
+ expected_ckpts = ["step=1.ckpt" ,
803
+ 'step=2.ckpt' ,
804
+ f'step={ epoch_length } .ckpt' ,
805
+ f'step={ 2 * epoch_length } .ckpt' ,
806
+ ]
807
+
808
+ for troublemaker in troublemakers :
809
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , troublemaker ], max_epochs = 5 , logger = False )
810
+
811
+ with pytest .raises (RuntimeError , match = "Trouble!" ):
812
+ trainer .fit (model )
813
+
814
+ assert set (os .listdir (tmp_path )) == set (expected_ckpts )
815
+
767
816
@mock .patch ("lightning.pytorch.callbacks.model_checkpoint.time" )
768
817
def test_model_checkpoint_train_time_interval (mock_datetime , tmp_path ) -> None :
769
818
"""Tests that the checkpoints are saved at the specified time interval."""
0 commit comments