Skip to content

Commit 9e9e580

Browse files
committed
test checkpointing on exception in varoius model steps
1 parent 2ca6dab commit 9e9e580

File tree

1 file changed

+101
-4
lines changed

1 file changed

+101
-4
lines changed

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -776,27 +776,124 @@ def test_ckpt_every_n_train_steps(tmp_path):
776776
def test_model_checkpoint_on_exception(tmp_path):
777777
"""Test that the checkpoint is saved when an exception is raised in a lightning module."""
778778

779+
class TroubledModelOnTrainEpochStart(BoringModel):
780+
def on_train_epoch_start(self):
781+
if self.current_epoch == 1:
782+
raise RuntimeError("Trouble!")
783+
784+
class TroubledModelOnTrainBatchStart(BoringModel):
785+
def on_train_batch_start(self, batch, batch_idx):
786+
if batch_idx == 1:
787+
raise RuntimeError("Trouble!")
788+
779789
class TroubledModelInTrainingStep(BoringModel):
780790
def training_step(self, batch, batch_idx):
781791
if batch_idx == 1:
782792
raise RuntimeError("Trouble!")
783793

794+
class TroubledModelOnBeforeZeroGrad(BoringModel):
795+
def on_before_zero_grad(self, optimizer):
796+
if self.current_epoch == 1:
797+
raise RuntimeError("Trouble!")
798+
799+
class TroubledModelOnBeforeBackward(BoringModel):
800+
def on_before_backward(self, loss):
801+
if self.current_epoch == 1:
802+
raise RuntimeError("Trouble!")
803+
804+
class TroubledModelOnAfterBackward(BoringModel):
805+
def on_after_backward(self):
806+
if self.current_epoch == 1:
807+
raise RuntimeError("Trouble!")
808+
809+
class TroubledModelOnBeforeOptimizerStep(BoringModel):
810+
def on_before_optimizer_step(self, optimizer):
811+
if self.current_epoch == 1:
812+
raise RuntimeError("Trouble!")
813+
814+
class TroubledModelOnTrainBatchEnd(BoringModel):
815+
def on_train_batch_end(self, outputs, batch, batch_idx):
816+
if batch_idx == 1:
817+
raise RuntimeError("Trouble!")
818+
819+
class TroubledModelOnTrainEpochEnd(BoringModel):
820+
def on_train_epoch_end(self):
821+
if self.current_epoch == 1:
822+
raise RuntimeError("Trouble!")
823+
824+
class TroubledModelOnTrainEnd(BoringModel):
825+
def on_train_end(self):
826+
raise RuntimeError("Trouble!")
827+
828+
class TroubledModelOnValidationStart(BoringModel):
829+
def on_validation_start(self):
830+
if not self.trainer.sanity_checking and self.current_epoch == 1:
831+
raise RuntimeError("Trouble!")
832+
833+
class TroubledModelOnValidationEpochStart(BoringModel):
834+
def on_validation_epoch_start(self):
835+
if not self.trainer.sanity_checking and self.current_epoch == 1:
836+
raise RuntimeError("Trouble!")
837+
838+
class TroubledModelOnValidationBatchStart(BoringModel):
839+
def on_validation_batch_start(self, batch, batch_idx):
840+
if not self.trainer.sanity_checking and batch_idx == 1:
841+
raise RuntimeError("Trouble!")
842+
784843
class TroubledModelInValidationStep(BoringModel):
785844
def validation_step(self, batch, batch_idx):
786-
if not trainer.sanity_checking and batch_idx == 1:
845+
if not self.trainer.sanity_checking and batch_idx == 1:
846+
raise RuntimeError("Trouble!")
847+
848+
class TroubledModelOnValidationBatchEnd(BoringModel):
849+
def on_validation_batch_end(self, outputs, batch, batch_idx):
850+
if not self.trainer.sanity_checking and batch_idx == 1:
851+
raise RuntimeError("Trouble!")
852+
853+
class TroubledModelOnValidationEpochEnd(BoringModel):
854+
def on_validation_epoch_end(self):
855+
if not self.trainer.sanity_checking and self.current_epoch == 1:
856+
raise RuntimeError("Trouble!")
857+
858+
class TroubledModelOnValidationEnd(BoringModel):
859+
def on_validation_end(self):
860+
if not self.trainer.sanity_checking:
787861
raise RuntimeError("Trouble!")
788862

789-
models = [TroubledModelInTrainingStep(), TroubledModelInValidationStep()]
863+
class TroubledModelOnFitEnd(BoringModel):
864+
def on_fit_end(self):
865+
raise RuntimeError("Trouble!")
866+
867+
models = [
868+
TroubledModelOnTrainEpochStart(),
869+
TroubledModelOnTrainBatchStart(),
870+
TroubledModelInTrainingStep(),
871+
TroubledModelOnBeforeZeroGrad(),
872+
TroubledModelOnBeforeBackward(),
873+
TroubledModelOnAfterBackward(),
874+
TroubledModelOnBeforeOptimizerStep(),
875+
TroubledModelOnTrainBatchEnd(),
876+
TroubledModelOnTrainEpochEnd(),
877+
TroubledModelOnTrainEnd(),
878+
TroubledModelOnValidationStart(),
879+
TroubledModelOnValidationEpochStart(),
880+
TroubledModelOnValidationBatchStart(),
881+
TroubledModelInValidationStep(),
882+
TroubledModelOnValidationBatchEnd(),
883+
TroubledModelOnValidationEpochEnd(),
884+
TroubledModelOnValidationEnd(),
885+
TroubledModelOnFitEnd(),
886+
]
790887

791888
for model in models:
792889
checkpoint_callback = ModelCheckpoint(
793-
dirpath=tmp_path, filename=model.__class__.__name__, save_on_exception=True, every_n_epochs=4
890+
dirpath=tmp_path, filename=model.__class__.__name__, save_on_exception=True, every_n_epochs=5
794891
)
795892
trainer = Trainer(
796893
default_root_dir=tmp_path,
797894
callbacks=[checkpoint_callback],
798895
limit_train_batches=2,
799-
max_epochs=5,
896+
max_epochs=4,
800897
logger=False,
801898
enable_progress_bar=False,
802899
)

0 commit comments

Comments
 (0)