@@ -811,241 +811,123 @@ def validation_step(self, batch, batch_idx):
811811    assert  os .path .isfile (tmp_path  /  f"step={ epoch_length }  .ckpt" )
812812
813813
814- def  test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_start (tmp_path ):
815-     """Test that the checkpoint is saved when an exception is raised in a callback on train_batch_start.""" 
814+ CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX  =  2 
815+ CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH  =  21 
816+ CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS  =  25 
817+ CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES  =  4 
818+ assert  CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX  <  CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES 
819+ assert  CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH  <  CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS 
816820
817-     class  TroublemakerOnTrainBatchStart (Callback ):
818-         def  on_train_batch_start (self , trainer , pl_module , batch , batch_idx ):
819-             if  batch_idx  ==  1 :
820-                 raise  RuntimeError ("Trouble!" )
821- 
822-     model  =  BoringModel ()
823-     checkpoint_callback  =  ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
824-     trainer  =  Trainer (
825-         default_root_dir = tmp_path ,
826-         callbacks = [checkpoint_callback , TroublemakerOnTrainBatchStart ()],
827-         max_epochs = 5 ,
828-         logger = False ,
829-         enable_progress_bar = False ,
830-     )
831-     with  pytest .raises (RuntimeError , match = "Trouble!" ):
832-         trainer .fit (model )
833-     assert  os .path .isfile (tmp_path  /  "step=1.ckpt" )
834- 
835- 
836- def  test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_end (tmp_path ):
837-     """Test that the checkpoint is saved when an exception is raised in a callback on train_batch_end.""" 
838- 
839-     class  TroublemakerOnTrainBatchEnd (Callback ):
840-         def  on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
841-             if  batch_idx  ==  1 :
842-                 raise  RuntimeError ("Trouble!" )
843- 
844-     model  =  BoringModel ()
845-     checkpoint_callback  =  ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
846-     trainer  =  Trainer (
847-         default_root_dir = tmp_path ,
848-         callbacks = [checkpoint_callback , TroublemakerOnTrainBatchEnd ()],
849-         max_epochs = 5 ,
850-         logger = False ,
851-         enable_progress_bar = False ,
852-     )
853-     with  pytest .raises (RuntimeError , match = "Trouble!" ):
854-         trainer .fit (model )
855- 
856-     assert  os .path .isfile (tmp_path  /  "step=2.ckpt" )
857- 
858- 
859- def  test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_start (tmp_path ):
860-     """Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_start.""" 
861- 
862-     class  TroublemakerOnTrainEpochStart (Callback ):
863-         def  on_train_epoch_start (self , trainer , pl_module ):
864-             if  trainer .current_epoch  ==  1 :
865-                 raise  RuntimeError ("Trouble!" )
866- 
867-     model  =  BoringModel ()
868-     epoch_length  =  2 
869-     checkpoint_callback  =  ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
870-     trainer  =  Trainer (
871-         default_root_dir = tmp_path ,
872-         callbacks = [checkpoint_callback , TroublemakerOnTrainEpochStart ()],
873-         max_epochs = 5 ,
874-         limit_train_batches = epoch_length ,
875-         logger = False ,
876-         enable_progress_bar = False ,
877-     )
878-     with  pytest .raises (RuntimeError , match = "Trouble!" ):
879-         trainer .fit (model )
880-     assert  os .path .isfile (tmp_path  /  f"step={ epoch_length }  .ckpt" )
881821
882- 
883- def  test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_end (tmp_path ):
884-     """Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_end.""" 
885- 
886-     class  TroublemakerOnTrainEpochEnd (Callback ):
887-         def  on_train_epoch_end (self , trainer , pl_module ):
888-             if  trainer .current_epoch  ==  1 :
889-                 raise  RuntimeError ("Trouble!" )
890- 
891-     model  =  BoringModel ()
892-     epoch_length  =  2 
893-     checkpoint_callback  =  ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
894-     trainer  =  Trainer (
895-         default_root_dir = tmp_path ,
896-         callbacks = [checkpoint_callback , TroublemakerOnTrainEpochEnd ()],
897-         max_epochs = 5 ,
898-         limit_train_batches = epoch_length ,
899-         logger = False ,
900-         enable_progress_bar = False ,
901-     )
902-     with  pytest .raises (RuntimeError , match = "Trouble!" ):
903-         trainer .fit (model )
904-     assert  os .path .isfile (tmp_path  /  f"step={ 2  *  epoch_length }  .ckpt" )
905- 
906- 
907- def  test_model_checkpoint_save_on_exception_in_val_callback (tmp_path ):
908-     """Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_start.""" 
909- 
910-     class  TroublemakerOnValidationBatchStart (Callback ):
911-         def  on_validation_batch_start (self , trainer , pl_module , batch , batch_idx ):
912-             if  not  trainer .sanity_checking  and  batch_idx  ==  1 :
913-                 raise  RuntimeError ("Trouble!" )
914- 
915-     model  =  BoringModel ()
916-     epoch_length  =  64 
917-     checkpoint_callback  =  ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
918-     trainer  =  Trainer (
919-         default_root_dir = tmp_path ,
920-         callbacks = [checkpoint_callback , TroublemakerOnValidationBatchStart ()],
921-         max_epochs = 5 ,
922-         logger = False ,
923-         enable_progress_bar = False ,
924-     )
925-     with  pytest .raises (RuntimeError , match = "Trouble!" ):
926-         trainer .fit (model )
927-     assert  os .path .isfile (tmp_path  /  f"step={ epoch_length }  .ckpt" )
822+ class  TroublemakerOnTrainBatchStart (Callback ):
823+     def  on_train_batch_start (self , trainer , pl_module , batch , batch_idx ):
824+         if  batch_idx  ==  CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX :
825+             raise  RuntimeError ("Trouble!" )
928826
929827
930- def  test_model_checkpoint_save_on_exception_in_val_callback_on_validation_batch_end (tmp_path ):
931-     """Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_end.""" 
828+ class  TroublemakerOnTrainBatchEnd (Callback ):
829+     def  on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
830+         if  batch_idx  ==  CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX :
831+             raise  RuntimeError ("Trouble!" )
932832
933-     class  TroublemakerOnValidationBatchEnd (Callback ):
934-         def  on_validation_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
935-             if  not  trainer .sanity_checking  and  batch_idx  ==  1 :
936-                 raise  RuntimeError ("Trouble!" )
937833
938-     model  =  BoringModel ()
939-     epoch_length  =  64 
940-     checkpoint_callback  =  ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
941-     trainer  =  Trainer (
942-         default_root_dir = tmp_path ,
943-         callbacks = [checkpoint_callback , TroublemakerOnValidationBatchEnd ()],
944-         max_epochs = 5 ,
945-         logger = False ,
946-         enable_progress_bar = False ,
947-     )
948-     with  pytest .raises (RuntimeError , match = "Trouble!" ):
949-         trainer .fit (model )
950-     assert  os .path .isfile (tmp_path  /  f"step={ epoch_length }  .ckpt" )
834+ class  TroublemakerOnTrainEpochStart (Callback ):
835+     def  on_train_epoch_start (self , trainer , pl_module ):
836+         if  trainer .current_epoch  ==  CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH :
837+             raise  RuntimeError ("Trouble!" )
951838
952839
953- def  test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_start (tmp_path ):
954-     """Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_start.""" 
840+ class  TroublemakerOnTrainEpochEnd (Callback ):
841+     def  on_train_epoch_end (self , trainer , pl_module ):
842+         if  trainer .current_epoch  ==  CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH :
843+             raise  RuntimeError ("Trouble!" )
955844
956-     class  TroublemakerOnValidationEpochStart (Callback ):
957-         def  on_validation_epoch_start (self , trainer , pl_module ):
958-             if  not  trainer .sanity_checking  and  trainer .current_epoch  ==  0 :
959-                 raise  RuntimeError ("Trouble!" )
960845
961-     model  =  BoringModel ()
962-     epoch_length  =  2 
963-     checkpoint_callback  =  ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
964-     trainer  =  Trainer (
965-         default_root_dir = tmp_path ,
966-         callbacks = [checkpoint_callback , TroublemakerOnValidationEpochStart ()],
967-         max_epochs = 5 ,
968-         limit_train_batches = epoch_length ,
969-         logger = False ,
970-         enable_progress_bar = False ,
971-     )
972-     with  pytest .raises (RuntimeError , match = "Trouble!" ):
973-         trainer .fit (model )
974-     assert  os .path .isfile (tmp_path  /  f"step={ epoch_length }  .ckpt" )
846+ class  TroublemakerOnTrainEnd (Callback ):
847+     def  on_train_end (self , trainer , pl_module ):
848+         raise  RuntimeError ("Trouble!" )
975849
976850
977- def  test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_end (tmp_path ):
978-     """Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_end.""" 
851+ class  TroublemakerOnValidationBatchStart (Callback ):
852+     def  on_validation_batch_start (self , trainer , pl_module , batch , batch_idx ):
853+         if  not  trainer .sanity_checking  and  batch_idx  ==  1 :
854+             raise  RuntimeError ("Trouble!" )
979855
980-     class  TroublemakerOnValidationEpochEnd (Callback ):
981-         def  on_validation_epoch_end (self , trainer , pl_module ):
982-             if  not  trainer .sanity_checking  and  trainer .current_epoch  ==  0 :
983-                 raise  RuntimeError ("Trouble!" )
984856
985-     model  =  BoringModel ()
986-     epoch_length  =  2 
987-     checkpoint_callback  =  ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
988-     trainer  =  Trainer (
989-         default_root_dir = tmp_path ,
990-         callbacks = [checkpoint_callback , TroublemakerOnValidationEpochEnd ()],
991-         max_epochs = 5 ,
992-         limit_train_batches = epoch_length ,
993-         logger = False ,
994-         enable_progress_bar = False ,
995-     )
996-     with  pytest .raises (RuntimeError , match = "Trouble!" ):
997-         trainer .fit (model )
998-     assert  os .path .isfile (tmp_path  /  f"step={ epoch_length }  .ckpt" )
857+ class  TroublemakerOnValidationBatchEnd (Callback ):
858+     def  on_validation_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
859+         if  not  trainer .sanity_checking  and  batch_idx  ==  1 :
860+             raise  RuntimeError ("Trouble!" )
999861
1000862
1001- def  test_model_checkpoint_save_on_exception_in_val_callback_on_validation_start (tmp_path ):
1002-     """Test that the checkpoint is saved when an exception is raised in a callback on validation_start.""" 
863+ class  TroublemakerOnValidationEpochStart (Callback ):
864+     def  on_validation_epoch_start (self , trainer , pl_module ):
865+         if  not  trainer .sanity_checking  and  trainer .current_epoch  ==  CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH :
866+             raise  RuntimeError ("Trouble!" )
1003867
1004-     class  TroublemakerOnValidationStart (Callback ):
1005-         def  on_validation_start (self , trainer , pl_module ):
1006-             if  not  trainer .sanity_checking :
1007-                 raise  RuntimeError ("Trouble!" )
1008868
1009-     model  =  BoringModel ()
1010-     epoch_length  =  2 
1011-     checkpoint_callback  =  ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
1012-     trainer  =  Trainer (
1013-         default_root_dir = tmp_path ,
1014-         callbacks = [checkpoint_callback , TroublemakerOnValidationStart ()],
1015-         max_epochs = 5 ,
1016-         limit_train_batches = epoch_length ,
1017-         logger = False ,
1018-         enable_progress_bar = False ,
1019-     )
1020-     with  pytest .raises (RuntimeError , match = "Trouble!" ):
1021-         trainer .fit (model )
1022-     assert  os .path .isfile (tmp_path  /  f"step={ epoch_length }  .ckpt" )
869+ class  TroublemakerOnValidationEpochEnd (Callback ):
870+     def  on_validation_epoch_end (self , trainer , pl_module ):
871+         if  not  trainer .sanity_checking  and  trainer .current_epoch  ==  CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH :
872+             raise  RuntimeError ("Trouble!" )
1023873
1024874
1025- def  test_model_checkpoint_save_on_exception_in_val_callback_on_validation_end (tmp_path ):
1026-     """Test that the checkpoint is saved when an exception is raised in a callback on validation_end.""" 
875+ class  TroublemakerOnValidationStart (Callback ):
876+     def  on_validation_start (self , trainer , pl_module ):
877+         if  not  trainer .sanity_checking :
878+             raise  RuntimeError ("Trouble!" )
1027879
1028-     class  TroublemakerOnValidationEnd (Callback ):
1029-         def  on_validation_end (self , trainer , pl_module ):
1030-             if  not  trainer .sanity_checking :
1031-                 raise  RuntimeError ("Trouble!" )
1032880
1033-     model  =  BoringModel ()
1034-     epoch_length  =  2 
1035-     checkpoint_callback  =  ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
1036-     trainer  =  Trainer (
1037-         default_root_dir = tmp_path ,
1038-         callbacks = [checkpoint_callback , TroublemakerOnValidationEnd ()],
1039-         max_epochs = 5 ,
1040-         limit_train_batches = epoch_length ,
1041-         logger = False ,
1042-         enable_progress_bar = False ,
1043-     )
1044-     with  pytest .raises (RuntimeError , match = "Trouble!" ):
1045-         trainer .fit (model )
1046-     assert  os .path .isfile (tmp_path  /  f"step={ epoch_length }  .ckpt" )
881+ class  TroublemakerOnValidationEnd (Callback ):
882+     def  on_validation_end (self , trainer , pl_module ):
883+         if  not  trainer .sanity_checking :
884+             raise  RuntimeError ("Trouble!" )
1047885
1048886
887+ @pytest .mark .parametrize ( 
888+     ("TroubledCallback" , "expected_checkpoint_global_step" ), 
889+     [ 
890+         pytest .param ( 
891+             TroublemakerOnTrainBatchStart , CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX , id = "on_train_batch_start"  
892+         ), 
893+         pytest .param ( 
894+             TroublemakerOnTrainBatchEnd , CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX  +  1 , id = "on_train_batch_end"  
895+         ), 
896+         pytest .param ( 
897+             TroublemakerOnTrainEpochStart , 
898+             CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH  *  CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES , 
899+             id = "on_train_epoch_start" , 
900+         ), 
901+         pytest .param ( 
902+             TroublemakerOnTrainEpochEnd , 
903+             (CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH  +  1 ) *  CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES , 
904+             id = "on_train_epoch_end" , 
905+         ), 
906+         pytest .param ( 
907+             TroublemakerOnTrainEnd , 
908+             CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS  *  CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES , 
909+             id = "on_train_end" , 
910+         ), 
911+         pytest .param ( 
912+             TroublemakerOnValidationBatchStart , CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES , id = "on_validation_batch_start"  
913+         ), 
914+         pytest .param ( 
915+             TroublemakerOnValidationBatchEnd , CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES , id = "on_validation_batch_end"  
916+         ), 
917+         pytest .param ( 
918+             TroublemakerOnValidationEpochStart , 
919+             (CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH  +  1 ) *  CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES , 
920+             id = "on_validation_epoch_start" , 
921+         ), 
922+         pytest .param ( 
923+             TroublemakerOnValidationEpochEnd , 
924+             (CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH  +  1 ) *  CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES , 
925+             id = "on_validation_epoch_end" , 
926+         ), 
927+         pytest .param (TroublemakerOnValidationStart , CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES , id = "on_validation_start" ), 
928+         pytest .param (TroublemakerOnValidationEnd , CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES , id = "on_validation_end" ), 
929+     ], 
930+ ) 
1049931@mock .patch ("lightning.pytorch.callbacks.model_checkpoint.time" ) 
1050932def  test_model_checkpoint_train_time_interval (mock_datetime , tmp_path ) ->  None :
1051933    """Tests that the checkpoints are saved at the specified time interval.""" 
0 commit comments