@@ -770,6 +770,73 @@ def test_ckpt_every_n_train_steps(tmp_path):
770770 assert set (os .listdir (tmp_path )) == set (expected )
771771
772772
773+ def test_model_checkpoint_on_exception_run_condition (tmp_path ):
774+ """Test that the checkpoint is saved when an exception is raised in a lightning module."""
775+
776+ # Don't save checkpoint if sanity check fails
777+ class TroubledModelSanityCheck (BoringModel ):
778+ def on_validation_start (self ) -> None :
779+ if self .trainer .sanity_checking :
780+ print ("Trouble!" )
781+ raise RuntimeError ("Trouble!" )
782+
783+ model = TroubledModelSanityCheck ()
784+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "sanity_check" , save_on_exception = True )
785+ trainer = Trainer (
786+ default_root_dir = tmp_path ,
787+ num_sanity_val_steps = 4 ,
788+ limit_train_batches = 2 ,
789+ callbacks = [checkpoint_callback ],
790+ max_epochs = 2 ,
791+ logger = False ,
792+ )
793+
794+ with pytest .raises (RuntimeError , match = "Trouble!" ):
795+ trainer .fit (model )
796+ assert not os .path .isfile (tmp_path / "exception-sanity_check.ckpt" )
797+
798+ # Don't save checkpoint if fast dev run fails
799+ class TroubledModelFastDevRun (BoringModel ):
800+ def on_train_batch_start (self , batch , batch_idx ) -> None :
801+ if self .trainer .fast_dev_run and batch_idx == 1 :
802+ raise RuntimeError ("Trouble!" )
803+
804+ model = TroubledModelFastDevRun ()
805+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "fast_dev_run" , save_on_exception = True )
806+ trainer = Trainer (
807+ default_root_dir = tmp_path ,
808+ fast_dev_run = 2 ,
809+ limit_train_batches = 2 ,
810+ callbacks = [checkpoint_callback ],
811+ max_epochs = 2 ,
812+ logger = False ,
813+ )
814+
815+ with pytest .raises (RuntimeError , match = "Trouble!" ):
816+ trainer .fit (model )
817+ assert not os .path .isfile (tmp_path / "exception-fast_dev_run.ckpt" )
818+
819+ # Don't save checkpoint if already saved a checkpoint
820+ class TroubledModelAlreadySavedCheckpoint (BoringModel ):
821+ def on_train_batch_start (self , batch , batch_idx ) -> None :
822+ if self .trainer .global_step == 1 :
823+ raise RuntimeError ("Trouble!" )
824+
825+ model = TroubledModelAlreadySavedCheckpoint ()
826+ checkpoint_callback = ModelCheckpoint (
827+ dirpath = tmp_path , filename = "already_saved" , save_on_exception = True , every_n_train_steps = 1
828+ )
829+ trainer = Trainer (
830+ default_root_dir = tmp_path , limit_train_batches = 2 , callbacks = [checkpoint_callback ], max_epochs = 2 , logger = False
831+ )
832+
833+ with pytest .raises (RuntimeError , match = "Trouble!" ):
834+ trainer .fit (model )
835+
836+ assert not os .path .isfile (tmp_path / "exception-already_saved.ckpt" )
837+ assert os .path .isfile (tmp_path / "already_saved.ckpt" )
838+
839+
773840def test_model_checkpoint_on_exception (tmp_path ):
774841 """Test that the checkpoint is saved when an exception is raised in a lightning module."""
775842
0 commit comments