Skip to content

Commit c4b8063

Browse files
committed
add test to check saving on exception in all relevalnt callback positions
1 parent 42bbac1 commit c4b8063

File tree

1 file changed

+76
-171
lines changed

1 file changed

+76
-171
lines changed

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 76 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -813,196 +813,101 @@ def validation_step(self, batch, batch_idx):
813813

814814

815815
#################################################################################################
816-
def test_model_checkpoint_save_on_exception_in_training_step(tmp_path):
817-
"""Test that the checkpoint is saved when an exception is raised in training_step."""
816+
def test_model_checkpoint_on_exception_in_other_callbacks(tmp_path):
817+
"""Test that an checkpoint is saved when an exception is raised in an other callback."""
818818

819-
class TroubledModel(BoringModel):
820-
def training_step(self, batch, batch_idx):
819+
class TroubleMakerOnTrainBatchStart(Callback):
820+
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
821821
if batch_idx == 1:
822822
raise RuntimeError("Trouble!")
823823

824-
model = TroubledModel()
825-
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
826-
trainer = Trainer(
827-
default_root_dir=tmp_path,
828-
callbacks=[checkpoint_callback],
829-
max_epochs=5,
830-
logger=False,
831-
enable_progress_bar=False,
832-
)
833-
with pytest.raises(RuntimeError, match="Trouble!"):
834-
trainer.fit(model)
835-
print(os.listdir(tmp_path))
836-
assert os.path.isfile(tmp_path / "step=1.ckpt")
837-
838-
839-
def test_model_checkpoint_save_on_exception_in_validation_step(tmp_path):
840-
"""Test that the checkpoint is saved when an exception is raised in validation_step."""
841-
842-
class TroubledModel(BoringModel):
843-
def validation_step(self, batch, batch_idx):
844-
if not trainer.sanity_checking and batch_idx == 0:
824+
class TroubleMakerOnTrainBatchEnd(Callback):
825+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
826+
if batch_idx == 1:
845827
raise RuntimeError("Trouble!")
846828

847-
model = TroubledModel()
848-
epoch_length = 2
849-
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
850-
trainer = Trainer(
851-
default_root_dir=tmp_path,
852-
callbacks=[checkpoint_callback],
853-
max_epochs=5,
854-
limit_train_batches=epoch_length,
855-
logger=False,
856-
enable_progress_bar=False,
857-
)
858-
with pytest.raises(RuntimeError, match="Trouble!"):
859-
trainer.fit(model)
860-
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
861-
862-
863-
#################################################################################################
864-
865-
CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX = 2
866-
CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH = 21
867-
CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS = 25
868-
CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES = 4
869-
assert CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX < CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES
870-
assert CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH < CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS
871-
872-
873-
class TroublemakerOnTrainBatchStart(Callback):
874-
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
875-
if batch_idx == CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX:
876-
raise RuntimeError("Trouble!")
877-
878-
879-
class TroublemakerOnTrainBatchEnd(Callback):
880-
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
881-
if batch_idx == CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX:
882-
raise RuntimeError("Trouble!")
883-
884-
885-
class TroublemakerOnTrainEpochStart(Callback):
886-
def on_train_epoch_start(self, trainer, pl_module):
887-
if trainer.current_epoch == CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH:
888-
raise RuntimeError("Trouble!")
889-
890-
891-
class TroublemakerOnTrainEpochEnd(Callback):
892-
def on_train_epoch_end(self, trainer, pl_module):
893-
if trainer.current_epoch == CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH:
894-
raise RuntimeError("Trouble!")
895-
896-
897-
class TroublemakerOnTrainEnd(Callback):
898-
def on_train_end(self, trainer, pl_module):
899-
raise RuntimeError("Trouble!")
900-
901-
902-
class TroublemakerOnValidationBatchStart(Callback):
903-
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx):
904-
if not trainer.sanity_checking and batch_idx == 1:
905-
raise RuntimeError("Trouble!")
829+
class TroubleMakerOnTrainEpochStart(Callback):
830+
def on_train_epoch_start(self, trainer, pl_module):
831+
if trainer.current_epoch == 1:
832+
raise RuntimeError("Trouble!")
906833

834+
class TroubleMakerOnTrainEpochEnd(Callback):
835+
def on_train_epoch_end(self, trainer, pl_module):
836+
if trainer.current_epoch == 1:
837+
raise RuntimeError("Trouble!")
907838

908-
class TroublemakerOnValidationBatchEnd(Callback):
909-
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
910-
if not trainer.sanity_checking and batch_idx == 1:
839+
class TroubleMakerOnTrainEnd(Callback):
840+
def on_train_end(self, trainer, pl_module):
911841
raise RuntimeError("Trouble!")
912842

843+
class TroubleMakerOnValidationBatchStart(Callback):
844+
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx):
845+
if not trainer.sanity_checking and batch_idx == 1:
846+
raise RuntimeError("Trouble!")
913847

914-
class TroublemakerOnValidationEpochStart(Callback):
915-
def on_validation_epoch_start(self, trainer, pl_module):
916-
if not trainer.sanity_checking and trainer.current_epoch == CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH:
917-
raise RuntimeError("Trouble!")
918-
848+
class TroubleMakerOnValidationBatchEnd(Callback):
849+
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
850+
if not trainer.sanity_checking and batch_idx == 1:
851+
raise RuntimeError("Trouble!")
919852

920-
class TroublemakerOnValidationEpochEnd(Callback):
921-
def on_validation_epoch_end(self, trainer, pl_module):
922-
if not trainer.sanity_checking and trainer.current_epoch == CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH:
923-
raise RuntimeError("Trouble!")
853+
class TroubleMakerOnValidationEpochStart(Callback):
854+
def on_validation_epoch_start(self, trainer, pl_module):
855+
if not trainer.sanity_checking and trainer.current_epoch == 1:
856+
raise RuntimeError("Trouble!")
924857

858+
class TroubleMakerOnValidationEpochEnd(Callback):
859+
def on_validation_epoch_end(self, trainer, pl_module):
860+
if not trainer.sanity_checking and trainer.current_epoch == 1:
861+
raise RuntimeError("Trouble!")
925862

926-
class TroublemakerOnValidationStart(Callback):
927-
def on_validation_start(self, trainer, pl_module):
928-
if not trainer.sanity_checking:
929-
raise RuntimeError("Trouble!")
863+
class TroubleMakerOnValidationStart(Callback):
864+
def on_validation_start(self, trainer, pl_module):
865+
if not trainer.sanity_checking:
866+
raise RuntimeError("Trouble!")
930867

868+
class TroubleMakerOnValidationEnd(Callback):
869+
def on_validation_end(self, trainer, pl_module):
870+
if not trainer.sanity_checking:
871+
raise RuntimeError("Trouble!")
931872

932-
class TroublemakerOnValidationEnd(Callback):
933-
def on_validation_end(self, trainer, pl_module):
934-
if not trainer.sanity_checking:
873+
class TroubleMakerOnFitEnd(Callback):
874+
def on_fit_end(self, trainer, pl_module):
935875
raise RuntimeError("Trouble!")
936876

877+
troubled_callbacks = [
878+
TroubleMakerOnTrainBatchStart(),
879+
TroubleMakerOnTrainBatchEnd(),
880+
TroubleMakerOnTrainEpochStart(),
881+
TroubleMakerOnTrainEpochEnd(),
882+
TroubleMakerOnTrainEnd(),
883+
TroubleMakerOnValidationBatchStart(),
884+
TroubleMakerOnValidationBatchEnd(),
885+
TroubleMakerOnValidationEpochStart(),
886+
TroubleMakerOnValidationEpochEnd(),
887+
TroubleMakerOnValidationStart(),
888+
TroubleMakerOnValidationEnd(),
889+
TroubleMakerOnFitEnd(),
890+
]
937891

938-
@pytest.mark.parametrize(
939-
("TroubledCallback", "expected_checkpoint_global_step"),
940-
[
941-
pytest.param(
942-
TroublemakerOnTrainBatchStart, CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX, id="on_train_batch_start"
943-
),
944-
pytest.param(
945-
TroublemakerOnTrainBatchEnd, CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX + 1, id="on_train_batch_end"
946-
),
947-
pytest.param(
948-
TroublemakerOnTrainEpochStart,
949-
CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH * CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES,
950-
id="on_train_epoch_start",
951-
),
952-
pytest.param(
953-
TroublemakerOnTrainEpochEnd,
954-
(CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH + 1) * CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES,
955-
id="on_train_epoch_end",
956-
),
957-
pytest.param(
958-
TroublemakerOnTrainEnd,
959-
CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS * CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES,
960-
id="on_train_end",
961-
),
962-
pytest.param(
963-
TroublemakerOnValidationBatchStart, CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, id="on_validation_batch_start"
964-
),
965-
pytest.param(
966-
TroublemakerOnValidationBatchEnd, CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, id="on_validation_batch_end"
967-
),
968-
pytest.param(
969-
TroublemakerOnValidationEpochStart,
970-
(CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH + 1) * CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES,
971-
id="on_validation_epoch_start",
972-
),
973-
pytest.param(
974-
TroublemakerOnValidationEpochEnd,
975-
(CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH + 1) * CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES,
976-
id="on_validation_epoch_end",
977-
),
978-
pytest.param(TroublemakerOnValidationStart, CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, id="on_validation_start"),
979-
pytest.param(TroublemakerOnValidationEnd, CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, id="on_validation_end"),
980-
],
981-
)
982-
def test_model_checkpoint_save_on_exception_in_other_callbacks(
983-
tmp_path, TroubledCallback, expected_checkpoint_global_step
984-
):
985-
"""Test that an checkpoint is saved when an exception is raised in an other callback."""
986-
987-
model = BoringModel()
988-
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
989-
trainer = Trainer(
990-
default_root_dir=tmp_path,
991-
callbacks=[checkpoint_callback, TroubledCallback()],
992-
max_epochs=CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS,
993-
limit_train_batches=CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES,
994-
logger=False,
995-
enable_progress_bar=False,
996-
)
997-
with pytest.raises(RuntimeError, match="Trouble!"):
998-
trainer.fit(model)
999-
1000-
assert os.path.isfile(tmp_path / f"step={expected_checkpoint_global_step}.ckpt")
1001-
checkpoint = torch.load(tmp_path / f"step={expected_checkpoint_global_step}.ckpt", weights_only=True)
1002-
assert checkpoint["global_step"] == expected_checkpoint_global_step
1003-
1004-
1005-
#################################################################################################
892+
for troubled_callback in troubled_callbacks:
893+
model = BoringModel()
894+
checkpoint_callback = ModelCheckpoint(
895+
dirpath=tmp_path, filename=troubled_callback.__class__.__name__, save_on_exception=True, every_n_epochs=5
896+
)
897+
trainer = Trainer(
898+
default_root_dir=tmp_path,
899+
callbacks=[checkpoint_callback, troubled_callback],
900+
max_epochs=4,
901+
limit_train_batches=2,
902+
logger=False,
903+
enable_progress_bar=False,
904+
)
905+
with pytest.raises(RuntimeError, match="Trouble!"):
906+
trainer.fit(model)
907+
assert os.path.isfile(tmp_path / f"exception-{troubled_callback.__class__.__name__}.ckpt")
908+
checkpoint = torch.load(tmp_path / f"exception-{troubled_callback.__class__.__name__}.ckpt", map_location="cpu")
909+
assert checkpoint["state_dict"] is not None
910+
assert checkpoint["state_dict"] != {}
1006911

1007912

1008913
@mock.patch("lightning.pytorch.callbacks.model_checkpoint.time")

0 commit comments

Comments
 (0)