@@ -764,6 +764,37 @@ def test_ckpt_every_n_train_steps(tmp_path):
764
764
assert set (os .listdir (tmp_path )) == set (expected )
765
765
766
766
767
+ def test_model_checkpoint_save_on_exception_in_training_step (tmp_path ):
768
+ """Test that the checkpoint is saved when an exception is raised in training_step."""
769
+ class TroubledModel (BoringModel ):
770
+ def training_step (self , batch , batch_idx ):
771
+ if batch_idx == 1 :
772
+ raise RuntimeError ("Trouble!" )
773
+
774
+ model = TroubledModel ()
775
+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
776
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback ], max_epochs = 5 , logger = False )
777
+ with pytest .raises (RuntimeError , match = "Trouble!" ):
778
+ trainer .fit (model )
779
+ print (os .listdir (tmp_path ))
780
+ assert os .path .isfile (tmp_path / "step=1.ckpt" )
781
+
782
+ def test_model_checkpoint_save_on_exception_in_validation_step (tmp_path ):
783
+ """Test that the checkpoint is saved when an exception is raised in validation_step."""
784
+ class TroubledModel (BoringModel ):
785
+ def validation_step (self , batch , batch_idx ):
786
+ if not trainer .sanity_checking and batch_idx == 0 :
787
+ raise RuntimeError ("Trouble!" )
788
+
789
+ model = TroubledModel ()
790
+ epoch_length = 64
791
+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
792
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback ], max_epochs = 5 , logger = False )
793
+ with pytest .raises (RuntimeError , match = "Trouble!" ):
794
+ trainer .fit (model )
795
+ assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
796
+
797
+
767
798
def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_start (tmp_path ):
768
799
"""Test that the checkpoint is saved when an exception is raised in a callback on train_batch_start."""
769
800
class TroublemakerOnTrainBatchStart (Callback ):
0 commit comments