Skip to content

Commit ac33670

Browse files
committed
add test for run conditions for save checkpoint on exception
1 parent d2f74e9 commit ac33670

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,73 @@ def test_ckpt_every_n_train_steps(tmp_path):
770770
assert set(os.listdir(tmp_path)) == set(expected)
771771

772772

773+
def test_model_checkpoint_on_exception_run_condition(tmp_path):
774+
"""Test that the checkpoint is saved when an exception is raised in a lightning module."""
775+
776+
# Don't save checkpoint if sanity check fails
777+
class TroubledModelSanityCheck(BoringModel):
778+
def on_validation_start(self) -> None:
779+
if self.trainer.sanity_checking:
780+
print("Trouble!")
781+
raise RuntimeError("Trouble!")
782+
783+
model = TroubledModelSanityCheck()
784+
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="sanity_check", save_on_exception=True)
785+
trainer = Trainer(
786+
default_root_dir=tmp_path,
787+
num_sanity_val_steps=4,
788+
limit_train_batches=2,
789+
callbacks=[checkpoint_callback],
790+
max_epochs=2,
791+
logger=False,
792+
)
793+
794+
with pytest.raises(RuntimeError, match="Trouble!"):
795+
trainer.fit(model)
796+
assert not os.path.isfile(tmp_path / "exception-sanity_check.ckpt")
797+
798+
# Don't save checkpoint if fast dev run fails
799+
class TroubledModelFastDevRun(BoringModel):
800+
def on_train_batch_start(self, batch, batch_idx) -> None:
801+
if self.trainer.fast_dev_run and batch_idx == 1:
802+
raise RuntimeError("Trouble!")
803+
804+
model = TroubledModelFastDevRun()
805+
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="fast_dev_run", save_on_exception=True)
806+
trainer = Trainer(
807+
default_root_dir=tmp_path,
808+
fast_dev_run=2,
809+
limit_train_batches=2,
810+
callbacks=[checkpoint_callback],
811+
max_epochs=2,
812+
logger=False,
813+
)
814+
815+
with pytest.raises(RuntimeError, match="Trouble!"):
816+
trainer.fit(model)
817+
assert not os.path.isfile(tmp_path / "exception-fast_dev_run.ckpt")
818+
819+
# Don't save checkpoint if already saved a checkpoint
820+
class TroubledModelAlreadySavedCheckpoint(BoringModel):
821+
def on_train_batch_start(self, batch, batch_idx) -> None:
822+
if self.trainer.global_step == 1:
823+
raise RuntimeError("Trouble!")
824+
825+
model = TroubledModelAlreadySavedCheckpoint()
826+
checkpoint_callback = ModelCheckpoint(
827+
dirpath=tmp_path, filename="already_saved", save_on_exception=True, every_n_train_steps=1
828+
)
829+
trainer = Trainer(
830+
default_root_dir=tmp_path, limit_train_batches=2, callbacks=[checkpoint_callback], max_epochs=2, logger=False
831+
)
832+
833+
with pytest.raises(RuntimeError, match="Trouble!"):
834+
trainer.fit(model)
835+
836+
assert not os.path.isfile(tmp_path / "exception-already_saved.ckpt")
837+
assert os.path.isfile(tmp_path / "already_saved.ckpt")
838+
839+
773840
def test_model_checkpoint_on_exception(tmp_path):
774841
"""Test that the checkpoint is saved when an exception is raised in a lightning module."""
775842

0 commit comments

Comments
 (0)