@@ -764,52 +764,126 @@ def test_ckpt_every_n_train_steps(tmp_path):
764764 assert set (os .listdir (tmp_path )) == set (expected )
765765
766766
767- def test_model_checkpoint_save_on_exception_in_train_callback (tmp_path ):
768- """Test that the checkpoint is saved when an exception is raised in a callback on different events ."""
767+ def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_start (tmp_path ):
768+ """Test that the checkpoint is saved when an exception is raised in a callback on train_batch_start ."""
769769 class TroublemakerOnTrainBatchStart (Callback ):
770770 def on_train_batch_start (self , trainer , pl_module , batch , batch_idx ):
771771 if batch_idx == 1 :
772772 raise RuntimeError ("Trouble!" )
773773
774+ model = BoringModel ()
775+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
776+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainBatchStart ()], max_epochs = 5 , logger = False )
777+ with pytest .raises (RuntimeError , match = "Trouble!" ):
778+ trainer .fit (model )
779+ assert os .path .isfile (tmp_path / "step=1.ckpt" )
780+
781+
782+ def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_end (tmp_path ):
783+ """Test that the checkpoint is saved when an exception is raised in a callback on train_batch_end."""
774784 class TroublemakerOnTrainBatchEnd (Callback ):
775785 def on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
776786 if batch_idx == 1 :
777787 raise RuntimeError ("Trouble!" )
788+
789+ model = BoringModel ()
790+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
791+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainBatchEnd ()], max_epochs = 5 , logger = False )
792+ with pytest .raises (RuntimeError , match = "Trouble!" ):
793+ trainer .fit (model )
778794
795+ assert os .path .isfile (tmp_path / "step=2.ckpt" )
796+
797+
798+ def test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_start (tmp_path ):
799+ """Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_start."""
779800 class TroublemakerOnTrainEpochStart (Callback ):
780801 def on_train_epoch_start (self , trainer , pl_module ):
781802 if trainer .current_epoch == 1 :
782803 raise RuntimeError ("Trouble!" )
804+
805+ model = BoringModel ()
806+ epoch_length = 64
807+ checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
808+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainEpochStart ()], max_epochs = 5 , logger = False )
809+ with pytest .raises (RuntimeError , match = "Trouble!" ):
810+ trainer .fit (model )
811+ assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
812+
783813
814+ def test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_end (tmp_path ):
815+ """Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_end."""
784816 class TroublemakerOnTrainEpochEnd (Callback ):
785817 def on_train_epoch_end (self , trainer , pl_module ):
786818 if trainer .current_epoch == 1 :
787819 raise RuntimeError ("Trouble!" )
788820
789-
790- epoch_length = 64
791821 model = BoringModel ()
792- # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints
822+ epoch_length = 64
793823 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
824+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainEpochEnd ()], max_epochs = 5 , logger = False )
825+ with pytest .raises (RuntimeError , match = "Trouble!" ):
826+ trainer .fit (model )
827+ assert os .path .isfile (tmp_path / f"step={ 2 * epoch_length } .ckpt" )
828+
829+
830+ # def test_model_checkpoint_save_on_exception_in_train_callback(tmp_path):
831+ # """Test that the checkpoint is saved when an exception is raised in a callback on different events."""
832+ # class TroublemakerOnTrainBatchStart(Callback):
833+ # def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
834+ # if batch_idx == 1:
835+ # raise RuntimeError("Trouble!")
836+
837+ # class TroublemakerOnTrainBatchEnd(Callback):
838+ # def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
839+ # if batch_idx == 1:
840+ # raise RuntimeError("Trouble!")
841+
842+ # class TroublemakerOnTrainEpochStart(Callback):
843+ # def on_train_epoch_start(self, trainer, pl_module):
844+ # if trainer.current_epoch == 1:
845+ # raise RuntimeError("Trouble!")
846+
847+ # class TroublemakerOnTrainEpochEnd(Callback):
848+ # def on_train_epoch_end(self, trainer, pl_module):
849+ # if trainer.current_epoch == 1:
850+ # raise RuntimeError("Trouble!")
851+
852+
853+ # epoch_length = 64
854+ # model = BoringModel()
855+ # # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints
856+ # checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
857+
858+ # troublemakers = [
859+ # TroublemakerOnTrainBatchStart(),
860+ # TroublemakerOnTrainBatchEnd(),
861+ # TroublemakerOnTrainEpochStart(),
862+ # TroublemakerOnTrainEpochEnd()
863+ # ]
864+
865+ # expected_ckpts = ["step=1.ckpt",
866+ # 'step=2.ckpt',
867+ # f'step={epoch_length}.ckpt',
868+ # f'step={2*epoch_length}.ckpt',
869+ # ]
870+
871+ # for troublemaker in troublemakers:
872+ # trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, troublemaker], max_epochs=5, logger=False)
873+
874+ # with pytest.raises(RuntimeError, match="Trouble!"):
875+ # trainer.fit(model)
794876
795- troublemakers = [
796- TroublemakerOnTrainBatchStart (),
797- TroublemakerOnTrainBatchEnd (),
798- TroublemakerOnTrainEpochStart (),
799- TroublemakerOnTrainEpochEnd ()
800- ]
877+ # assert set(os.listdir(tmp_path)) == set(expected_ckpts)
801878
802- expected_ckpts = ["step=1.ckpt" ,
803- 'step=2.ckpt' ,
804- f'step={ epoch_length } .ckpt' ,
805- f'step={ 2 * epoch_length } .ckpt' ,
806- ]
807879
808- for troublemaker in troublemakers :
809- trainer = Trainer ( default_root_dir = tmp_path , callbacks = [ checkpoint_callback , troublemaker ], max_epochs = 5 , logger = False )
880+ # # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints
881+ # checkpoint_callback = ModelCheckpoint(dirpath =tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4 )
810882
811- with pytest .raises (RuntimeError , match = "Trouble!" ):
812- trainer .fit (model )
883+ # troublemakers = [
884+ # # TroublemakerOnValidationBatchStart(),
885+ # TroublemakerOnValidationBatchEnd(),
886+ # expected_ckpts = [f"step={2*epoch_length}.ckpt",
813887
814888 assert set (os .listdir (tmp_path )) == set (expected_ckpts )
815889
0 commit comments