@@ -813,196 +813,101 @@ def validation_step(self, batch, batch_idx):
813813
814814
815815################################################################################################# 
816- def  test_model_checkpoint_save_on_exception_in_training_step (tmp_path ):
817-     """Test that the  checkpoint is saved when an exception is raised in training_step .""" 
816+ def  test_model_checkpoint_on_exception_in_other_callbacks (tmp_path ):
817+     """Test that an  checkpoint is saved when an exception is raised in an other callback .""" 
818818
819-     class  TroubledModel ( BoringModel ):
820-         def  training_step (self , batch , batch_idx ):
819+     class  TroubleMakerOnTrainBatchStart ( Callback ):
820+         def  on_train_batch_start (self ,  trainer ,  pl_module , batch , batch_idx ):
821821            if  batch_idx  ==  1 :
822822                raise  RuntimeError ("Trouble!" )
823823
824-     model  =  TroubledModel ()
825-     checkpoint_callback  =  ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
826-     trainer  =  Trainer (
827-         default_root_dir = tmp_path ,
828-         callbacks = [checkpoint_callback ],
829-         max_epochs = 5 ,
830-         logger = False ,
831-         enable_progress_bar = False ,
832-     )
833-     with  pytest .raises (RuntimeError , match = "Trouble!" ):
834-         trainer .fit (model )
835-     print (os .listdir (tmp_path ))
836-     assert  os .path .isfile (tmp_path  /  "step=1.ckpt" )
837- 
838- 
839- def  test_model_checkpoint_save_on_exception_in_validation_step (tmp_path ):
840-     """Test that the checkpoint is saved when an exception is raised in validation_step.""" 
841- 
842-     class  TroubledModel (BoringModel ):
843-         def  validation_step (self , batch , batch_idx ):
844-             if  not  trainer .sanity_checking  and  batch_idx  ==  0 :
824+     class  TroubleMakerOnTrainBatchEnd (Callback ):
825+         def  on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
826+             if  batch_idx  ==  1 :
845827                raise  RuntimeError ("Trouble!" )
846828
847-     model  =  TroubledModel ()
848-     epoch_length  =  2 
849-     checkpoint_callback  =  ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
850-     trainer  =  Trainer (
851-         default_root_dir = tmp_path ,
852-         callbacks = [checkpoint_callback ],
853-         max_epochs = 5 ,
854-         limit_train_batches = epoch_length ,
855-         logger = False ,
856-         enable_progress_bar = False ,
857-     )
858-     with  pytest .raises (RuntimeError , match = "Trouble!" ):
859-         trainer .fit (model )
860-     assert  os .path .isfile (tmp_path  /  f"step={ epoch_length }  .ckpt" )
861- 
862- 
863- ################################################################################################# 
864- 
865- CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX  =  2 
866- CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH  =  21 
867- CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS  =  25 
868- CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES  =  4 
869- assert  CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX  <  CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES 
870- assert  CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH  <  CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS 
871- 
872- 
873- class  TroublemakerOnTrainBatchStart (Callback ):
874-     def  on_train_batch_start (self , trainer , pl_module , batch , batch_idx ):
875-         if  batch_idx  ==  CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX :
876-             raise  RuntimeError ("Trouble!" )
877- 
878- 
879- class  TroublemakerOnTrainBatchEnd (Callback ):
880-     def  on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
881-         if  batch_idx  ==  CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX :
882-             raise  RuntimeError ("Trouble!" )
883- 
884- 
885- class  TroublemakerOnTrainEpochStart (Callback ):
886-     def  on_train_epoch_start (self , trainer , pl_module ):
887-         if  trainer .current_epoch  ==  CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH :
888-             raise  RuntimeError ("Trouble!" )
889- 
890- 
891- class  TroublemakerOnTrainEpochEnd (Callback ):
892-     def  on_train_epoch_end (self , trainer , pl_module ):
893-         if  trainer .current_epoch  ==  CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH :
894-             raise  RuntimeError ("Trouble!" )
895- 
896- 
897- class  TroublemakerOnTrainEnd (Callback ):
898-     def  on_train_end (self , trainer , pl_module ):
899-         raise  RuntimeError ("Trouble!" )
900- 
901- 
902- class  TroublemakerOnValidationBatchStart (Callback ):
903-     def  on_validation_batch_start (self , trainer , pl_module , batch , batch_idx ):
904-         if  not  trainer .sanity_checking  and  batch_idx  ==  1 :
905-             raise  RuntimeError ("Trouble!" )
829+     class  TroubleMakerOnTrainEpochStart (Callback ):
830+         def  on_train_epoch_start (self , trainer , pl_module ):
831+             if  trainer .current_epoch  ==  1 :
832+                 raise  RuntimeError ("Trouble!" )
906833
834+     class  TroubleMakerOnTrainEpochEnd (Callback ):
835+         def  on_train_epoch_end (self , trainer , pl_module ):
836+             if  trainer .current_epoch  ==  1 :
837+                 raise  RuntimeError ("Trouble!" )
907838
908- class  TroublemakerOnValidationBatchEnd (Callback ):
909-     def  on_validation_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
910-         if  not  trainer .sanity_checking  and  batch_idx  ==  1 :
839+     class  TroubleMakerOnTrainEnd (Callback ):
840+         def  on_train_end (self , trainer , pl_module ):
911841            raise  RuntimeError ("Trouble!" )
912842
843+     class  TroubleMakerOnValidationBatchStart (Callback ):
844+         def  on_validation_batch_start (self , trainer , pl_module , batch , batch_idx ):
845+             if  not  trainer .sanity_checking  and  batch_idx  ==  1 :
846+                 raise  RuntimeError ("Trouble!" )
913847
914- class  TroublemakerOnValidationEpochStart (Callback ):
915-     def  on_validation_epoch_start (self , trainer , pl_module ):
916-         if  not  trainer .sanity_checking  and  trainer .current_epoch  ==  CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH :
917-             raise  RuntimeError ("Trouble!" )
918- 
848+     class  TroubleMakerOnValidationBatchEnd (Callback ):
849+         def  on_validation_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
850+             if  not  trainer .sanity_checking  and  batch_idx  ==  1 :
851+                 raise  RuntimeError ("Trouble!" )
919852
920- class  TroublemakerOnValidationEpochEnd (Callback ):
921-     def  on_validation_epoch_end (self , trainer , pl_module ):
922-         if  not  trainer .sanity_checking  and  trainer .current_epoch  ==  CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH :
923-             raise  RuntimeError ("Trouble!" )
853+      class  TroubleMakerOnValidationEpochStart (Callback ):
854+          def  on_validation_epoch_start (self , trainer , pl_module ):
855+              if  not  trainer .sanity_checking  and  trainer .current_epoch  ==  1 :
856+                  raise  RuntimeError ("Trouble!" )
924857
858+     class  TroubleMakerOnValidationEpochEnd (Callback ):
859+         def  on_validation_epoch_end (self , trainer , pl_module ):
860+             if  not  trainer .sanity_checking  and  trainer .current_epoch  ==  1 :
861+                 raise  RuntimeError ("Trouble!" )
925862
926- class  TroublemakerOnValidationStart (Callback ):
927-     def  on_validation_start (self , trainer , pl_module ):
928-         if  not  trainer .sanity_checking :
929-             raise  RuntimeError ("Trouble!" )
863+      class  TroubleMakerOnValidationStart (Callback ):
864+          def  on_validation_start (self , trainer , pl_module ):
865+              if  not  trainer .sanity_checking :
866+                  raise  RuntimeError ("Trouble!" )
930867
868+     class  TroubleMakerOnValidationEnd (Callback ):
869+         def  on_validation_end (self , trainer , pl_module ):
870+             if  not  trainer .sanity_checking :
871+                 raise  RuntimeError ("Trouble!" )
931872
932- class  TroublemakerOnValidationEnd (Callback ):
933-     def  on_validation_end (self , trainer , pl_module ):
934-         if  not  trainer .sanity_checking :
873+     class  TroubleMakerOnFitEnd (Callback ):
874+         def  on_fit_end (self , trainer , pl_module ):
935875            raise  RuntimeError ("Trouble!" )
936876
877+     troubled_callbacks  =  [
878+         TroubleMakerOnTrainBatchStart (),
879+         TroubleMakerOnTrainBatchEnd (),
880+         TroubleMakerOnTrainEpochStart (),
881+         TroubleMakerOnTrainEpochEnd (),
882+         TroubleMakerOnTrainEnd (),
883+         TroubleMakerOnValidationBatchStart (),
884+         TroubleMakerOnValidationBatchEnd (),
885+         TroubleMakerOnValidationEpochStart (),
886+         TroubleMakerOnValidationEpochEnd (),
887+         TroubleMakerOnValidationStart (),
888+         TroubleMakerOnValidationEnd (),
889+         TroubleMakerOnFitEnd (),
890+     ]
937891
938- @pytest .mark .parametrize ( 
939-     ("TroubledCallback" , "expected_checkpoint_global_step" ), 
940-     [ 
941-         pytest .param ( 
942-             TroublemakerOnTrainBatchStart , CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX , id = "on_train_batch_start"  
943-         ), 
944-         pytest .param ( 
945-             TroublemakerOnTrainBatchEnd , CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX  +  1 , id = "on_train_batch_end"  
946-         ), 
947-         pytest .param ( 
948-             TroublemakerOnTrainEpochStart , 
949-             CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH  *  CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES , 
950-             id = "on_train_epoch_start" , 
951-         ), 
952-         pytest .param ( 
953-             TroublemakerOnTrainEpochEnd , 
954-             (CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH  +  1 ) *  CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES , 
955-             id = "on_train_epoch_end" , 
956-         ), 
957-         pytest .param ( 
958-             TroublemakerOnTrainEnd , 
959-             CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS  *  CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES , 
960-             id = "on_train_end" , 
961-         ), 
962-         pytest .param ( 
963-             TroublemakerOnValidationBatchStart , CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES , id = "on_validation_batch_start"  
964-         ), 
965-         pytest .param ( 
966-             TroublemakerOnValidationBatchEnd , CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES , id = "on_validation_batch_end"  
967-         ), 
968-         pytest .param ( 
969-             TroublemakerOnValidationEpochStart , 
970-             (CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH  +  1 ) *  CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES , 
971-             id = "on_validation_epoch_start" , 
972-         ), 
973-         pytest .param ( 
974-             TroublemakerOnValidationEpochEnd , 
975-             (CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH  +  1 ) *  CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES , 
976-             id = "on_validation_epoch_end" , 
977-         ), 
978-         pytest .param (TroublemakerOnValidationStart , CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES , id = "on_validation_start" ), 
979-         pytest .param (TroublemakerOnValidationEnd , CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES , id = "on_validation_end" ), 
980-     ], 
981- ) 
982- def  test_model_checkpoint_save_on_exception_in_other_callbacks (
983-     tmp_path , TroubledCallback , expected_checkpoint_global_step 
984- ):
985-     """Test that an checkpoint is saved when an exception is raised in an other callback.""" 
986- 
987-     model  =  BoringModel ()
988-     checkpoint_callback  =  ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
989-     trainer  =  Trainer (
990-         default_root_dir = tmp_path ,
991-         callbacks = [checkpoint_callback , TroubledCallback ()],
992-         max_epochs = CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS ,
993-         limit_train_batches = CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES ,
994-         logger = False ,
995-         enable_progress_bar = False ,
996-     )
997-     with  pytest .raises (RuntimeError , match = "Trouble!" ):
998-         trainer .fit (model )
999- 
1000-     assert  os .path .isfile (tmp_path  /  f"step={ expected_checkpoint_global_step }  .ckpt" )
1001-     checkpoint  =  torch .load (tmp_path  /  f"step={ expected_checkpoint_global_step }  .ckpt" , weights_only = True )
1002-     assert  checkpoint ["global_step" ] ==  expected_checkpoint_global_step 
1003- 
1004- 
1005- ################################################################################################# 
892+     for  troubled_callback  in  troubled_callbacks :
893+         model  =  BoringModel ()
894+         checkpoint_callback  =  ModelCheckpoint (
895+             dirpath = tmp_path , filename = troubled_callback .__class__ .__name__ , save_on_exception = True , every_n_epochs = 5 
896+         )
897+         trainer  =  Trainer (
898+             default_root_dir = tmp_path ,
899+             callbacks = [checkpoint_callback , troubled_callback ],
900+             max_epochs = 4 ,
901+             limit_train_batches = 2 ,
902+             logger = False ,
903+             enable_progress_bar = False ,
904+         )
905+         with  pytest .raises (RuntimeError , match = "Trouble!" ):
906+             trainer .fit (model )
907+         assert  os .path .isfile (tmp_path  /  f"exception-{ troubled_callback .__class__ .__name__ }  .ckpt" )
908+         checkpoint  =  torch .load (tmp_path  /  f"exception-{ troubled_callback .__class__ .__name__ }  .ckpt" , map_location = "cpu" )
909+         assert  checkpoint ["state_dict" ] is  not   None 
910+         assert  checkpoint ["state_dict" ] !=  {}
1006911
1007912
1008913@mock .patch ("lightning.pytorch.callbacks.model_checkpoint.time" ) 
0 commit comments