@@ -837,26 +837,20 @@ def on_train_batch_start(self, batch, batch_idx) -> None:
837837 assert os .path .isfile (tmp_path / "already_saved.ckpt" )
838838
839839
840- class TroubledModelOnTrainEpochStart (BoringModel ):
841- def on_train_epoch_start (self ):
842- if self .current_epoch == 1 :
843- raise RuntimeError ("Trouble!" )
844-
845-
846- class TroubledModelOnTrainBatchStart (BoringModel ):
847- def on_train_batch_start (self , batch , batch_idx ):
840+ class TroubledModelInTrainingStep (BoringModel ):
841+ def training_step (self , batch , batch_idx ):
848842 if batch_idx == 1 :
849843 raise RuntimeError ("Trouble!" )
850844
851845
852- class TroubledModelInTrainingStep (BoringModel ):
853- def training_step (self , batch , batch_idx ):
854- if batch_idx == 1 :
846+ class TroubledModelInValidationStep (BoringModel ):
847+ def validation_step (self , batch , batch_idx ):
848+ if not self . trainer . sanity_checking and batch_idx == 1 :
855849 raise RuntimeError ("Trouble!" )
856850
857851
858- class TroubledModelOnBeforeZeroGrad (BoringModel ):
859- def on_before_zero_grad (self , optimizer ):
852+ class TroubledModelBackward (BoringModel ):
853+ def backward (self , loss ):
860854 if self .current_epoch == 1 :
861855 raise RuntimeError ("Trouble!" )
862856
@@ -873,22 +867,15 @@ def on_after_backward(self):
873867 raise RuntimeError ("Trouble!" )
874868
875869
876- class TroubledModelOnBeforeOptimizerStep (BoringModel ):
877- def on_before_optimizer_step (self , optimizer ):
870+ class TroubledModelOnBeforeZeroGrad (BoringModel ):
871+ def on_before_zero_grad (self , optimizer ):
878872 if self .current_epoch == 1 :
879873 raise RuntimeError ("Trouble!" )
880874
881875
882- class TroubledModelOnTrainBatchEnd (BoringModel ):
883- def on_train_batch_end (self , outputs , batch , batch_idx ):
884- if batch_idx == 1 :
885- raise RuntimeError ("Trouble!" )
886-
887-
888- class TroubledModelOnTrainEpochEnd (BoringModel ):
889- def on_train_epoch_end (self ):
890- if self .current_epoch == 1 :
891- raise RuntimeError ("Trouble!" )
876+ class TroubledModelOnFitEnd (BoringModel ):
877+ def on_fit_end (self ):
878+ raise RuntimeError ("Trouble!" )
892879
893880
894881class TroubledModelOnTrainEnd (BoringModel ):
@@ -902,20 +889,38 @@ def on_validation_start(self):
902889 raise RuntimeError ("Trouble!" )
903890
904891
905- class TroubledModelOnValidationEpochStart (BoringModel ):
906- def on_validation_epoch_start (self ):
907- if not self .trainer .sanity_checking and self . current_epoch == 1 :
892+ class TroubledModelOnValidationEnd (BoringModel ):
893+ def on_validation_end (self ):
894+ if not self .trainer .sanity_checking :
908895 raise RuntimeError ("Trouble!" )
909896
910897
911- class TroubledModelOnValidationBatchStart (BoringModel ):
912- def on_validation_batch_start (self , batch , batch_idx ):
913- if not self . trainer . sanity_checking and batch_idx == 1 :
898+ class TroubledModelOnTrainBatchStart (BoringModel ):
899+ def on_train_batch_start (self , batch , batch_idx ):
900+ if batch_idx == 1 :
914901 raise RuntimeError ("Trouble!" )
915902
916903
917- class TroubledModelInValidationStep (BoringModel ):
918- def validation_step (self , batch , batch_idx ):
904+ class TroubledModelOnTrainBatchEnd (BoringModel ):
905+ def on_train_batch_end (self , outputs , batch , batch_idx ):
906+ if batch_idx == 1 :
907+ raise RuntimeError ("Trouble!" )
908+
909+
910+ class TroubledModelOnTrainEpochStart (BoringModel ):
911+ def on_train_epoch_start (self ):
912+ if self .current_epoch == 1 :
913+ raise RuntimeError ("Trouble!" )
914+
915+
916+ class TroubledModelOnTrainEpochEnd (BoringModel ):
917+ def on_train_epoch_end (self ):
918+ if self .current_epoch == 1 :
919+ raise RuntimeError ("Trouble!" )
920+
921+
922+ class TroubledModelOnValidationBatchStart (BoringModel ):
923+ def on_validation_batch_start (self , batch , batch_idx ):
919924 if not self .trainer .sanity_checking and batch_idx == 1 :
920925 raise RuntimeError ("Trouble!" )
921926
@@ -926,44 +931,82 @@ def on_validation_batch_end(self, outputs, batch, batch_idx):
926931 raise RuntimeError ("Trouble!" )
927932
928933
934+ class TroubledModelOnValidationEpochStart (BoringModel ):
935+ def on_validation_epoch_start (self ):
936+ if not self .trainer .sanity_checking and self .current_epoch == 1 :
937+ raise RuntimeError ("Trouble!" )
938+
939+
929940class TroubledModelOnValidationEpochEnd (BoringModel ):
930941 def on_validation_epoch_end (self ):
931942 if not self .trainer .sanity_checking and self .current_epoch == 1 :
932943 raise RuntimeError ("Trouble!" )
933944
934945
935- class TroubledModelOnValidationEnd (BoringModel ):
936- def on_validation_end (self ):
937- if not self .trainer .sanity_checking :
946+ class TroubledModelOnValidationModelEval (BoringModel ):
947+ def on_validation_model_eval (self ):
948+ if not self .trainer .sanity_checking and self . current_epoch == 1 :
938949 raise RuntimeError ("Trouble!" )
939950
940951
941- class TroubledModelOnFitEnd (BoringModel ):
942- def on_fit_end (self ):
943- raise RuntimeError ("Trouble!" )
952+ class TroubledModelOnValidationModelTrain (BoringModel ):
953+ def on_validation_model_train (self ):
954+ if not self .trainer .sanity_checking and self .current_epoch == 1 :
955+ raise RuntimeError ("Trouble!" )
956+
957+
958+ class TroubledModelOnBeforeOptimizerStep (BoringModel ):
959+ def on_before_optimizer_step (self , optimizer ):
960+ if self .current_epoch == 1 :
961+ raise RuntimeError ("Trouble!" )
962+
963+
964+ class TroubledModelConfigureGradienClipping (BoringModel ):
965+ def configure_gradient_clipping (self , optimizer , gradient_clip_val = None , gradient_clip_algorithm = None ):
966+ if self .current_epoch == 1 :
967+ raise RuntimeError ("Trouble!" )
968+
969+
970+ class TroubledModelOptimizerStep (BoringModel ):
971+ def optimizer_step (self , epoch , batch_idx , optimizer , optimizer_closure = None ):
972+ optimizer .step (closure = optimizer_closure )
973+ if self .current_epoch == 1 :
974+ raise RuntimeError ("Trouble!" )
975+
976+
977+ class TroubledModelOptimizerZeroGrad (BoringModel ):
978+ def optimizer_zero_grad (self , epoch , batch_idx , optimizer ):
979+ if self .current_epoch == 1 :
980+ raise RuntimeError ("Trouble!" )
944981
945982
946983@pytest .mark .parametrize (
947984 "TroubledModel" ,
948985 [
949- TroubledModelOnTrainEpochStart ,
950- TroubledModelOnTrainBatchStart ,
951986 TroubledModelInTrainingStep ,
952- TroubledModelOnBeforeZeroGrad ,
987+ TroubledModelInValidationStep ,
988+ TroubledModelBackward ,
953989 TroubledModelOnBeforeBackward ,
954990 TroubledModelOnAfterBackward ,
955- TroubledModelOnBeforeOptimizerStep ,
956- TroubledModelOnTrainBatchEnd ,
957- TroubledModelOnTrainEpochEnd ,
991+ TroubledModelOnBeforeZeroGrad ,
992+ TroubledModelOnFitEnd ,
958993 TroubledModelOnTrainEnd ,
959994 TroubledModelOnValidationStart ,
960- TroubledModelOnValidationEpochStart ,
995+ TroubledModelOnValidationEnd ,
996+ TroubledModelOnTrainBatchStart ,
997+ TroubledModelOnTrainBatchEnd ,
998+ TroubledModelOnTrainEpochStart ,
999+ TroubledModelOnTrainEpochEnd ,
9611000 TroubledModelOnValidationBatchStart ,
962- TroubledModelInValidationStep ,
9631001 TroubledModelOnValidationBatchEnd ,
1002+ TroubledModelOnValidationEpochStart ,
9641003 TroubledModelOnValidationEpochEnd ,
965- TroubledModelOnValidationEnd ,
966- TroubledModelOnFitEnd ,
1004+ TroubledModelOnValidationModelEval ,
1005+ TroubledModelOnValidationModelTrain ,
1006+ TroubledModelOnBeforeOptimizerStep ,
1007+ TroubledModelConfigureGradienClipping ,
1008+ TroubledModelOptimizerStep ,
1009+ TroubledModelOptimizerZeroGrad ,
9671010 ],
9681011)
9691012def test_model_checkpoint_on_exception_parametrized (tmp_path , TroubledModel ):
0 commit comments