@@ -827,66 +827,100 @@ def on_train_epoch_end(self, trainer, pl_module):
827827 assert os .path .isfile (tmp_path / f"step={ 2 * epoch_length } .ckpt" )
828828
829829
830- # def test_model_checkpoint_save_on_exception_in_train_callback (tmp_path):
831- # """Test that the checkpoint is saved when an exception is raised in a callback on different events ."""
832- # class TroublemakerOnTrainBatchStart (Callback):
833- # def on_train_batch_start (self, trainer, pl_module, batch, batch_idx):
834- # if batch_idx == 1:
835- # raise RuntimeError("Trouble!")
830+ def test_model_checkpoint_save_on_exception_in_val_callback (tmp_path ):
831+ """Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_start ."""
832+ class TroublemakerOnValidationBatchStart (Callback ):
833+ def on_validation_batch_start (self , trainer , pl_module , batch , batch_idx ):
834+ if not trainer . sanity_checking and batch_idx == 1 :
835+ raise RuntimeError ("Trouble!" )
836836
837- # class TroublemakerOnTrainBatchEnd(Callback):
838- # def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
839- # if batch_idx == 1:
840- # raise RuntimeError("Trouble!")
841-
842- # class TroublemakerOnTrainEpochStart(Callback):
843- # def on_train_epoch_start(self, trainer, pl_module):
844- # if trainer.current_epoch == 1:
845- # raise RuntimeError("Trouble!")
846-
847- # class TroublemakerOnTrainEpochEnd(Callback):
848- # def on_train_epoch_end(self, trainer, pl_module):
849- # if trainer.current_epoch == 1:
850- # raise RuntimeError("Trouble!")
851-
852-
853- # epoch_length = 64
854- # model = BoringModel()
855- # # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints
856- # checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
857-
858- # troublemakers = [
859- # TroublemakerOnTrainBatchStart(),
860- # TroublemakerOnTrainBatchEnd(),
861- # TroublemakerOnTrainEpochStart(),
862- # TroublemakerOnTrainEpochEnd()
863- # ]
837+ model = BoringModel ()
838+ epoch_length = 64
839+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
840+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationBatchStart ()], max_epochs = 5 , logger = False )
841+ with pytest .raises (RuntimeError , match = "Trouble!" ):
842+ trainer .fit (model )
843+ assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
864844
865- # expected_ckpts = ["step=1.ckpt",
866- # 'step=2.ckpt',
867- # f'step={epoch_length}.ckpt',
868- # f'step={2*epoch_length}.ckpt',
869- # ]
870845
871- # for troublemaker in troublemakers:
872- # trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, troublemaker], max_epochs=5, logger=False)
846+ def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_batch_end (tmp_path ):
847+ """Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_end."""
848+ class TroublemakerOnValidationBatchEnd (Callback ):
849+ def on_validation_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
850+ if not trainer .sanity_checking and batch_idx == 1 :
851+ raise RuntimeError ("Trouble!" )
852+
853+ model = BoringModel ()
854+ epoch_length = 64
855+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
856+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationBatchEnd ()], max_epochs = 5 , logger = False )
857+ with pytest .raises (RuntimeError , match = "Trouble!" ):
858+ trainer .fit (model )
859+ assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
873860
874- # with pytest.raises(RuntimeError, match="Trouble!"):
875- # trainer.fit(model)
876861
877- # assert set(os.listdir(tmp_path)) == set(expected_ckpts)
862+ def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_start (tmp_path ):
863+ """Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_start."""
864+ class TroublemakerOnValidationEpochStart (Callback ):
865+ def on_validation_epoch_start (self , trainer , pl_module ):
866+ if not trainer .sanity_checking and trainer .current_epoch == 0 :
867+ raise RuntimeError ("Trouble!" )
878868
869+ model = BoringModel ()
870+ epoch_length = 64
871+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
872+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEpochStart ()], max_epochs = 5 , logger = False )
873+ with pytest .raises (RuntimeError , match = "Trouble!" ):
874+ trainer .fit (model )
875+ assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
876+
879877
880- # # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints
881- # checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
878+ def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_end (tmp_path ):
879+ """Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_end."""
880+ class TroublemakerOnValidationEpochEnd (Callback ):
881+ def on_validation_epoch_end (self , trainer , pl_module ):
882+ if not trainer .sanity_checking and trainer .current_epoch == 0 :
883+ raise RuntimeError ("Trouble!" )
884+
885+ model = BoringModel ()
886+ epoch_length = 64
887+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
888+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEpochEnd ()], max_epochs = 5 , logger = False )
889+ with pytest .raises (RuntimeError , match = "Trouble!" ):
890+ trainer .fit (model )
891+ assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
882892
883- # troublemakers = [
884- # # TroublemakerOnValidationBatchStart(),
885- # TroublemakerOnValidationBatchEnd(),
886- # expected_ckpts = [f"step={2*epoch_length}.ckpt",
893+ def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_start (tmp_path ):
894+ """Test that the checkpoint is saved when an exception is raised in a callback on validation_start."""
895+ class TroublemakerOnValidationStart (Callback ):
896+ def on_validation_start (self , trainer , pl_module ):
897+ if not trainer .sanity_checking :
898+ raise RuntimeError ("Trouble!" )
899+
900+ model = BoringModel ()
901+ epoch_length = 64
902+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
903+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationStart ()], max_epochs = 5 , logger = False )
904+ with pytest .raises (RuntimeError , match = "Trouble!" ):
905+ trainer .fit (model )
906+ assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
887907
888- assert set (os .listdir (tmp_path )) == set (expected_ckpts )
908+ def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_end (tmp_path ):
909+ """Test that the checkpoint is saved when an exception is raised in a callback on validation_end."""
910+ class TroublemakerOnValidationEnd (Callback ):
911+ def on_validation_end (self , trainer , pl_module ):
912+ if not trainer .sanity_checking :
913+ raise RuntimeError ("Trouble!" )
914+
915+ model = BoringModel ()
916+ epoch_length = 64
917+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
918+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEnd ()], max_epochs = 5 , logger = False )
919+ with pytest .raises (RuntimeError , match = "Trouble!" ):
920+ trainer .fit (model )
921+ assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
889922
923+
890924@mock .patch ("lightning.pytorch.callbacks.model_checkpoint.time" )
891925def test_model_checkpoint_train_time_interval (mock_datetime , tmp_path ) -> None :
892926 """Tests that the checkpoints are saved at the specified time interval."""
0 commit comments