Skip to content

Commit 3a3204e

Browse files
committed
checkpoint on exception add test function for exception in callback
1 parent 904bd74 commit 3a3204e

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,29 @@ def on_validation_end(self, trainer, pl_module):
928928
pytest.param(TroublemakerOnValidationEnd, CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, id="on_validation_end"),
929929
],
930930
)
931+
def test_model_checkpoint_save_on_exception_in_other_callbacks(
932+
tmp_path, TroubledCallback, expected_checkpoint_global_step
933+
):
934+
"""Test that an checkpoint is saved when an exception is raised in an other callback."""
935+
936+
model = BoringModel()
937+
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
938+
trainer = Trainer(
939+
default_root_dir=tmp_path,
940+
callbacks=[checkpoint_callback, TroubledCallback()],
941+
max_epochs=CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS,
942+
limit_train_batches=CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES,
943+
logger=False,
944+
enable_progress_bar=False,
945+
)
946+
with pytest.raises(RuntimeError, match="Trouble!"):
947+
trainer.fit(model)
948+
949+
assert os.path.isfile(tmp_path / f"step={expected_checkpoint_global_step}.ckpt")
950+
checkpoint = torch.load(tmp_path / f"step={expected_checkpoint_global_step}.ckpt", weights_only=True)
951+
assert checkpoint["global_step"] == expected_checkpoint_global_step
952+
953+
931954
@mock.patch("lightning.pytorch.callbacks.model_checkpoint.time")
932955
def test_model_checkpoint_train_time_interval(mock_datetime, tmp_path) -> None:
933956
"""Tests that the checkpoints are saved at the specified time interval."""

0 commit comments

Comments
 (0)