@@ -776,27 +776,124 @@ def test_ckpt_every_n_train_steps(tmp_path):
776776def test_model_checkpoint_on_exception (tmp_path ):
777777 """Test that the checkpoint is saved when an exception is raised in a lightning module."""
778778
779+ class TroubledModelOnTrainEpochStart (BoringModel ):
780+ def on_train_epoch_start (self ):
781+ if self .current_epoch == 1 :
782+ raise RuntimeError ("Trouble!" )
783+
784+ class TroubledModelOnTrainBatchStart (BoringModel ):
785+ def on_train_batch_start (self , batch , batch_idx ):
786+ if batch_idx == 1 :
787+ raise RuntimeError ("Trouble!" )
788+
779789 class TroubledModelInTrainingStep (BoringModel ):
780790 def training_step (self , batch , batch_idx ):
781791 if batch_idx == 1 :
782792 raise RuntimeError ("Trouble!" )
783793
794+ class TroubledModelOnBeforeZeroGrad (BoringModel ):
795+ def on_before_zero_grad (self , optimizer ):
796+ if self .current_epoch == 1 :
797+ raise RuntimeError ("Trouble!" )
798+
799+ class TroubledModelOnBeforeBackward (BoringModel ):
800+ def on_before_backward (self , loss ):
801+ if self .current_epoch == 1 :
802+ raise RuntimeError ("Trouble!" )
803+
804+ class TroubledModelOnAfterBackward (BoringModel ):
805+ def on_after_backward (self ):
806+ if self .current_epoch == 1 :
807+ raise RuntimeError ("Trouble!" )
808+
809+ class TroubledModelOnBeforeOptimizerStep (BoringModel ):
810+ def on_before_optimizer_step (self , optimizer ):
811+ if self .current_epoch == 1 :
812+ raise RuntimeError ("Trouble!" )
813+
814+ class TroubledModelOnTrainBatchEnd (BoringModel ):
815+ def on_train_batch_end (self , outputs , batch , batch_idx ):
816+ if batch_idx == 1 :
817+ raise RuntimeError ("Trouble!" )
818+
819+ class TroubledModelOnTrainEpochEnd (BoringModel ):
820+ def on_train_epoch_end (self ):
821+ if self .current_epoch == 1 :
822+ raise RuntimeError ("Trouble!" )
823+
824+ class TroubledModelOnTrainEnd (BoringModel ):
825+ def on_train_end (self ):
826+ raise RuntimeError ("Trouble!" )
827+
828+ class TroubledModelOnValidationStart (BoringModel ):
829+ def on_validation_start (self ):
830+ if not self .trainer .sanity_checking and self .current_epoch == 1 :
831+ raise RuntimeError ("Trouble!" )
832+
833+ class TroubledModelOnValidationEpochStart (BoringModel ):
834+ def on_validation_epoch_start (self ):
835+ if not self .trainer .sanity_checking and self .current_epoch == 1 :
836+ raise RuntimeError ("Trouble!" )
837+
838+ class TroubledModelOnValidationBatchStart (BoringModel ):
839+ def on_validation_batch_start (self , batch , batch_idx ):
840+ if not self .trainer .sanity_checking and batch_idx == 1 :
841+ raise RuntimeError ("Trouble!" )
842+
784843 class TroubledModelInValidationStep (BoringModel ):
785844 def validation_step (self , batch , batch_idx ):
786- if not trainer .sanity_checking and batch_idx == 1 :
845+ if not self .trainer .sanity_checking and batch_idx == 1 :
846+ raise RuntimeError ("Trouble!" )
847+
848+ class TroubledModelOnValidationBatchEnd (BoringModel ):
849+ def on_validation_batch_end (self , outputs , batch , batch_idx ):
850+ if not self .trainer .sanity_checking and batch_idx == 1 :
851+ raise RuntimeError ("Trouble!" )
852+
853+ class TroubledModelOnValidationEpochEnd (BoringModel ):
854+ def on_validation_epoch_end (self ):
855+ if not self .trainer .sanity_checking and self .current_epoch == 1 :
856+ raise RuntimeError ("Trouble!" )
857+
858+ class TroubledModelOnValidationEnd (BoringModel ):
859+ def on_validation_end (self ):
860+ if not self .trainer .sanity_checking :
787861 raise RuntimeError ("Trouble!" )
788862
789- models = [TroubledModelInTrainingStep (), TroubledModelInValidationStep ()]
863+ class TroubledModelOnFitEnd (BoringModel ):
864+ def on_fit_end (self ):
865+ raise RuntimeError ("Trouble!" )
866+
867+ models = [
868+ TroubledModelOnTrainEpochStart (),
869+ TroubledModelOnTrainBatchStart (),
870+ TroubledModelInTrainingStep (),
871+ TroubledModelOnBeforeZeroGrad (),
872+ TroubledModelOnBeforeBackward (),
873+ TroubledModelOnAfterBackward (),
874+ TroubledModelOnBeforeOptimizerStep (),
875+ TroubledModelOnTrainBatchEnd (),
876+ TroubledModelOnTrainEpochEnd (),
877+ TroubledModelOnTrainEnd (),
878+ TroubledModelOnValidationStart (),
879+ TroubledModelOnValidationEpochStart (),
880+ TroubledModelOnValidationBatchStart (),
881+ TroubledModelInValidationStep (),
882+ TroubledModelOnValidationBatchEnd (),
883+ TroubledModelOnValidationEpochEnd (),
884+ TroubledModelOnValidationEnd (),
885+ TroubledModelOnFitEnd (),
886+ ]
790887
791888 for model in models :
792889 checkpoint_callback = ModelCheckpoint (
793- dirpath = tmp_path , filename = model .__class__ .__name__ , save_on_exception = True , every_n_epochs = 4
890+ dirpath = tmp_path , filename = model .__class__ .__name__ , save_on_exception = True , every_n_epochs = 5
794891 )
795892 trainer = Trainer (
796893 default_root_dir = tmp_path ,
797894 callbacks = [checkpoint_callback ],
798895 limit_train_batches = 2 ,
799- max_epochs = 5 ,
896+ max_epochs = 4 ,
800897 logger = False ,
801898 enable_progress_bar = False ,
802899 )
0 commit comments