@@ -827,66 +827,100 @@ def on_train_epoch_end(self, trainer, pl_module):
827
827
assert os .path .isfile (tmp_path / f"step={ 2 * epoch_length } .ckpt" )
828
828
829
829
830
- # def test_model_checkpoint_save_on_exception_in_train_callback (tmp_path):
831
- # """Test that the checkpoint is saved when an exception is raised in a callback on different events ."""
832
- # class TroublemakerOnTrainBatchStart (Callback):
833
- # def on_train_batch_start (self, trainer, pl_module, batch, batch_idx):
834
- # if batch_idx == 1:
835
- # raise RuntimeError("Trouble!")
830
+ def test_model_checkpoint_save_on_exception_in_val_callback (tmp_path ):
831
+ """Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_start ."""
832
+ class TroublemakerOnValidationBatchStart (Callback ):
833
+ def on_validation_batch_start (self , trainer , pl_module , batch , batch_idx ):
834
+ if not trainer . sanity_checking and batch_idx == 1 :
835
+ raise RuntimeError ("Trouble!" )
836
836
837
- # class TroublemakerOnTrainBatchEnd(Callback):
838
- # def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
839
- # if batch_idx == 1:
840
- # raise RuntimeError("Trouble!")
841
-
842
- # class TroublemakerOnTrainEpochStart(Callback):
843
- # def on_train_epoch_start(self, trainer, pl_module):
844
- # if trainer.current_epoch == 1:
845
- # raise RuntimeError("Trouble!")
846
-
847
- # class TroublemakerOnTrainEpochEnd(Callback):
848
- # def on_train_epoch_end(self, trainer, pl_module):
849
- # if trainer.current_epoch == 1:
850
- # raise RuntimeError("Trouble!")
851
-
852
-
853
- # epoch_length = 64
854
- # model = BoringModel()
855
- # # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints
856
- # checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
857
-
858
- # troublemakers = [
859
- # TroublemakerOnTrainBatchStart(),
860
- # TroublemakerOnTrainBatchEnd(),
861
- # TroublemakerOnTrainEpochStart(),
862
- # TroublemakerOnTrainEpochEnd()
863
- # ]
837
+ model = BoringModel ()
838
+ epoch_length = 64
839
+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
840
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationBatchStart ()], max_epochs = 5 , logger = False )
841
+ with pytest .raises (RuntimeError , match = "Trouble!" ):
842
+ trainer .fit (model )
843
+ assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
864
844
865
- # expected_ckpts = ["step=1.ckpt",
866
- # 'step=2.ckpt',
867
- # f'step={epoch_length}.ckpt',
868
- # f'step={2*epoch_length}.ckpt',
869
- # ]
870
845
871
- # for troublemaker in troublemakers:
872
- # trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, troublemaker], max_epochs=5, logger=False)
846
+ def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_batch_end (tmp_path ):
847
+ """Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_end."""
848
+ class TroublemakerOnValidationBatchEnd (Callback ):
849
+ def on_validation_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
850
+ if not trainer .sanity_checking and batch_idx == 1 :
851
+ raise RuntimeError ("Trouble!" )
852
+
853
+ model = BoringModel ()
854
+ epoch_length = 64
855
+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
856
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationBatchEnd ()], max_epochs = 5 , logger = False )
857
+ with pytest .raises (RuntimeError , match = "Trouble!" ):
858
+ trainer .fit (model )
859
+ assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
873
860
874
- # with pytest.raises(RuntimeError, match="Trouble!"):
875
- # trainer.fit(model)
876
861
877
- # assert set(os.listdir(tmp_path)) == set(expected_ckpts)
862
+ def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_start (tmp_path ):
863
+ """Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_start."""
864
+ class TroublemakerOnValidationEpochStart (Callback ):
865
+ def on_validation_epoch_start (self , trainer , pl_module ):
866
+ if not trainer .sanity_checking and trainer .current_epoch == 0 :
867
+ raise RuntimeError ("Trouble!" )
878
868
869
+ model = BoringModel ()
870
+ epoch_length = 64
871
+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
872
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEpochStart ()], max_epochs = 5 , logger = False )
873
+ with pytest .raises (RuntimeError , match = "Trouble!" ):
874
+ trainer .fit (model )
875
+ assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
876
+
879
877
880
- # # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints
881
- # checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
878
+ def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_end (tmp_path ):
879
+ """Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_end."""
880
+ class TroublemakerOnValidationEpochEnd (Callback ):
881
+ def on_validation_epoch_end (self , trainer , pl_module ):
882
+ if not trainer .sanity_checking and trainer .current_epoch == 0 :
883
+ raise RuntimeError ("Trouble!" )
884
+
885
+ model = BoringModel ()
886
+ epoch_length = 64
887
+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
888
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEpochEnd ()], max_epochs = 5 , logger = False )
889
+ with pytest .raises (RuntimeError , match = "Trouble!" ):
890
+ trainer .fit (model )
891
+ assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
882
892
883
- # troublemakers = [
884
- # # TroublemakerOnValidationBatchStart(),
885
- # TroublemakerOnValidationBatchEnd(),
886
- # expected_ckpts = [f"step={2*epoch_length}.ckpt",
893
+ def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_start (tmp_path ):
894
+ """Test that the checkpoint is saved when an exception is raised in a callback on validation_start."""
895
+ class TroublemakerOnValidationStart (Callback ):
896
+ def on_validation_start (self , trainer , pl_module ):
897
+ if not trainer .sanity_checking :
898
+ raise RuntimeError ("Trouble!" )
899
+
900
+ model = BoringModel ()
901
+ epoch_length = 64
902
+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
903
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationStart ()], max_epochs = 5 , logger = False )
904
+ with pytest .raises (RuntimeError , match = "Trouble!" ):
905
+ trainer .fit (model )
906
+ assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
887
907
888
- assert set (os .listdir (tmp_path )) == set (expected_ckpts )
908
+ def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_end (tmp_path ):
909
+ """Test that the checkpoint is saved when an exception is raised in a callback on validation_end."""
910
+ class TroublemakerOnValidationEnd (Callback ):
911
+ def on_validation_end (self , trainer , pl_module ):
912
+ if not trainer .sanity_checking :
913
+ raise RuntimeError ("Trouble!" )
914
+
915
+ model = BoringModel ()
916
+ epoch_length = 64
917
+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
918
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEnd ()], max_epochs = 5 , logger = False )
919
+ with pytest .raises (RuntimeError , match = "Trouble!" ):
920
+ trainer .fit (model )
921
+ assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
889
922
923
+
890
924
@mock .patch ("lightning.pytorch.callbacks.model_checkpoint.time" )
891
925
def test_model_checkpoint_train_time_interval (mock_datetime , tmp_path ) -> None :
892
926
"""Tests that the checkpoints are saved at the specified time interval."""
0 commit comments