Skip to content

Commit d78ea3e

Browse files
committed
add test for exceptions at diffrent position in a model
1 parent 3076ea1 commit d78ea3e

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,49 @@ def test_ckpt_every_n_train_steps(tmp_path):
770770
assert set(os.listdir(tmp_path)) == set(expected)
771771

772772

773+
#################################################################################################
774+
775+
776+
def test_model_checkpoint_on_exception(tmp_path):
777+
"""Test that the checkpoint is saved when an exception is raised in a lightning module."""
778+
779+
class TroubledModelInTrainingStep(BoringModel):
780+
def training_step(self, batch, batch_idx):
781+
if batch_idx == 1:
782+
raise RuntimeError("Trouble!")
783+
784+
class TroubledModelInValidationStep(BoringModel):
785+
def validation_step(self, batch, batch_idx):
786+
if not trainer.sanity_checking and batch_idx == 1:
787+
raise RuntimeError("Trouble!")
788+
789+
models = [TroubledModelInTrainingStep(), TroubledModelInValidationStep()]
790+
791+
for model in models:
792+
checkpoint_callback = ModelCheckpoint(
793+
dirpath=tmp_path, filename=model.__class__.__name__, save_on_exception=True, every_n_epochs=4
794+
)
795+
trainer = Trainer(
796+
default_root_dir=tmp_path,
797+
callbacks=[checkpoint_callback],
798+
limit_train_batches=2,
799+
max_epochs=5,
800+
logger=False,
801+
enable_progress_bar=False,
802+
)
803+
804+
with pytest.raises(RuntimeError, match="Trouble!"):
805+
trainer.fit(model)
806+
807+
checkpoint_path = tmp_path / f"exception-{model.__class__.__name__}.ckpt"
808+
809+
assert os.path.isfile(checkpoint_path)
810+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
811+
assert checkpoint["state_dict"] is not None
812+
assert checkpoint["state_dict"] != {}
813+
814+
815+
#################################################################################################
773816
def test_model_checkpoint_save_on_exception_in_training_step(tmp_path):
774817
"""Test that the checkpoint is saved when an exception is raised in training_step."""
775818

@@ -817,6 +860,8 @@ def validation_step(self, batch, batch_idx):
817860
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
818861

819862

863+
#################################################################################################
864+
820865
CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX = 2
821866
CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH = 21
822867
CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS = 25
@@ -957,6 +1002,9 @@ def test_model_checkpoint_save_on_exception_in_other_callbacks(
9571002
assert checkpoint["global_step"] == expected_checkpoint_global_step
9581003

9591004

1005+
#################################################################################################
1006+
1007+
9601008
@mock.patch("lightning.pytorch.callbacks.model_checkpoint.time")
9611009
def test_model_checkpoint_train_time_interval(mock_datetime, tmp_path) -> None:
9621010
"""Tests that the checkpoints are saved at the specified time interval."""

0 commit comments

Comments
 (0)