Skip to content

Commit c272633

Browse files
committed
split
1 parent f4428e5 commit c272633

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,7 @@ def test_ckpt_every_n_train_steps(tmp_path):
770770
assert set(os.listdir(tmp_path)) == set(expected)
771771

772772

773-
def test_model_checkpoint_on_exception_run_condition(tmp_path):
773+
def test_model_checkpoint_on_exception_run_condition_on_validation_start(tmp_path):
774774
"""Test that no checkpoint is saved when an exception is raised during a sanity check or a fast dev run, or when a
775775
checkpoint has already been saved at the current training step."""
776776

@@ -796,6 +796,10 @@ def on_validation_start(self) -> None:
796796
trainer.fit(model)
797797
assert not os.path.isfile(tmp_path / "exception-sanity_check.ckpt")
798798

799+
800+
def test_model_checkpoint_on_exception_fast_dev_run_on_train_batch_start(tmp_path):
801+
"""Test that no checkpoint is saved when an exception is raised during a sanity check or a fast dev run, or when a
802+
checkpoint has already been saved at the current training step."""
799803
# Don't save checkpoint if fast dev run fails
800804
class TroubledModelFastDevRun(BoringModel):
801805
def on_train_batch_start(self, batch, batch_idx) -> None:
@@ -817,6 +821,9 @@ def on_train_batch_start(self, batch, batch_idx) -> None:
817821
trainer.fit(model)
818822
assert not os.path.isfile(tmp_path / "exception-fast_dev_run.ckpt")
819823

824+
def test_model_checkpoint_on_exception_run_condition_on_train_batch_start(tmp_path):
825+
"""Test that no checkpoint is saved when an exception is raised during a sanity check or a fast dev run, or when a
826+
checkpoint has already been saved at the current training step."""
820827
# Don't save checkpoint if already saved a checkpoint
821828
class TroubledModelAlreadySavedCheckpoint(BoringModel):
822829
def on_train_batch_start(self, batch, batch_idx) -> None:

0 commit comments

Comments
 (0)