Skip to content

Commit c921875

Browse files
committed
add missing callback hooks for Test Troubled Callback and order them according to documentation
1 parent 09ba24f commit c921875

File tree

1 file changed

+51
-20
lines changed

1 file changed

+51
-20
lines changed

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,11 @@ def test_model_checkpoint_on_exception_parametrized(tmp_path, TroubledModel):
994994
assert checkpoint["state_dict"] != {}
995995

996996

997+
class TroubledCallbackOnFitEnd(Callback):
998+
def on_fit_end(self, trainer, pl_module):
999+
raise RuntimeError("Trouble!")
1000+
1001+
9971002
class TroubledCallbackOnTrainBatchStart(Callback):
9981003
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
9991004
if batch_idx == 1:
@@ -1018,9 +1023,16 @@ def on_train_epoch_end(self, trainer, pl_module):
10181023
raise RuntimeError("Trouble!")
10191024

10201025

1021-
class TroubledCallbackOnTrainEnd(Callback):
1022-
def on_train_end(self, trainer, pl_module):
1023-
raise RuntimeError("Trouble!")
1026+
class TroubledCallbackOnValidationEpochStart(Callback):
1027+
def on_validation_epoch_start(self, trainer, pl_module):
1028+
if not trainer.sanity_checking and trainer.current_epoch == 1:
1029+
raise RuntimeError("Trouble!")
1030+
1031+
1032+
class TroubledCallbackOnValidationEpochEnd(Callback):
1033+
def on_validation_epoch_end(self, trainer, pl_module):
1034+
if not trainer.sanity_checking and trainer.current_epoch == 1:
1035+
raise RuntimeError("Trouble!")
10241036

10251037

10261038
class TroubledCallbackOnValidationBatchStart(Callback):
@@ -1035,16 +1047,9 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx)
10351047
raise RuntimeError("Trouble!")
10361048

10371049

1038-
class TroubledCallbackOnValidationEpochStart(Callback):
1039-
def on_validation_epoch_start(self, trainer, pl_module):
1040-
if not trainer.sanity_checking and trainer.current_epoch == 1:
1041-
raise RuntimeError("Trouble!")
1042-
1043-
1044-
class TroubledCallbackOnValidationEpochEnd(Callback):
1045-
def on_validation_epoch_end(self, trainer, pl_module):
1046-
if not trainer.sanity_checking and trainer.current_epoch == 1:
1047-
raise RuntimeError("Trouble!")
1050+
class TroubledCallbackOnTrainEnd(Callback):
1051+
def on_train_end(self, trainer, pl_module):
1052+
raise RuntimeError("Trouble!")
10481053

10491054

10501055
class TroubledCallbackOnValidationStart(Callback):
@@ -1059,26 +1064,52 @@ def on_validation_end(self, trainer, pl_module):
10591064
raise RuntimeError("Trouble!")
10601065

10611066

1062-
class TroubledCallbackOnFitEnd(Callback):
1063-
def on_fit_end(self, trainer, pl_module):
1064-
raise RuntimeError("Trouble!")
1067+
class TroubleCallbackOnBeforeBackward(Callback):
1068+
def on_before_backward(self, trainer, pl_module, loss):
1069+
if trainer.current_epoch == 1:
1070+
raise RuntimeError("Trouble!")
1071+
1072+
1073+
class TroubleCallbackOnAfterBackward(Callback):
1074+
def on_after_backward(self, trainer, pl_module):
1075+
if trainer.current_epoch == 1:
1076+
raise RuntimeError("Trouble!")
1077+
1078+
1079+
class TroubleCallbackOnBeforeOptimizerStep(Callback):
1080+
def on_before_optimizer_step(self, trainer, pl_module, optimizer):
1081+
if trainer.current_epoch == 1:
1082+
raise RuntimeError("Trouble!")
1083+
1084+
1085+
class TroubleCallbackOnBeforeZeroGrad(Callback):
1086+
def on_before_zero_grad(self, trainer, pl_module, optimizer):
1087+
if trainer.current_epoch == 1:
1088+
raise RuntimeError("Trouble!")
1089+
1090+
1091+
####
10651092

10661093

10671094
@pytest.mark.parametrize(
10681095
"TroubledCallback",
10691096
[
1097+
TroubledCallbackOnFitEnd,
10701098
TroubledCallbackOnTrainBatchStart,
10711099
TroubledCallbackOnTrainBatchEnd,
10721100
TroubledCallbackOnTrainEpochStart,
10731101
TroubledCallbackOnTrainEpochEnd,
1074-
TroubledCallbackOnTrainEnd,
1075-
TroubledCallbackOnValidationBatchStart,
1076-
TroubledCallbackOnValidationBatchEnd,
10771102
TroubledCallbackOnValidationEpochStart,
10781103
TroubledCallbackOnValidationEpochEnd,
1104+
TroubledCallbackOnValidationBatchStart,
1105+
TroubledCallbackOnValidationBatchEnd,
1106+
TroubledCallbackOnTrainEnd,
10791107
TroubledCallbackOnValidationStart,
10801108
TroubledCallbackOnValidationEnd,
1081-
TroubledCallbackOnFitEnd,
1109+
TroubleCallbackOnBeforeBackward,
1110+
TroubleCallbackOnAfterBackward,
1111+
TroubleCallbackOnBeforeOptimizerStep,
1112+
TroubleCallbackOnBeforeZeroGrad,
10821113
],
10831114
)
10841115
def test_model_checkpoint_on_exception_in_other_callbacks(tmp_path, TroubledCallback):

0 commit comments

Comments
 (0)