@@ -773,7 +773,8 @@ def training_step(self, batch, batch_idx):
773773
774774 model = TroubledModel ()
775775 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
776- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback ], max_epochs = 5 , logger = False , enable_progress_bar = False )
776+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback ],
777+ max_epochs = 5 , logger = False , enable_progress_bar = False )
777778 with pytest .raises (RuntimeError , match = "Trouble!" ):
778779 trainer .fit (model )
779780 print (os .listdir (tmp_path ))
@@ -785,15 +786,16 @@ class TroubledModel(BoringModel):
785786 def validation_step (self , batch , batch_idx ):
786787 if not trainer .sanity_checking and batch_idx == 0 :
787788 raise RuntimeError ("Trouble!" )
788-
789+
789790 model = TroubledModel ()
790791 epoch_length = 64
791792 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
792- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback ], max_epochs = 5 , logger = False , enable_progress_bar = False )
793+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback ],
794+ max_epochs = 5 , logger = False , enable_progress_bar = False )
793795 with pytest .raises (RuntimeError , match = "Trouble!" ):
794796 trainer .fit (model )
795797 assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
796-
798+
797799
798800def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_start (tmp_path ):
799801 """Test that the checkpoint is saved when an exception is raised in a callback on train_batch_start."""
@@ -804,7 +806,8 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
804806
805807 model = BoringModel ()
806808 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
807- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainBatchStart ()], max_epochs = 5 , logger = False , enable_progress_bar = False )
809+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainBatchStart ()],
810+ max_epochs = 5 , logger = False , enable_progress_bar = False )
808811 with pytest .raises (RuntimeError , match = "Trouble!" ):
809812 trainer .fit (model )
810813 assert os .path .isfile (tmp_path / "step=1.ckpt" )
@@ -816,10 +819,11 @@ class TroublemakerOnTrainBatchEnd(Callback):
816819 def on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
817820 if batch_idx == 1 :
818821 raise RuntimeError ("Trouble!" )
819-
822+
820823 model = BoringModel ()
821824 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
822- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainBatchEnd ()], max_epochs = 5 , logger = False , enable_progress_bar = False )
825+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainBatchEnd ()],
826+ max_epochs = 5 , logger = False , enable_progress_bar = False )
823827 with pytest .raises (RuntimeError , match = "Trouble!" ):
824828 trainer .fit (model )
825829
@@ -832,11 +836,12 @@ class TroublemakerOnTrainEpochStart(Callback):
832836 def on_train_epoch_start (self , trainer , pl_module ):
833837 if trainer .current_epoch == 1 :
834838 raise RuntimeError ("Trouble!" )
835-
839+
836840 model = BoringModel ()
837841 epoch_length = 64
838842 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
839- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainEpochStart ()], max_epochs = 5 , logger = False , enable_progress_bar = False )
843+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainEpochStart ()],
844+ max_epochs = 5 , logger = False , enable_progress_bar = False )
840845 with pytest .raises (RuntimeError , match = "Trouble!" ):
841846 trainer .fit (model )
842847 assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
@@ -852,7 +857,8 @@ def on_train_epoch_end(self, trainer, pl_module):
852857 model = BoringModel ()
853858 epoch_length = 64
854859 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
855- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainEpochEnd ()], max_epochs = 5 , logger = False , enable_progress_bar = False )
860+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainEpochEnd ()],
861+ max_epochs = 5 , logger = False , enable_progress_bar = False )
856862 with pytest .raises (RuntimeError , match = "Trouble!" ):
857863 trainer .fit (model )
858864 assert os .path .isfile (tmp_path / f"step={ 2 * epoch_length } .ckpt" )
@@ -868,7 +874,8 @@ def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx):
868874 model = BoringModel ()
869875 epoch_length = 64
870876 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
871- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationBatchStart ()], max_epochs = 5 , logger = False , enable_progress_bar = False )
877+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationBatchStart ()],
878+ max_epochs = 5 , logger = False , enable_progress_bar = False )
872879 with pytest .raises (RuntimeError , match = "Trouble!" ):
873880 trainer .fit (model )
874881 assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
@@ -880,11 +887,12 @@ class TroublemakerOnValidationBatchEnd(Callback):
880887 def on_validation_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
881888 if not trainer .sanity_checking and batch_idx == 1 :
882889 raise RuntimeError ("Trouble!" )
883-
890+
884891 model = BoringModel ()
885892 epoch_length = 64
886893 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
887- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationBatchEnd ()], max_epochs = 5 , logger = False , enable_progress_bar = False )
894+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationBatchEnd ()],
895+ max_epochs = 5 , logger = False , enable_progress_bar = False )
888896 with pytest .raises (RuntimeError , match = "Trouble!" ):
889897 trainer .fit (model )
890898 assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
@@ -900,23 +908,25 @@ def on_validation_epoch_start(self, trainer, pl_module):
900908 model = BoringModel ()
901909 epoch_length = 64
902910 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
903- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEpochStart ()], max_epochs = 5 , logger = False , enable_progress_bar = False )
911+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEpochStart ()],
912+ max_epochs = 5 , logger = False , enable_progress_bar = False )
904913 with pytest .raises (RuntimeError , match = "Trouble!" ):
905914 trainer .fit (model )
906915 assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
907-
916+
908917
909918def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_end (tmp_path ):
910919 """Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_end."""
911920 class TroublemakerOnValidationEpochEnd (Callback ):
912921 def on_validation_epoch_end (self , trainer , pl_module ):
913922 if not trainer .sanity_checking and trainer .current_epoch == 0 :
914923 raise RuntimeError ("Trouble!" )
915-
924+
916925 model = BoringModel ()
917926 epoch_length = 64
918927 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
919- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEpochEnd ()], max_epochs = 5 , logger = False , enable_progress_bar = False )
928+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEpochEnd ()],
929+ max_epochs = 5 , logger = False , enable_progress_bar = False )
920930 with pytest .raises (RuntimeError , match = "Trouble!" ):
921931 trainer .fit (model )
922932 assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
@@ -927,11 +937,12 @@ class TroublemakerOnValidationStart(Callback):
927937 def on_validation_start (self , trainer , pl_module ):
928938 if not trainer .sanity_checking :
929939 raise RuntimeError ("Trouble!" )
930-
940+
931941 model = BoringModel ()
932942 epoch_length = 64
933943 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
934- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationStart ()], max_epochs = 5 , logger = False , enable_progress_bar = False )
944+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationStart ()],
945+ max_epochs = 5 , logger = False , enable_progress_bar = False )
935946 with pytest .raises (RuntimeError , match = "Trouble!" ):
936947 trainer .fit (model )
937948 assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
@@ -942,16 +953,17 @@ class TroublemakerOnValidationEnd(Callback):
942953 def on_validation_end (self , trainer , pl_module ):
943954 if not trainer .sanity_checking :
944955 raise RuntimeError ("Trouble!" )
945-
956+
946957 model = BoringModel ()
947958 epoch_length = 64
948959 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
949- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEnd ()], max_epochs = 5 , logger = False , enable_progress_bar = False )
960+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEnd ()],
961+ max_epochs = 5 , logger = False , enable_progress_bar = False )
950962 with pytest .raises (RuntimeError , match = "Trouble!" ):
951963 trainer .fit (model )
952964 assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
953965
954-
966+
955967@mock .patch ("lightning.pytorch.callbacks.model_checkpoint.time" )
956968def test_model_checkpoint_train_time_interval (mock_datetime , tmp_path ) -> None :
957969 """Tests that the checkpoints are saved at the specified time interval."""
0 commit comments