Skip to content

Commit d4d933b

Browse files
committed
add test for save model chekpoint on exception for exception in train and val step
1 parent 34e598a commit d4d933b

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,37 @@ def test_ckpt_every_n_train_steps(tmp_path):
764764
assert set(os.listdir(tmp_path)) == set(expected)
765765

766766

767+
def test_model_checkpoint_save_on_exception_in_training_step(tmp_path):
768+
"""Test that the checkpoint is saved when an exception is raised in training_step."""
769+
class TroubledModel(BoringModel):
770+
def training_step(self, batch, batch_idx):
771+
if batch_idx == 1:
772+
raise RuntimeError("Trouble!")
773+
774+
model = TroubledModel()
775+
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
776+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback], max_epochs=5, logger=False)
777+
with pytest.raises(RuntimeError, match="Trouble!"):
778+
trainer.fit(model)
779+
print(os.listdir(tmp_path))
780+
assert os.path.isfile(tmp_path / "step=1.ckpt")
781+
782+
def test_model_checkpoint_save_on_exception_in_validation_step(tmp_path):
783+
"""Test that the checkpoint is saved when an exception is raised in validation_step."""
784+
class TroubledModel(BoringModel):
785+
def validation_step(self, batch, batch_idx):
786+
if not trainer.sanity_checking and batch_idx == 0:
787+
raise RuntimeError("Trouble!")
788+
789+
model = TroubledModel()
790+
epoch_length = 64
791+
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
792+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback], max_epochs=5, logger=False)
793+
with pytest.raises(RuntimeError, match="Trouble!"):
794+
trainer.fit(model)
795+
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
796+
797+
767798
def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_start(tmp_path):
768799
"""Test that the checkpoint is saved when an exception is raised in a callback on train_batch_start."""
769800
class TroublemakerOnTrainBatchStart(Callback):

0 commit comments

Comments
 (0)