@@ -770,6 +770,49 @@ def test_ckpt_every_n_train_steps(tmp_path):
770770 assert set (os .listdir (tmp_path )) == set (expected )
771771
772772
773+ #################################################################################################
774+
775+
776+ def test_model_checkpoint_on_exception (tmp_path ):
777+ """Test that the checkpoint is saved when an exception is raised in a lightning module."""
778+
779+ class TroubledModelInTrainingStep (BoringModel ):
780+ def training_step (self , batch , batch_idx ):
781+ if batch_idx == 1 :
782+ raise RuntimeError ("Trouble!" )
783+
784+ class TroubledModelInValidationStep (BoringModel ):
785+ def validation_step (self , batch , batch_idx ):
786+ if not trainer .sanity_checking and batch_idx == 1 :
787+ raise RuntimeError ("Trouble!" )
788+
789+ models = [TroubledModelInTrainingStep (), TroubledModelInValidationStep ()]
790+
791+ for model in models :
792+ checkpoint_callback = ModelCheckpoint (
793+ dirpath = tmp_path , filename = model .__class__ .__name__ , save_on_exception = True , every_n_epochs = 4
794+ )
795+ trainer = Trainer (
796+ default_root_dir = tmp_path ,
797+ callbacks = [checkpoint_callback ],
798+ limit_train_batches = 2 ,
799+ max_epochs = 5 ,
800+ logger = False ,
801+ enable_progress_bar = False ,
802+ )
803+
804+ with pytest .raises (RuntimeError , match = "Trouble!" ):
805+ trainer .fit (model )
806+
807+ checkpoint_path = tmp_path / f"exception-{ model .__class__ .__name__ } .ckpt"
808+
809+ assert os .path .isfile (checkpoint_path )
810+ checkpoint = torch .load (checkpoint_path , map_location = "cpu" )
811+ assert checkpoint ["state_dict" ] is not None
812+ assert checkpoint ["state_dict" ] != {}
813+
814+
815+ #################################################################################################
773816def test_model_checkpoint_save_on_exception_in_training_step (tmp_path ):
774817 """Test that the checkpoint is saved when an exception is raised in training_step."""
775818
@@ -817,6 +860,8 @@ def validation_step(self, batch, batch_idx):
817860 assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
818861
819862
863+ #################################################################################################
864+
820865CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX = 2
821866CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH = 21
822867CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS = 25
@@ -957,6 +1002,9 @@ def test_model_checkpoint_save_on_exception_in_other_callbacks(
9571002 assert checkpoint ["global_step" ] == expected_checkpoint_global_step
9581003
9591004
1005+ #################################################################################################
1006+
1007+
9601008@mock .patch ("lightning.pytorch.callbacks.model_checkpoint.time" )
9611009def test_model_checkpoint_train_time_interval (mock_datetime , tmp_path ) -> None :
9621010 """Tests that the checkpoints are saved at the specified time interval."""
0 commit comments