Skip to content

Commit fd3de65

Browse files
committed
add missing test hooks of lighning module to test save chekpoint on exception
1 parent c921875 commit fd3de65

File tree

1 file changed

+92
-49
lines changed

1 file changed

+92
-49
lines changed

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 92 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -837,26 +837,20 @@ def on_train_batch_start(self, batch, batch_idx) -> None:
837837
assert os.path.isfile(tmp_path / "already_saved.ckpt")
838838

839839

840-
class TroubledModelOnTrainEpochStart(BoringModel):
841-
def on_train_epoch_start(self):
842-
if self.current_epoch == 1:
843-
raise RuntimeError("Trouble!")
844-
845-
846-
class TroubledModelOnTrainBatchStart(BoringModel):
847-
def on_train_batch_start(self, batch, batch_idx):
840+
class TroubledModelInTrainingStep(BoringModel):
841+
def training_step(self, batch, batch_idx):
848842
if batch_idx == 1:
849843
raise RuntimeError("Trouble!")
850844

851845

852-
class TroubledModelInTrainingStep(BoringModel):
853-
def training_step(self, batch, batch_idx):
854-
if batch_idx == 1:
846+
class TroubledModelInValidationStep(BoringModel):
847+
def validation_step(self, batch, batch_idx):
848+
if not self.trainer.sanity_checking and batch_idx == 1:
855849
raise RuntimeError("Trouble!")
856850

857851

858-
class TroubledModelOnBeforeZeroGrad(BoringModel):
859-
def on_before_zero_grad(self, optimizer):
852+
class TroubledModelBackward(BoringModel):
853+
def backward(self, loss):
860854
if self.current_epoch == 1:
861855
raise RuntimeError("Trouble!")
862856

@@ -873,22 +867,15 @@ def on_after_backward(self):
873867
raise RuntimeError("Trouble!")
874868

875869

876-
class TroubledModelOnBeforeOptimizerStep(BoringModel):
877-
def on_before_optimizer_step(self, optimizer):
870+
class TroubledModelOnBeforeZeroGrad(BoringModel):
871+
def on_before_zero_grad(self, optimizer):
878872
if self.current_epoch == 1:
879873
raise RuntimeError("Trouble!")
880874

881875

882-
class TroubledModelOnTrainBatchEnd(BoringModel):
883-
def on_train_batch_end(self, outputs, batch, batch_idx):
884-
if batch_idx == 1:
885-
raise RuntimeError("Trouble!")
886-
887-
888-
class TroubledModelOnTrainEpochEnd(BoringModel):
889-
def on_train_epoch_end(self):
890-
if self.current_epoch == 1:
891-
raise RuntimeError("Trouble!")
876+
class TroubledModelOnFitEnd(BoringModel):
877+
def on_fit_end(self):
878+
raise RuntimeError("Trouble!")
892879

893880

894881
class TroubledModelOnTrainEnd(BoringModel):
@@ -902,20 +889,38 @@ def on_validation_start(self):
902889
raise RuntimeError("Trouble!")
903890

904891

905-
class TroubledModelOnValidationEpochStart(BoringModel):
906-
def on_validation_epoch_start(self):
907-
if not self.trainer.sanity_checking and self.current_epoch == 1:
892+
class TroubledModelOnValidationEnd(BoringModel):
893+
def on_validation_end(self):
894+
if not self.trainer.sanity_checking:
908895
raise RuntimeError("Trouble!")
909896

910897

911-
class TroubledModelOnValidationBatchStart(BoringModel):
912-
def on_validation_batch_start(self, batch, batch_idx):
913-
if not self.trainer.sanity_checking and batch_idx == 1:
898+
class TroubledModelOnTrainBatchStart(BoringModel):
899+
def on_train_batch_start(self, batch, batch_idx):
900+
if batch_idx == 1:
914901
raise RuntimeError("Trouble!")
915902

916903

917-
class TroubledModelInValidationStep(BoringModel):
918-
def validation_step(self, batch, batch_idx):
904+
class TroubledModelOnTrainBatchEnd(BoringModel):
905+
def on_train_batch_end(self, outputs, batch, batch_idx):
906+
if batch_idx == 1:
907+
raise RuntimeError("Trouble!")
908+
909+
910+
class TroubledModelOnTrainEpochStart(BoringModel):
911+
def on_train_epoch_start(self):
912+
if self.current_epoch == 1:
913+
raise RuntimeError("Trouble!")
914+
915+
916+
class TroubledModelOnTrainEpochEnd(BoringModel):
917+
def on_train_epoch_end(self):
918+
if self.current_epoch == 1:
919+
raise RuntimeError("Trouble!")
920+
921+
922+
class TroubledModelOnValidationBatchStart(BoringModel):
923+
def on_validation_batch_start(self, batch, batch_idx):
919924
if not self.trainer.sanity_checking and batch_idx == 1:
920925
raise RuntimeError("Trouble!")
921926

@@ -926,44 +931,82 @@ def on_validation_batch_end(self, outputs, batch, batch_idx):
926931
raise RuntimeError("Trouble!")
927932

928933

934+
class TroubledModelOnValidationEpochStart(BoringModel):
935+
def on_validation_epoch_start(self):
936+
if not self.trainer.sanity_checking and self.current_epoch == 1:
937+
raise RuntimeError("Trouble!")
938+
939+
929940
class TroubledModelOnValidationEpochEnd(BoringModel):
930941
def on_validation_epoch_end(self):
931942
if not self.trainer.sanity_checking and self.current_epoch == 1:
932943
raise RuntimeError("Trouble!")
933944

934945

935-
class TroubledModelOnValidationEnd(BoringModel):
936-
def on_validation_end(self):
937-
if not self.trainer.sanity_checking:
946+
class TroubledModelOnValidationModelEval(BoringModel):
947+
def on_validation_model_eval(self):
948+
if not self.trainer.sanity_checking and self.current_epoch == 1:
938949
raise RuntimeError("Trouble!")
939950

940951

941-
class TroubledModelOnFitEnd(BoringModel):
942-
def on_fit_end(self):
943-
raise RuntimeError("Trouble!")
952+
class TroubledModelOnValidationModelTrain(BoringModel):
953+
def on_validation_model_train(self):
954+
if not self.trainer.sanity_checking and self.current_epoch == 1:
955+
raise RuntimeError("Trouble!")
956+
957+
958+
class TroubledModelOnBeforeOptimizerStep(BoringModel):
959+
def on_before_optimizer_step(self, optimizer):
960+
if self.current_epoch == 1:
961+
raise RuntimeError("Trouble!")
962+
963+
964+
class TroubledModelConfigureGradienClipping(BoringModel):
965+
def configure_gradient_clipping(self, optimizer, gradient_clip_val=None, gradient_clip_algorithm=None):
966+
if self.current_epoch == 1:
967+
raise RuntimeError("Trouble!")
968+
969+
970+
class TroubledModelOptimizerStep(BoringModel):
971+
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure=None):
972+
optimizer.step(closure=optimizer_closure)
973+
if self.current_epoch == 1:
974+
raise RuntimeError("Trouble!")
975+
976+
977+
class TroubledModelOptimizerZeroGrad(BoringModel):
978+
def optimizer_zero_grad(self, epoch, batch_idx, optimizer):
979+
if self.current_epoch == 1:
980+
raise RuntimeError("Trouble!")
944981

945982

946983
@pytest.mark.parametrize(
947984
"TroubledModel",
948985
[
949-
TroubledModelOnTrainEpochStart,
950-
TroubledModelOnTrainBatchStart,
951986
TroubledModelInTrainingStep,
952-
TroubledModelOnBeforeZeroGrad,
987+
TroubledModelInValidationStep,
988+
TroubledModelBackward,
953989
TroubledModelOnBeforeBackward,
954990
TroubledModelOnAfterBackward,
955-
TroubledModelOnBeforeOptimizerStep,
956-
TroubledModelOnTrainBatchEnd,
957-
TroubledModelOnTrainEpochEnd,
991+
TroubledModelOnBeforeZeroGrad,
992+
TroubledModelOnFitEnd,
958993
TroubledModelOnTrainEnd,
959994
TroubledModelOnValidationStart,
960-
TroubledModelOnValidationEpochStart,
995+
TroubledModelOnValidationEnd,
996+
TroubledModelOnTrainBatchStart,
997+
TroubledModelOnTrainBatchEnd,
998+
TroubledModelOnTrainEpochStart,
999+
TroubledModelOnTrainEpochEnd,
9611000
TroubledModelOnValidationBatchStart,
962-
TroubledModelInValidationStep,
9631001
TroubledModelOnValidationBatchEnd,
1002+
TroubledModelOnValidationEpochStart,
9641003
TroubledModelOnValidationEpochEnd,
965-
TroubledModelOnValidationEnd,
966-
TroubledModelOnFitEnd,
1004+
TroubledModelOnValidationModelEval,
1005+
TroubledModelOnValidationModelTrain,
1006+
TroubledModelOnBeforeOptimizerStep,
1007+
TroubledModelConfigureGradienClipping,
1008+
TroubledModelOptimizerStep,
1009+
TroubledModelOptimizerZeroGrad,
9671010
],
9681011
)
9691012
def test_model_checkpoint_on_exception_parametrized(tmp_path, TroubledModel):

0 commit comments

Comments
 (0)