Skip to content

Commit e0dae53

Browse files
committed
add test for exception in training callbacks
1 parent 136e59a commit e0dae53

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,55 @@ def test_ckpt_every_n_train_steps(tmp_path):
764764
assert set(os.listdir(tmp_path)) == set(expected)
765765

766766

767+
def test_model_checkpoint_save_on_exception_in_train_callback(tmp_path):
768+
"""Test that the checkpoint is saved when an exception is raised in a callback on different events."""
769+
class TroublemakerOnTrainBatchStart(Callback):
770+
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
771+
if batch_idx == 1:
772+
raise RuntimeError("Trouble!")
773+
774+
class TroublemakerOnTrainBatchEnd(Callback):
775+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
776+
if batch_idx == 1:
777+
raise RuntimeError("Trouble!")
778+
779+
class TroublemakerOnTrainEpochStart(Callback):
780+
def on_train_epoch_start(self, trainer, pl_module):
781+
if trainer.current_epoch == 1:
782+
raise RuntimeError("Trouble!")
783+
784+
class TroublemakerOnTrainEpochEnd(Callback):
785+
def on_train_epoch_end(self, trainer, pl_module):
786+
if trainer.current_epoch == 1:
787+
raise RuntimeError("Trouble!")
788+
789+
790+
epoch_length = 64
791+
model = BoringModel()
792+
# use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints
793+
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
794+
795+
troublemakers = [
796+
TroublemakerOnTrainBatchStart(),
797+
TroublemakerOnTrainBatchEnd(),
798+
TroublemakerOnTrainEpochStart(),
799+
TroublemakerOnTrainEpochEnd()
800+
]
801+
802+
expected_ckpts = ["step=1.ckpt",
803+
'step=2.ckpt',
804+
f'step={epoch_length}.ckpt',
805+
f'step={2*epoch_length}.ckpt',
806+
]
807+
808+
for troublemaker in troublemakers:
809+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, troublemaker], max_epochs=5, logger=False)
810+
811+
with pytest.raises(RuntimeError, match="Trouble!"):
812+
trainer.fit(model)
813+
814+
assert set(os.listdir(tmp_path)) == set(expected_ckpts)
815+
767816
@mock.patch("lightning.pytorch.callbacks.model_checkpoint.time")
768817
def test_model_checkpoint_train_time_interval(mock_datetime, tmp_path) -> None:
769818
"""Tests that the checkpoints are saved at the specified time interval."""

0 commit comments

Comments
 (0)