@@ -764,52 +764,126 @@ def test_ckpt_every_n_train_steps(tmp_path):
764
764
assert set (os .listdir (tmp_path )) == set (expected )
765
765
766
766
767
- def test_model_checkpoint_save_on_exception_in_train_callback (tmp_path ):
768
- """Test that the checkpoint is saved when an exception is raised in a callback on different events ."""
767
+ def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_start (tmp_path ):
768
+ """Test that the checkpoint is saved when an exception is raised in a callback on train_batch_start ."""
769
769
class TroublemakerOnTrainBatchStart (Callback ):
770
770
def on_train_batch_start (self , trainer , pl_module , batch , batch_idx ):
771
771
if batch_idx == 1 :
772
772
raise RuntimeError ("Trouble!" )
773
773
774
+ model = BoringModel ()
775
+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
776
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainBatchStart ()], max_epochs = 5 , logger = False )
777
+ with pytest .raises (RuntimeError , match = "Trouble!" ):
778
+ trainer .fit (model )
779
+ assert os .path .isfile (tmp_path / "step=1.ckpt" )
780
+
781
+
782
+ def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_end (tmp_path ):
783
+ """Test that the checkpoint is saved when an exception is raised in a callback on train_batch_end."""
774
784
class TroublemakerOnTrainBatchEnd (Callback ):
775
785
def on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
776
786
if batch_idx == 1 :
777
787
raise RuntimeError ("Trouble!" )
788
+
789
+ model = BoringModel ()
790
+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
791
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainBatchEnd ()], max_epochs = 5 , logger = False )
792
+ with pytest .raises (RuntimeError , match = "Trouble!" ):
793
+ trainer .fit (model )
778
794
795
+ assert os .path .isfile (tmp_path / "step=2.ckpt" )
796
+
797
+
798
+ def test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_start (tmp_path ):
799
+ """Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_start."""
779
800
class TroublemakerOnTrainEpochStart (Callback ):
780
801
def on_train_epoch_start (self , trainer , pl_module ):
781
802
if trainer .current_epoch == 1 :
782
803
raise RuntimeError ("Trouble!" )
804
+
805
+ model = BoringModel ()
806
+ epoch_length = 64
807
+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
808
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainEpochStart ()], max_epochs = 5 , logger = False )
809
+ with pytest .raises (RuntimeError , match = "Trouble!" ):
810
+ trainer .fit (model )
811
+ assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
812
+
783
813
814
+ def test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_end (tmp_path ):
815
+ """Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_end."""
784
816
class TroublemakerOnTrainEpochEnd (Callback ):
785
817
def on_train_epoch_end (self , trainer , pl_module ):
786
818
if trainer .current_epoch == 1 :
787
819
raise RuntimeError ("Trouble!" )
788
820
789
-
790
- epoch_length = 64
791
821
model = BoringModel ()
792
- # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints
822
+ epoch_length = 64
793
823
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
824
+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainEpochEnd ()], max_epochs = 5 , logger = False )
825
+ with pytest .raises (RuntimeError , match = "Trouble!" ):
826
+ trainer .fit (model )
827
+ assert os .path .isfile (tmp_path / f"step={ 2 * epoch_length } .ckpt" )
828
+
829
+
830
+ # def test_model_checkpoint_save_on_exception_in_train_callback(tmp_path):
831
+ # """Test that the checkpoint is saved when an exception is raised in a callback on different events."""
832
+ # class TroublemakerOnTrainBatchStart(Callback):
833
+ # def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
834
+ # if batch_idx == 1:
835
+ # raise RuntimeError("Trouble!")
836
+
837
+ # class TroublemakerOnTrainBatchEnd(Callback):
838
+ # def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
839
+ # if batch_idx == 1:
840
+ # raise RuntimeError("Trouble!")
841
+
842
+ # class TroublemakerOnTrainEpochStart(Callback):
843
+ # def on_train_epoch_start(self, trainer, pl_module):
844
+ # if trainer.current_epoch == 1:
845
+ # raise RuntimeError("Trouble!")
846
+
847
+ # class TroublemakerOnTrainEpochEnd(Callback):
848
+ # def on_train_epoch_end(self, trainer, pl_module):
849
+ # if trainer.current_epoch == 1:
850
+ # raise RuntimeError("Trouble!")
851
+
852
+
853
+ # epoch_length = 64
854
+ # model = BoringModel()
855
+ # # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints
856
+ # checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
857
+
858
+ # troublemakers = [
859
+ # TroublemakerOnTrainBatchStart(),
860
+ # TroublemakerOnTrainBatchEnd(),
861
+ # TroublemakerOnTrainEpochStart(),
862
+ # TroublemakerOnTrainEpochEnd()
863
+ # ]
864
+
865
+ # expected_ckpts = ["step=1.ckpt",
866
+ # 'step=2.ckpt',
867
+ # f'step={epoch_length}.ckpt',
868
+ # f'step={2*epoch_length}.ckpt',
869
+ # ]
870
+
871
+ # for troublemaker in troublemakers:
872
+ # trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, troublemaker], max_epochs=5, logger=False)
873
+
874
+ # with pytest.raises(RuntimeError, match="Trouble!"):
875
+ # trainer.fit(model)
794
876
795
- troublemakers = [
796
- TroublemakerOnTrainBatchStart (),
797
- TroublemakerOnTrainBatchEnd (),
798
- TroublemakerOnTrainEpochStart (),
799
- TroublemakerOnTrainEpochEnd ()
800
- ]
877
+ # assert set(os.listdir(tmp_path)) == set(expected_ckpts)
801
878
802
- expected_ckpts = ["step=1.ckpt" ,
803
- 'step=2.ckpt' ,
804
- f'step={ epoch_length } .ckpt' ,
805
- f'step={ 2 * epoch_length } .ckpt' ,
806
- ]
807
879
808
- for troublemaker in troublemakers :
809
- trainer = Trainer ( default_root_dir = tmp_path , callbacks = [ checkpoint_callback , troublemaker ], max_epochs = 5 , logger = False )
880
+ # # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints
881
+ # checkpoint_callback = ModelCheckpoint(dirpath =tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4 )
810
882
811
- with pytest .raises (RuntimeError , match = "Trouble!" ):
812
- trainer .fit (model )
883
+ # troublemakers = [
884
+ # # TroublemakerOnValidationBatchStart(),
885
+ # TroublemakerOnValidationBatchEnd(),
886
+ # expected_ckpts = [f"step={2*epoch_length}.ckpt",
813
887
814
888
assert set (os .listdir (tmp_path )) == set (expected_ckpts )
815
889
0 commit comments