@@ -994,6 +994,11 @@ def test_model_checkpoint_on_exception_parametrized(tmp_path, TroubledModel):
994994 assert checkpoint ["state_dict" ] != {}
995995
996996
997+ class TroubledCallbackOnFitEnd (Callback ):
998+ def on_fit_end (self , trainer , pl_module ):
999+ raise RuntimeError ("Trouble!" )
1000+
1001+
9971002class TroubledCallbackOnTrainBatchStart (Callback ):
9981003 def on_train_batch_start (self , trainer , pl_module , batch , batch_idx ):
9991004 if batch_idx == 1 :
@@ -1018,9 +1023,16 @@ def on_train_epoch_end(self, trainer, pl_module):
10181023 raise RuntimeError ("Trouble!" )
10191024
10201025
1021- class TroubledCallbackOnTrainEnd (Callback ):
1022- def on_train_end (self , trainer , pl_module ):
1023- raise RuntimeError ("Trouble!" )
1026+ class TroubledCallbackOnValidationEpochStart (Callback ):
1027+ def on_validation_epoch_start (self , trainer , pl_module ):
1028+ if not trainer .sanity_checking and trainer .current_epoch == 1 :
1029+ raise RuntimeError ("Trouble!" )
1030+
1031+
1032+ class TroubledCallbackOnValidationEpochEnd (Callback ):
1033+ def on_validation_epoch_end (self , trainer , pl_module ):
1034+ if not trainer .sanity_checking and trainer .current_epoch == 1 :
1035+ raise RuntimeError ("Trouble!" )
10241036
10251037
10261038class TroubledCallbackOnValidationBatchStart (Callback ):
@@ -1035,16 +1047,9 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx)
10351047 raise RuntimeError ("Trouble!" )
10361048
10371049
1038- class TroubledCallbackOnValidationEpochStart (Callback ):
1039- def on_validation_epoch_start (self , trainer , pl_module ):
1040- if not trainer .sanity_checking and trainer .current_epoch == 1 :
1041- raise RuntimeError ("Trouble!" )
1042-
1043-
1044- class TroubledCallbackOnValidationEpochEnd (Callback ):
1045- def on_validation_epoch_end (self , trainer , pl_module ):
1046- if not trainer .sanity_checking and trainer .current_epoch == 1 :
1047- raise RuntimeError ("Trouble!" )
1050+ class TroubledCallbackOnTrainEnd (Callback ):
1051+ def on_train_end (self , trainer , pl_module ):
1052+ raise RuntimeError ("Trouble!" )
10481053
10491054
10501055class TroubledCallbackOnValidationStart (Callback ):
@@ -1059,26 +1064,52 @@ def on_validation_end(self, trainer, pl_module):
10591064 raise RuntimeError ("Trouble!" )
10601065
10611066
1062- class TroubledCallbackOnFitEnd (Callback ):
1063- def on_fit_end (self , trainer , pl_module ):
1064- raise RuntimeError ("Trouble!" )
1067+ class TroubleCallbackOnBeforeBackward (Callback ):
1068+ def on_before_backward (self , trainer , pl_module , loss ):
1069+ if trainer .current_epoch == 1 :
1070+ raise RuntimeError ("Trouble!" )
1071+
1072+
1073+ class TroubleCallbackOnAfterBackward (Callback ):
1074+ def on_after_backward (self , trainer , pl_module ):
1075+ if trainer .current_epoch == 1 :
1076+ raise RuntimeError ("Trouble!" )
1077+
1078+
1079+ class TroubleCallbackOnBeforeOptimizerStep (Callback ):
1080+ def on_before_optimizer_step (self , trainer , pl_module , optimizer ):
1081+ if trainer .current_epoch == 1 :
1082+ raise RuntimeError ("Trouble!" )
1083+
1084+
1085+ class TroubleCallbackOnBeforeZeroGrad (Callback ):
1086+ def on_before_zero_grad (self , trainer , pl_module , optimizer ):
1087+ if trainer .current_epoch == 1 :
1088+ raise RuntimeError ("Trouble!" )
1089+
1090+
1091+ ####
10651092
10661093
10671094@pytest .mark .parametrize (
10681095 "TroubledCallback" ,
10691096 [
1097+ TroubledCallbackOnFitEnd ,
10701098 TroubledCallbackOnTrainBatchStart ,
10711099 TroubledCallbackOnTrainBatchEnd ,
10721100 TroubledCallbackOnTrainEpochStart ,
10731101 TroubledCallbackOnTrainEpochEnd ,
1074- TroubledCallbackOnTrainEnd ,
1075- TroubledCallbackOnValidationBatchStart ,
1076- TroubledCallbackOnValidationBatchEnd ,
10771102 TroubledCallbackOnValidationEpochStart ,
10781103 TroubledCallbackOnValidationEpochEnd ,
1104+ TroubledCallbackOnValidationBatchStart ,
1105+ TroubledCallbackOnValidationBatchEnd ,
1106+ TroubledCallbackOnTrainEnd ,
10791107 TroubledCallbackOnValidationStart ,
10801108 TroubledCallbackOnValidationEnd ,
1081- TroubledCallbackOnFitEnd ,
1109+ TroubleCallbackOnBeforeBackward ,
1110+ TroubleCallbackOnAfterBackward ,
1111+ TroubleCallbackOnBeforeOptimizerStep ,
1112+ TroubleCallbackOnBeforeZeroGrad ,
10821113 ],
10831114)
10841115def test_model_checkpoint_on_exception_in_other_callbacks (tmp_path , TroubledCallback ):
0 commit comments