@@ -764,6 +764,37 @@ def test_ckpt_every_n_train_steps(tmp_path):
764764 assert set (os .listdir (tmp_path )) == set (expected )
765765
766766
767+ def test_model_checkpoint_save_on_exception_in_training_step (tmp_path ):
768+ """Test that the checkpoint is saved when an exception is raised in training_step."""
769+ class TroubledModel (BoringModel ):
770+ def training_step (self , batch , batch_idx ):
771+ if batch_idx == 1 :
772+ raise RuntimeError ("Trouble!" )
773+
774+ model = TroubledModel ()
775+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
776+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback ], max_epochs = 5 , logger = False )
777+ with pytest .raises (RuntimeError , match = "Trouble!" ):
778+ trainer .fit (model )
779+ print (os .listdir (tmp_path ))
780+ assert os .path .isfile (tmp_path / "step=1.ckpt" )
781+
782+ def test_model_checkpoint_save_on_exception_in_validation_step (tmp_path ):
783+ """Test that the checkpoint is saved when an exception is raised in validation_step."""
784+ class TroubledModel (BoringModel ):
785+ def validation_step (self , batch , batch_idx ):
786+ if not trainer .sanity_checking and batch_idx == 0 :
787+ raise RuntimeError ("Trouble!" )
788+
789+ model = TroubledModel ()
790+ epoch_length = 64
791+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
792+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback ], max_epochs = 5 , logger = False )
793+ with pytest .raises (RuntimeError , match = "Trouble!" ):
794+ trainer .fit (model )
795+ assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
796+
797+
767798def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_start (tmp_path ):
768799 """Test that the checkpoint is saved when an exception is raised in a callback on train_batch_start."""
769800 class TroublemakerOnTrainBatchStart (Callback ):
0 commit comments