@@ -773,7 +773,8 @@ def training_step(self, batch, batch_idx):
773
773
774
774
model = TroubledModel ()
775
775
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
776
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback ], max_epochs = 5 , logger = False , enable_progress_bar = False )
776
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback ],
777
+ max_epochs = 5 , logger = False , enable_progress_bar = False )
777
778
with pytest .raises (RuntimeError , match = "Trouble!" ):
778
779
trainer .fit (model )
779
780
print (os .listdir (tmp_path ))
@@ -785,15 +786,16 @@ class TroubledModel(BoringModel):
785
786
def validation_step (self , batch , batch_idx ):
786
787
if not trainer .sanity_checking and batch_idx == 0 :
787
788
raise RuntimeError ("Trouble!" )
788
-
789
+
789
790
model = TroubledModel ()
790
791
epoch_length = 64
791
792
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
792
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback ], max_epochs = 5 , logger = False , enable_progress_bar = False )
793
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback ],
794
+ max_epochs = 5 , logger = False , enable_progress_bar = False )
793
795
with pytest .raises (RuntimeError , match = "Trouble!" ):
794
796
trainer .fit (model )
795
797
assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
796
-
798
+
797
799
798
800
def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_start (tmp_path ):
799
801
"""Test that the checkpoint is saved when an exception is raised in a callback on train_batch_start."""
@@ -804,7 +806,8 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
804
806
805
807
model = BoringModel ()
806
808
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
807
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainBatchStart ()], max_epochs = 5 , logger = False , enable_progress_bar = False )
809
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainBatchStart ()],
810
+ max_epochs = 5 , logger = False , enable_progress_bar = False )
808
811
with pytest .raises (RuntimeError , match = "Trouble!" ):
809
812
trainer .fit (model )
810
813
assert os .path .isfile (tmp_path / "step=1.ckpt" )
@@ -816,10 +819,11 @@ class TroublemakerOnTrainBatchEnd(Callback):
816
819
def on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
817
820
if batch_idx == 1 :
818
821
raise RuntimeError ("Trouble!" )
819
-
822
+
820
823
model = BoringModel ()
821
824
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
822
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainBatchEnd ()], max_epochs = 5 , logger = False , enable_progress_bar = False )
825
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainBatchEnd ()],
826
+ max_epochs = 5 , logger = False , enable_progress_bar = False )
823
827
with pytest .raises (RuntimeError , match = "Trouble!" ):
824
828
trainer .fit (model )
825
829
@@ -832,11 +836,12 @@ class TroublemakerOnTrainEpochStart(Callback):
832
836
def on_train_epoch_start (self , trainer , pl_module ):
833
837
if trainer .current_epoch == 1 :
834
838
raise RuntimeError ("Trouble!" )
835
-
839
+
836
840
model = BoringModel ()
837
841
epoch_length = 64
838
842
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
839
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainEpochStart ()], max_epochs = 5 , logger = False , enable_progress_bar = False )
843
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainEpochStart ()],
844
+ max_epochs = 5 , logger = False , enable_progress_bar = False )
840
845
with pytest .raises (RuntimeError , match = "Trouble!" ):
841
846
trainer .fit (model )
842
847
assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
@@ -852,7 +857,8 @@ def on_train_epoch_end(self, trainer, pl_module):
852
857
model = BoringModel ()
853
858
epoch_length = 64
854
859
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
855
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainEpochEnd ()], max_epochs = 5 , logger = False , enable_progress_bar = False )
860
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainEpochEnd ()],
861
+ max_epochs = 5 , logger = False , enable_progress_bar = False )
856
862
with pytest .raises (RuntimeError , match = "Trouble!" ):
857
863
trainer .fit (model )
858
864
assert os .path .isfile (tmp_path / f"step={ 2 * epoch_length } .ckpt" )
@@ -868,7 +874,8 @@ def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx):
868
874
model = BoringModel ()
869
875
epoch_length = 64
870
876
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
871
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationBatchStart ()], max_epochs = 5 , logger = False , enable_progress_bar = False )
877
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationBatchStart ()],
878
+ max_epochs = 5 , logger = False , enable_progress_bar = False )
872
879
with pytest .raises (RuntimeError , match = "Trouble!" ):
873
880
trainer .fit (model )
874
881
assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
@@ -880,11 +887,12 @@ class TroublemakerOnValidationBatchEnd(Callback):
880
887
def on_validation_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
881
888
if not trainer .sanity_checking and batch_idx == 1 :
882
889
raise RuntimeError ("Trouble!" )
883
-
890
+
884
891
model = BoringModel ()
885
892
epoch_length = 64
886
893
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
887
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationBatchEnd ()], max_epochs = 5 , logger = False , enable_progress_bar = False )
894
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationBatchEnd ()],
895
+ max_epochs = 5 , logger = False , enable_progress_bar = False )
888
896
with pytest .raises (RuntimeError , match = "Trouble!" ):
889
897
trainer .fit (model )
890
898
assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
@@ -900,23 +908,25 @@ def on_validation_epoch_start(self, trainer, pl_module):
900
908
model = BoringModel ()
901
909
epoch_length = 64
902
910
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
903
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEpochStart ()], max_epochs = 5 , logger = False , enable_progress_bar = False )
911
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEpochStart ()],
912
+ max_epochs = 5 , logger = False , enable_progress_bar = False )
904
913
with pytest .raises (RuntimeError , match = "Trouble!" ):
905
914
trainer .fit (model )
906
915
assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
907
-
916
+
908
917
909
918
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_end (tmp_path ):
910
919
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_end."""
911
920
class TroublemakerOnValidationEpochEnd (Callback ):
912
921
def on_validation_epoch_end (self , trainer , pl_module ):
913
922
if not trainer .sanity_checking and trainer .current_epoch == 0 :
914
923
raise RuntimeError ("Trouble!" )
915
-
924
+
916
925
model = BoringModel ()
917
926
epoch_length = 64
918
927
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
919
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEpochEnd ()], max_epochs = 5 , logger = False , enable_progress_bar = False )
928
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEpochEnd ()],
929
+ max_epochs = 5 , logger = False , enable_progress_bar = False )
920
930
with pytest .raises (RuntimeError , match = "Trouble!" ):
921
931
trainer .fit (model )
922
932
assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
@@ -927,11 +937,12 @@ class TroublemakerOnValidationStart(Callback):
927
937
def on_validation_start (self , trainer , pl_module ):
928
938
if not trainer .sanity_checking :
929
939
raise RuntimeError ("Trouble!" )
930
-
940
+
931
941
model = BoringModel ()
932
942
epoch_length = 64
933
943
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
934
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationStart ()], max_epochs = 5 , logger = False , enable_progress_bar = False )
944
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationStart ()],
945
+ max_epochs = 5 , logger = False , enable_progress_bar = False )
935
946
with pytest .raises (RuntimeError , match = "Trouble!" ):
936
947
trainer .fit (model )
937
948
assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
@@ -942,16 +953,17 @@ class TroublemakerOnValidationEnd(Callback):
942
953
def on_validation_end (self , trainer , pl_module ):
943
954
if not trainer .sanity_checking :
944
955
raise RuntimeError ("Trouble!" )
945
-
956
+
946
957
model = BoringModel ()
947
958
epoch_length = 64
948
959
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
949
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEnd ()], max_epochs = 5 , logger = False , enable_progress_bar = False )
960
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEnd ()],
961
+ max_epochs = 5 , logger = False , enable_progress_bar = False )
950
962
with pytest .raises (RuntimeError , match = "Trouble!" ):
951
963
trainer .fit (model )
952
964
assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
953
965
954
-
966
+
955
967
@mock .patch ("lightning.pytorch.callbacks.model_checkpoint.time" )
956
968
def test_model_checkpoint_train_time_interval (mock_datetime , tmp_path ) -> None :
957
969
"""Tests that the checkpoints are saved at the specified time interval."""
0 commit comments