Skip to content

Commit 34e598a

Browse files
committed
add for saving checkpoint on exeption if the exception occurs in a validation callback
1 parent 7d750e6 commit 34e598a

File tree

1 file changed

+84
-50
lines changed

1 file changed

+84
-50
lines changed

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 84 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -827,66 +827,100 @@ def on_train_epoch_end(self, trainer, pl_module):
827827
assert os.path.isfile(tmp_path / f"step={2*epoch_length}.ckpt")
828828

829829

830-
# def test_model_checkpoint_save_on_exception_in_train_callback(tmp_path):
831-
# """Test that the checkpoint is saved when an exception is raised in a callback on different events."""
832-
# class TroublemakerOnTrainBatchStart(Callback):
833-
# def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
834-
# if batch_idx == 1:
835-
# raise RuntimeError("Trouble!")
830+
def test_model_checkpoint_save_on_exception_in_val_callback(tmp_path):
831+
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_start."""
832+
class TroublemakerOnValidationBatchStart(Callback):
833+
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx):
834+
if not trainer.sanity_checking and batch_idx == 1:
835+
raise RuntimeError("Trouble!")
836836

837-
# class TroublemakerOnTrainBatchEnd(Callback):
838-
# def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
839-
# if batch_idx == 1:
840-
# raise RuntimeError("Trouble!")
841-
842-
# class TroublemakerOnTrainEpochStart(Callback):
843-
# def on_train_epoch_start(self, trainer, pl_module):
844-
# if trainer.current_epoch == 1:
845-
# raise RuntimeError("Trouble!")
846-
847-
# class TroublemakerOnTrainEpochEnd(Callback):
848-
# def on_train_epoch_end(self, trainer, pl_module):
849-
# if trainer.current_epoch == 1:
850-
# raise RuntimeError("Trouble!")
851-
852-
853-
# epoch_length = 64
854-
# model = BoringModel()
855-
# # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints
856-
# checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
857-
858-
# troublemakers = [
859-
# TroublemakerOnTrainBatchStart(),
860-
# TroublemakerOnTrainBatchEnd(),
861-
# TroublemakerOnTrainEpochStart(),
862-
# TroublemakerOnTrainEpochEnd()
863-
# ]
837+
model = BoringModel()
838+
epoch_length = 64
839+
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
840+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationBatchStart()], max_epochs=5, logger=False)
841+
with pytest.raises(RuntimeError, match="Trouble!"):
842+
trainer.fit(model)
843+
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
864844

865-
# expected_ckpts = ["step=1.ckpt",
866-
# 'step=2.ckpt',
867-
# f'step={epoch_length}.ckpt',
868-
# f'step={2*epoch_length}.ckpt',
869-
# ]
870845

871-
# for troublemaker in troublemakers:
872-
# trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, troublemaker], max_epochs=5, logger=False)
846+
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_batch_end(tmp_path):
847+
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_end."""
848+
class TroublemakerOnValidationBatchEnd(Callback):
849+
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
850+
if not trainer.sanity_checking and batch_idx == 1:
851+
raise RuntimeError("Trouble!")
852+
853+
model = BoringModel()
854+
epoch_length = 64
855+
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
856+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationBatchEnd()], max_epochs=5, logger=False)
857+
with pytest.raises(RuntimeError, match="Trouble!"):
858+
trainer.fit(model)
859+
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
873860

874-
# with pytest.raises(RuntimeError, match="Trouble!"):
875-
# trainer.fit(model)
876861

877-
# assert set(os.listdir(tmp_path)) == set(expected_ckpts)
862+
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_start(tmp_path):
863+
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_start."""
864+
class TroublemakerOnValidationEpochStart(Callback):
865+
def on_validation_epoch_start(self, trainer, pl_module):
866+
if not trainer.sanity_checking and trainer.current_epoch == 0:
867+
raise RuntimeError("Trouble!")
878868

869+
model = BoringModel()
870+
epoch_length = 64
871+
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
872+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochStart()], max_epochs=5, logger=False)
873+
with pytest.raises(RuntimeError, match="Trouble!"):
874+
trainer.fit(model)
875+
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
876+
879877

880-
# # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints
881-
# checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
878+
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_end(tmp_path):
879+
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_end."""
880+
class TroublemakerOnValidationEpochEnd(Callback):
881+
def on_validation_epoch_end(self, trainer, pl_module):
882+
if not trainer.sanity_checking and trainer.current_epoch == 0:
883+
raise RuntimeError("Trouble!")
884+
885+
model = BoringModel()
886+
epoch_length = 64
887+
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
888+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochEnd()], max_epochs=5, logger=False)
889+
with pytest.raises(RuntimeError, match="Trouble!"):
890+
trainer.fit(model)
891+
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
882892

883-
# troublemakers = [
884-
# # TroublemakerOnValidationBatchStart(),
885-
# TroublemakerOnValidationBatchEnd(),
886-
# expected_ckpts = [f"step={2*epoch_length}.ckpt",
893+
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_start(tmp_path):
894+
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_start."""
895+
class TroublemakerOnValidationStart(Callback):
896+
def on_validation_start(self, trainer, pl_module):
897+
if not trainer.sanity_checking:
898+
raise RuntimeError("Trouble!")
899+
900+
model = BoringModel()
901+
epoch_length = 64
902+
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
903+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationStart()], max_epochs=5, logger=False)
904+
with pytest.raises(RuntimeError, match="Trouble!"):
905+
trainer.fit(model)
906+
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
887907

888-
assert set(os.listdir(tmp_path)) == set(expected_ckpts)
908+
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_end(tmp_path):
909+
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_end."""
910+
class TroublemakerOnValidationEnd(Callback):
911+
def on_validation_end(self, trainer, pl_module):
912+
if not trainer.sanity_checking:
913+
raise RuntimeError("Trouble!")
914+
915+
model = BoringModel()
916+
epoch_length = 64
917+
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
918+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEnd()], max_epochs=5, logger=False)
919+
with pytest.raises(RuntimeError, match="Trouble!"):
920+
trainer.fit(model)
921+
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
889922

923+
890924
@mock.patch("lightning.pytorch.callbacks.model_checkpoint.time")
891925
def test_model_checkpoint_train_time_interval(mock_datetime, tmp_path) -> None:
892926
"""Tests that the checkpoints are saved at the specified time interval."""

0 commit comments

Comments
 (0)