Skip to content

Commit c092385

Browse files
committed
checkpoint on exception put callback tests into a pytest prametrization
1 parent 99af7ed commit c092385

File tree

1 file changed

+93
-211
lines changed

1 file changed

+93
-211
lines changed

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 93 additions & 211 deletions
Original file line numberDiff line numberDiff line change
@@ -811,241 +811,123 @@ def validation_step(self, batch, batch_idx):
811811
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
812812

813813

814-
def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_start(tmp_path):
815-
"""Test that the checkpoint is saved when an exception is raised in a callback on train_batch_start."""
814+
CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX = 2
815+
CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH = 21
816+
CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS = 25
817+
CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES = 4
818+
assert CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX < CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES
819+
assert CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH < CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS
816820

817-
class TroublemakerOnTrainBatchStart(Callback):
818-
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
819-
if batch_idx == 1:
820-
raise RuntimeError("Trouble!")
821-
822-
model = BoringModel()
823-
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
824-
trainer = Trainer(
825-
default_root_dir=tmp_path,
826-
callbacks=[checkpoint_callback, TroublemakerOnTrainBatchStart()],
827-
max_epochs=5,
828-
logger=False,
829-
enable_progress_bar=False,
830-
)
831-
with pytest.raises(RuntimeError, match="Trouble!"):
832-
trainer.fit(model)
833-
assert os.path.isfile(tmp_path / "step=1.ckpt")
834-
835-
836-
def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_end(tmp_path):
837-
"""Test that the checkpoint is saved when an exception is raised in a callback on train_batch_end."""
838-
839-
class TroublemakerOnTrainBatchEnd(Callback):
840-
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
841-
if batch_idx == 1:
842-
raise RuntimeError("Trouble!")
843-
844-
model = BoringModel()
845-
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
846-
trainer = Trainer(
847-
default_root_dir=tmp_path,
848-
callbacks=[checkpoint_callback, TroublemakerOnTrainBatchEnd()],
849-
max_epochs=5,
850-
logger=False,
851-
enable_progress_bar=False,
852-
)
853-
with pytest.raises(RuntimeError, match="Trouble!"):
854-
trainer.fit(model)
855-
856-
assert os.path.isfile(tmp_path / "step=2.ckpt")
857-
858-
859-
def test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_start(tmp_path):
860-
"""Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_start."""
861-
862-
class TroublemakerOnTrainEpochStart(Callback):
863-
def on_train_epoch_start(self, trainer, pl_module):
864-
if trainer.current_epoch == 1:
865-
raise RuntimeError("Trouble!")
866-
867-
model = BoringModel()
868-
epoch_length = 2
869-
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
870-
trainer = Trainer(
871-
default_root_dir=tmp_path,
872-
callbacks=[checkpoint_callback, TroublemakerOnTrainEpochStart()],
873-
max_epochs=5,
874-
limit_train_batches=epoch_length,
875-
logger=False,
876-
enable_progress_bar=False,
877-
)
878-
with pytest.raises(RuntimeError, match="Trouble!"):
879-
trainer.fit(model)
880-
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
881821

882-
883-
def test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_end(tmp_path):
884-
"""Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_end."""
885-
886-
class TroublemakerOnTrainEpochEnd(Callback):
887-
def on_train_epoch_end(self, trainer, pl_module):
888-
if trainer.current_epoch == 1:
889-
raise RuntimeError("Trouble!")
890-
891-
model = BoringModel()
892-
epoch_length = 2
893-
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
894-
trainer = Trainer(
895-
default_root_dir=tmp_path,
896-
callbacks=[checkpoint_callback, TroublemakerOnTrainEpochEnd()],
897-
max_epochs=5,
898-
limit_train_batches=epoch_length,
899-
logger=False,
900-
enable_progress_bar=False,
901-
)
902-
with pytest.raises(RuntimeError, match="Trouble!"):
903-
trainer.fit(model)
904-
assert os.path.isfile(tmp_path / f"step={2 * epoch_length}.ckpt")
905-
906-
907-
def test_model_checkpoint_save_on_exception_in_val_callback(tmp_path):
908-
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_start."""
909-
910-
class TroublemakerOnValidationBatchStart(Callback):
911-
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx):
912-
if not trainer.sanity_checking and batch_idx == 1:
913-
raise RuntimeError("Trouble!")
914-
915-
model = BoringModel()
916-
epoch_length = 64
917-
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
918-
trainer = Trainer(
919-
default_root_dir=tmp_path,
920-
callbacks=[checkpoint_callback, TroublemakerOnValidationBatchStart()],
921-
max_epochs=5,
922-
logger=False,
923-
enable_progress_bar=False,
924-
)
925-
with pytest.raises(RuntimeError, match="Trouble!"):
926-
trainer.fit(model)
927-
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
822+
class TroublemakerOnTrainBatchStart(Callback):
823+
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
824+
if batch_idx == CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX:
825+
raise RuntimeError("Trouble!")
928826

929827

930-
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_batch_end(tmp_path):
931-
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_end."""
828+
class TroublemakerOnTrainBatchEnd(Callback):
829+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
830+
if batch_idx == CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX:
831+
raise RuntimeError("Trouble!")
932832

933-
class TroublemakerOnValidationBatchEnd(Callback):
934-
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
935-
if not trainer.sanity_checking and batch_idx == 1:
936-
raise RuntimeError("Trouble!")
937833

938-
model = BoringModel()
939-
epoch_length = 64
940-
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
941-
trainer = Trainer(
942-
default_root_dir=tmp_path,
943-
callbacks=[checkpoint_callback, TroublemakerOnValidationBatchEnd()],
944-
max_epochs=5,
945-
logger=False,
946-
enable_progress_bar=False,
947-
)
948-
with pytest.raises(RuntimeError, match="Trouble!"):
949-
trainer.fit(model)
950-
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
834+
class TroublemakerOnTrainEpochStart(Callback):
835+
def on_train_epoch_start(self, trainer, pl_module):
836+
if trainer.current_epoch == CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH:
837+
raise RuntimeError("Trouble!")
951838

952839

953-
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_start(tmp_path):
954-
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_start."""
840+
class TroublemakerOnTrainEpochEnd(Callback):
841+
def on_train_epoch_end(self, trainer, pl_module):
842+
if trainer.current_epoch == CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH:
843+
raise RuntimeError("Trouble!")
955844

956-
class TroublemakerOnValidationEpochStart(Callback):
957-
def on_validation_epoch_start(self, trainer, pl_module):
958-
if not trainer.sanity_checking and trainer.current_epoch == 0:
959-
raise RuntimeError("Trouble!")
960845

961-
model = BoringModel()
962-
epoch_length = 2
963-
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
964-
trainer = Trainer(
965-
default_root_dir=tmp_path,
966-
callbacks=[checkpoint_callback, TroublemakerOnValidationEpochStart()],
967-
max_epochs=5,
968-
limit_train_batches=epoch_length,
969-
logger=False,
970-
enable_progress_bar=False,
971-
)
972-
with pytest.raises(RuntimeError, match="Trouble!"):
973-
trainer.fit(model)
974-
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
846+
class TroublemakerOnTrainEnd(Callback):
847+
def on_train_end(self, trainer, pl_module):
848+
raise RuntimeError("Trouble!")
975849

976850

977-
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_end(tmp_path):
978-
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_end."""
851+
class TroublemakerOnValidationBatchStart(Callback):
852+
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx):
853+
if not trainer.sanity_checking and batch_idx == 1:
854+
raise RuntimeError("Trouble!")
979855

980-
class TroublemakerOnValidationEpochEnd(Callback):
981-
def on_validation_epoch_end(self, trainer, pl_module):
982-
if not trainer.sanity_checking and trainer.current_epoch == 0:
983-
raise RuntimeError("Trouble!")
984856

985-
model = BoringModel()
986-
epoch_length = 2
987-
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
988-
trainer = Trainer(
989-
default_root_dir=tmp_path,
990-
callbacks=[checkpoint_callback, TroublemakerOnValidationEpochEnd()],
991-
max_epochs=5,
992-
limit_train_batches=epoch_length,
993-
logger=False,
994-
enable_progress_bar=False,
995-
)
996-
with pytest.raises(RuntimeError, match="Trouble!"):
997-
trainer.fit(model)
998-
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
857+
class TroublemakerOnValidationBatchEnd(Callback):
858+
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
859+
if not trainer.sanity_checking and batch_idx == 1:
860+
raise RuntimeError("Trouble!")
999861

1000862

1001-
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_start(tmp_path):
1002-
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_start."""
863+
class TroublemakerOnValidationEpochStart(Callback):
864+
def on_validation_epoch_start(self, trainer, pl_module):
865+
if not trainer.sanity_checking and trainer.current_epoch == CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH:
866+
raise RuntimeError("Trouble!")
1003867

1004-
class TroublemakerOnValidationStart(Callback):
1005-
def on_validation_start(self, trainer, pl_module):
1006-
if not trainer.sanity_checking:
1007-
raise RuntimeError("Trouble!")
1008868

1009-
model = BoringModel()
1010-
epoch_length = 2
1011-
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
1012-
trainer = Trainer(
1013-
default_root_dir=tmp_path,
1014-
callbacks=[checkpoint_callback, TroublemakerOnValidationStart()],
1015-
max_epochs=5,
1016-
limit_train_batches=epoch_length,
1017-
logger=False,
1018-
enable_progress_bar=False,
1019-
)
1020-
with pytest.raises(RuntimeError, match="Trouble!"):
1021-
trainer.fit(model)
1022-
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
869+
class TroublemakerOnValidationEpochEnd(Callback):
870+
def on_validation_epoch_end(self, trainer, pl_module):
871+
if not trainer.sanity_checking and trainer.current_epoch == CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH:
872+
raise RuntimeError("Trouble!")
1023873

1024874

1025-
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_end(tmp_path):
1026-
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_end."""
875+
class TroublemakerOnValidationStart(Callback):
876+
def on_validation_start(self, trainer, pl_module):
877+
if not trainer.sanity_checking:
878+
raise RuntimeError("Trouble!")
1027879

1028-
class TroublemakerOnValidationEnd(Callback):
1029-
def on_validation_end(self, trainer, pl_module):
1030-
if not trainer.sanity_checking:
1031-
raise RuntimeError("Trouble!")
1032880

1033-
model = BoringModel()
1034-
epoch_length = 2
1035-
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
1036-
trainer = Trainer(
1037-
default_root_dir=tmp_path,
1038-
callbacks=[checkpoint_callback, TroublemakerOnValidationEnd()],
1039-
max_epochs=5,
1040-
limit_train_batches=epoch_length,
1041-
logger=False,
1042-
enable_progress_bar=False,
1043-
)
1044-
with pytest.raises(RuntimeError, match="Trouble!"):
1045-
trainer.fit(model)
1046-
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
881+
class TroublemakerOnValidationEnd(Callback):
882+
def on_validation_end(self, trainer, pl_module):
883+
if not trainer.sanity_checking:
884+
raise RuntimeError("Trouble!")
1047885

1048886

887+
@pytest.mark.parametrize(
888+
("TroubledCallback", "expected_checkpoint_global_step"),
889+
[
890+
pytest.param(
891+
TroublemakerOnTrainBatchStart, CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX, id="on_train_batch_start"
892+
),
893+
pytest.param(
894+
TroublemakerOnTrainBatchEnd, CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX + 1, id="on_train_batch_end"
895+
),
896+
pytest.param(
897+
TroublemakerOnTrainEpochStart,
898+
CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH * CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES,
899+
id="on_train_epoch_start",
900+
),
901+
pytest.param(
902+
TroublemakerOnTrainEpochEnd,
903+
(CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH + 1) * CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES,
904+
id="on_train_epoch_end",
905+
),
906+
pytest.param(
907+
TroublemakerOnTrainEnd,
908+
CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS * CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES,
909+
id="on_train_end",
910+
),
911+
pytest.param(
912+
TroublemakerOnValidationBatchStart, CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, id="on_validation_batch_start"
913+
),
914+
pytest.param(
915+
TroublemakerOnValidationBatchEnd, CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, id="on_validation_batch_end"
916+
),
917+
pytest.param(
918+
TroublemakerOnValidationEpochStart,
919+
(CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH + 1) * CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES,
920+
id="on_validation_epoch_start",
921+
),
922+
pytest.param(
923+
TroublemakerOnValidationEpochEnd,
924+
(CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH + 1) * CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES,
925+
id="on_validation_epoch_end",
926+
),
927+
pytest.param(TroublemakerOnValidationStart, CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, id="on_validation_start"),
928+
pytest.param(TroublemakerOnValidationEnd, CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, id="on_validation_end"),
929+
],
930+
)
1049931
@mock.patch("lightning.pytorch.callbacks.model_checkpoint.time")
1050932
def test_model_checkpoint_train_time_interval(mock_datetime, tmp_path) -> None:
1051933
"""Tests that the checkpoints are saved at the specified time interval."""

0 commit comments

Comments
 (0)