@@ -770,7 +770,7 @@ def test_ckpt_every_n_train_steps(tmp_path):
770770 assert set (os .listdir (tmp_path )) == set (expected )
771771
772772
773- def test_model_checkpoint_on_exception_run_condition (tmp_path ):
773+ def test_model_checkpoint_on_exception_run_condition_on_validation_start (tmp_path ):
774774 """Test that no checkpoint is saved when an exception is raised during a sanity check or a fast dev run, or when a
775775 checkpoint has already been saved at the current training step."""
776776
@@ -796,6 +796,10 @@ def on_validation_start(self) -> None:
796796 trainer .fit (model )
797797 assert not os .path .isfile (tmp_path / "exception-sanity_check.ckpt" )
798798
799+
800+ def test_model_checkpoint_on_exception_fast_dev_run_on_train_batch_start (tmp_path ):
801+ """Test that no checkpoint is saved when an exception is raised during a sanity check or a fast dev run, or when a
802+ checkpoint has already been saved at the current training step."""
799803 # Don't save checkpoint if fast dev run fails
800804 class TroubledModelFastDevRun (BoringModel ):
801805 def on_train_batch_start (self , batch , batch_idx ) -> None :
@@ -817,6 +821,9 @@ def on_train_batch_start(self, batch, batch_idx) -> None:
817821 trainer .fit (model )
818822 assert not os .path .isfile (tmp_path / "exception-fast_dev_run.ckpt" )
819823
824+ def test_model_checkpoint_on_exception_run_condition_on_train_batch_start (tmp_path ):
825+ """Test that no checkpoint is saved when an exception is raised during a sanity check or a fast dev run, or when a
826+ checkpoint has already been saved at the current training step."""
820827 # Don't save checkpoint if already saved a checkpoint
821828 class TroubledModelAlreadySavedCheckpoint (BoringModel ):
822829 def on_train_batch_start (self , batch , batch_idx ) -> None :
0 commit comments