@@ -766,22 +766,30 @@ def test_ckpt_every_n_train_steps(tmp_path):
766766
767767def test_model_checkpoint_save_on_exception_in_training_step (tmp_path ):
768768 """Test that the checkpoint is saved when an exception is raised in training_step."""
769+
769770 class TroubledModel (BoringModel ):
770771 def training_step (self , batch , batch_idx ):
771772 if batch_idx == 1 :
772773 raise RuntimeError ("Trouble!" )
773774
774775 model = TroubledModel ()
775776 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
776- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback ],
777- max_epochs = 5 , logger = False , enable_progress_bar = False )
777+ trainer = Trainer (
778+ default_root_dir = tmp_path ,
779+ callbacks = [checkpoint_callback ],
780+ max_epochs = 5 ,
781+ logger = False ,
782+ enable_progress_bar = False ,
783+ )
778784 with pytest .raises (RuntimeError , match = "Trouble!" ):
779785 trainer .fit (model )
780786 print (os .listdir (tmp_path ))
781787 assert os .path .isfile (tmp_path / "step=1.ckpt" )
782788
789+
783790def test_model_checkpoint_save_on_exception_in_validation_step (tmp_path ):
784791 """Test that the checkpoint is saved when an exception is raised in validation_step."""
792+
785793 class TroubledModel (BoringModel ):
786794 def validation_step (self , batch , batch_idx ):
787795 if not trainer .sanity_checking and batch_idx == 0 :
@@ -790,40 +798,57 @@ def validation_step(self, batch, batch_idx):
790798 model = TroubledModel ()
791799 epoch_length = 64
792800 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
793- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback ],
794- max_epochs = 5 , logger = False , enable_progress_bar = False )
801+ trainer = Trainer (
802+ default_root_dir = tmp_path ,
803+ callbacks = [checkpoint_callback ],
804+ max_epochs = 5 ,
805+ logger = False ,
806+ enable_progress_bar = False ,
807+ )
795808 with pytest .raises (RuntimeError , match = "Trouble!" ):
796809 trainer .fit (model )
797810 assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
798811
799812
800813def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_start (tmp_path ):
801814 """Test that the checkpoint is saved when an exception is raised in a callback on train_batch_start."""
815+
802816 class TroublemakerOnTrainBatchStart (Callback ):
803817 def on_train_batch_start (self , trainer , pl_module , batch , batch_idx ):
804818 if batch_idx == 1 :
805819 raise RuntimeError ("Trouble!" )
806820
807821 model = BoringModel ()
808822 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
809- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainBatchStart ()],
810- max_epochs = 5 , logger = False , enable_progress_bar = False )
823+ trainer = Trainer (
824+ default_root_dir = tmp_path ,
825+ callbacks = [checkpoint_callback , TroublemakerOnTrainBatchStart ()],
826+ max_epochs = 5 ,
827+ logger = False ,
828+ enable_progress_bar = False ,
829+ )
811830 with pytest .raises (RuntimeError , match = "Trouble!" ):
812831 trainer .fit (model )
813832 assert os .path .isfile (tmp_path / "step=1.ckpt" )
814833
815834
816835def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_end (tmp_path ):
817836 """Test that the checkpoint is saved when an exception is raised in a callback on train_batch_end."""
837+
818838 class TroublemakerOnTrainBatchEnd (Callback ):
819839 def on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
820840 if batch_idx == 1 :
821841 raise RuntimeError ("Trouble!" )
822842
823843 model = BoringModel ()
824844 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
825- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainBatchEnd ()],
826- max_epochs = 5 , logger = False , enable_progress_bar = False )
845+ trainer = Trainer (
846+ default_root_dir = tmp_path ,
847+ callbacks = [checkpoint_callback , TroublemakerOnTrainBatchEnd ()],
848+ max_epochs = 5 ,
849+ logger = False ,
850+ enable_progress_bar = False ,
851+ )
827852 with pytest .raises (RuntimeError , match = "Trouble!" ):
828853 trainer .fit (model )
829854
@@ -832,6 +857,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
832857
833858def test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_start (tmp_path ):
834859 """Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_start."""
860+
835861 class TroublemakerOnTrainEpochStart (Callback ):
836862 def on_train_epoch_start (self , trainer , pl_module ):
837863 if trainer .current_epoch == 1 :
@@ -840,15 +866,21 @@ def on_train_epoch_start(self, trainer, pl_module):
840866 model = BoringModel ()
841867 epoch_length = 64
842868 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
843- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainEpochStart ()],
844- max_epochs = 5 , logger = False , enable_progress_bar = False )
869+ trainer = Trainer (
870+ default_root_dir = tmp_path ,
871+ callbacks = [checkpoint_callback , TroublemakerOnTrainEpochStart ()],
872+ max_epochs = 5 ,
873+ logger = False ,
874+ enable_progress_bar = False ,
875+ )
845876 with pytest .raises (RuntimeError , match = "Trouble!" ):
846877 trainer .fit (model )
847878 assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
848879
849880
850881def test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_end (tmp_path ):
851882 """Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_end."""
883+
852884 class TroublemakerOnTrainEpochEnd (Callback ):
853885 def on_train_epoch_end (self , trainer , pl_module ):
854886 if trainer .current_epoch == 1 :
@@ -857,49 +889,67 @@ def on_train_epoch_end(self, trainer, pl_module):
857889 model = BoringModel ()
858890 epoch_length = 64
859891 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
860- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainEpochEnd ()],
861- max_epochs = 5 , logger = False , enable_progress_bar = False )
892+ trainer = Trainer (
893+ default_root_dir = tmp_path ,
894+ callbacks = [checkpoint_callback , TroublemakerOnTrainEpochEnd ()],
895+ max_epochs = 5 ,
896+ logger = False ,
897+ enable_progress_bar = False ,
898+ )
862899 with pytest .raises (RuntimeError , match = "Trouble!" ):
863900 trainer .fit (model )
864- assert os .path .isfile (tmp_path / f"step={ 2 * epoch_length } .ckpt" )
901+ assert os .path .isfile (tmp_path / f"step={ 2 * epoch_length } .ckpt" )
865902
866903
867904def test_model_checkpoint_save_on_exception_in_val_callback (tmp_path ):
868905 """Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_start."""
906+
869907 class TroublemakerOnValidationBatchStart (Callback ):
870908 def on_validation_batch_start (self , trainer , pl_module , batch , batch_idx ):
871- if not trainer .sanity_checking and batch_idx == 1 :
872- raise RuntimeError ("Trouble!" )
909+ if not trainer .sanity_checking and batch_idx == 1 :
910+ raise RuntimeError ("Trouble!" )
873911
874912 model = BoringModel ()
875913 epoch_length = 64
876914 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
877- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationBatchStart ()],
878- max_epochs = 5 , logger = False , enable_progress_bar = False )
915+ trainer = Trainer (
916+ default_root_dir = tmp_path ,
917+ callbacks = [checkpoint_callback , TroublemakerOnValidationBatchStart ()],
918+ max_epochs = 5 ,
919+ logger = False ,
920+ enable_progress_bar = False ,
921+ )
879922 with pytest .raises (RuntimeError , match = "Trouble!" ):
880923 trainer .fit (model )
881924 assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
882925
883926
884927def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_batch_end (tmp_path ):
885928 """Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_end."""
929+
886930 class TroublemakerOnValidationBatchEnd (Callback ):
887931 def on_validation_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
888- if not trainer .sanity_checking and batch_idx == 1 :
889- raise RuntimeError ("Trouble!" )
932+ if not trainer .sanity_checking and batch_idx == 1 :
933+ raise RuntimeError ("Trouble!" )
890934
891935 model = BoringModel ()
892936 epoch_length = 64
893937 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
894- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationBatchEnd ()],
895- max_epochs = 5 , logger = False , enable_progress_bar = False )
938+ trainer = Trainer (
939+ default_root_dir = tmp_path ,
940+ callbacks = [checkpoint_callback , TroublemakerOnValidationBatchEnd ()],
941+ max_epochs = 5 ,
942+ logger = False ,
943+ enable_progress_bar = False ,
944+ )
896945 with pytest .raises (RuntimeError , match = "Trouble!" ):
897946 trainer .fit (model )
898947 assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
899948
900949
901950def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_start (tmp_path ):
902951 """Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_start."""
952+
903953 class TroublemakerOnValidationEpochStart (Callback ):
904954 def on_validation_epoch_start (self , trainer , pl_module ):
905955 if not trainer .sanity_checking and trainer .current_epoch == 0 :
@@ -908,15 +958,21 @@ def on_validation_epoch_start(self, trainer, pl_module):
908958 model = BoringModel ()
909959 epoch_length = 64
910960 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
911- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEpochStart ()],
912- max_epochs = 5 , logger = False , enable_progress_bar = False )
961+ trainer = Trainer (
962+ default_root_dir = tmp_path ,
963+ callbacks = [checkpoint_callback , TroublemakerOnValidationEpochStart ()],
964+ max_epochs = 5 ,
965+ logger = False ,
966+ enable_progress_bar = False ,
967+ )
913968 with pytest .raises (RuntimeError , match = "Trouble!" ):
914969 trainer .fit (model )
915970 assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
916971
917972
918973def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_end (tmp_path ):
919974 """Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_end."""
975+
920976 class TroublemakerOnValidationEpochEnd (Callback ):
921977 def on_validation_epoch_end (self , trainer , pl_module ):
922978 if not trainer .sanity_checking and trainer .current_epoch == 0 :
@@ -925,14 +981,21 @@ def on_validation_epoch_end(self, trainer, pl_module):
925981 model = BoringModel ()
926982 epoch_length = 64
927983 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
928- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEpochEnd ()],
929- max_epochs = 5 , logger = False , enable_progress_bar = False )
984+ trainer = Trainer (
985+ default_root_dir = tmp_path ,
986+ callbacks = [checkpoint_callback , TroublemakerOnValidationEpochEnd ()],
987+ max_epochs = 5 ,
988+ logger = False ,
989+ enable_progress_bar = False ,
990+ )
930991 with pytest .raises (RuntimeError , match = "Trouble!" ):
931992 trainer .fit (model )
932993 assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
933994
995+
934996def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_start (tmp_path ):
935997 """Test that the checkpoint is saved when an exception is raised in a callback on validation_start."""
998+
936999 class TroublemakerOnValidationStart (Callback ):
9371000 def on_validation_start (self , trainer , pl_module ):
9381001 if not trainer .sanity_checking :
@@ -941,14 +1004,21 @@ def on_validation_start(self, trainer, pl_module):
9411004 model = BoringModel ()
9421005 epoch_length = 64
9431006 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
944- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationStart ()],
945- max_epochs = 5 , logger = False , enable_progress_bar = False )
1007+ trainer = Trainer (
1008+ default_root_dir = tmp_path ,
1009+ callbacks = [checkpoint_callback , TroublemakerOnValidationStart ()],
1010+ max_epochs = 5 ,
1011+ logger = False ,
1012+ enable_progress_bar = False ,
1013+ )
9461014 with pytest .raises (RuntimeError , match = "Trouble!" ):
9471015 trainer .fit (model )
9481016 assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
9491017
1018+
9501019def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_end (tmp_path ):
9511020 """Test that the checkpoint is saved when an exception is raised in a callback on validation_end."""
1021+
9521022 class TroublemakerOnValidationEnd (Callback ):
9531023 def on_validation_end (self , trainer , pl_module ):
9541024 if not trainer .sanity_checking :
@@ -957,8 +1027,13 @@ def on_validation_end(self, trainer, pl_module):
9571027 model = BoringModel ()
9581028 epoch_length = 64
9591029 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
960- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEnd ()],
961- max_epochs = 5 , logger = False , enable_progress_bar = False )
1030+ trainer = Trainer (
1031+ default_root_dir = tmp_path ,
1032+ callbacks = [checkpoint_callback , TroublemakerOnValidationEnd ()],
1033+ max_epochs = 5 ,
1034+ logger = False ,
1035+ enable_progress_bar = False ,
1036+ )
9621037 with pytest .raises (RuntimeError , match = "Trouble!" ):
9631038 trainer .fit (model )
9641039 assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
0 commit comments