|
13 | 13 | # limitations under the License. |
14 | 14 | import pytorch_lightning as pl |
15 | 15 | from pytorch_lightning.trainer.states import TrainerFn |
16 | | -from pytorch_lightning.utilities import rank_zero_warn |
| 16 | +from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn |
17 | 17 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
18 | 18 | from pytorch_lightning.utilities.model_helpers import is_overridden |
19 | 19 | from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature |
@@ -75,6 +75,25 @@ def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None |
75 | 75 | " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined." |
76 | 76 | ) |
77 | 77 |
|
| 78 | + # ---------------------------------------------- |
| 79 | + # verify model does not have |
| 80 | + # - on_train_dataloader |
| 81 | + # - on_val_dataloader |
| 82 | + # ---------------------------------------------- |
| 83 | + has_on_train_dataloader = is_overridden("on_train_dataloader", model) |
| 84 | + if has_on_train_dataloader: |
| 85 | + rank_zero_deprecation( |
| 86 | + "Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0." |
| 87 | + " Please use `train_dataloader()` directly." |
| 88 | + ) |
| 89 | + |
| 90 | + has_on_val_dataloader = is_overridden("on_val_dataloader", model) |
| 91 | + if has_on_val_dataloader: |
| 92 | + rank_zero_deprecation( |
| 93 | + "Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0." |
| 94 | + " Please use `val_dataloader()` directly." |
| 95 | + ) |
| 96 | + |
78 | 97 | trainer = self.trainer |
79 | 98 |
|
80 | 99 | trainer.overriden_optimizer_step = is_overridden("optimizer_step", model) |
@@ -102,10 +121,39 @@ def __verify_eval_loop_configuration(self, model: "pl.LightningModule", stage: s |
102 | 121 | if has_step and not has_loader: |
103 | 122 | rank_zero_warn(f"you defined a {step_name} but have no {loader_name}. Skipping {stage} loop") |
104 | 123 |
|
| 124 | + # ---------------------------------------------- |
| 125 | + # verify model does not have |
| 126 | + # - on_val_dataloader |
| 127 | + # - on_test_dataloader |
| 128 | + # ---------------------------------------------- |
| 129 | + has_on_val_dataloader = is_overridden("on_val_dataloader", model) |
| 130 | + if has_on_val_dataloader: |
| 131 | + rank_zero_deprecation( |
| 132 | + "Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0." |
| 133 | + " Please use `val_dataloader()` directly." |
| 134 | + ) |
| 135 | + |
| 136 | + has_on_test_dataloader = is_overridden("on_test_dataloader", model) |
| 137 | + if has_on_test_dataloader: |
| 138 | + rank_zero_deprecation( |
| 139 | + "Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0." |
| 140 | + " Please use `test_dataloader()` directly." |
| 141 | + ) |
| 142 | + |
105 | 143 | def __verify_predict_loop_configuration(self, model: "pl.LightningModule") -> None: |
106 | 144 | has_predict_dataloader = is_overridden("predict_dataloader", model) |
107 | 145 | if not has_predict_dataloader: |
108 | 146 | raise MisconfigurationException("Dataloader not found for `Trainer.predict`") |
| 147 | + # ---------------------------------------------- |
| 148 | + # verify model does not have |
| 149 | + # - on_predict_dataloader |
| 150 | + # ---------------------------------------------- |
| 151 | + has_on_predict_dataloader = is_overridden("on_predict_dataloader", model) |
| 152 | + if has_on_predict_dataloader: |
| 153 | + rank_zero_deprecation( |
| 154 | + "Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0." |
| 155 | + " Please use `predict_dataloader()` directly." |
| 156 | + ) |
109 | 157 |
|
110 | 158 | def __verify_dp_batch_transfer_support(self, model: "pl.LightningModule") -> None: |
111 | 159 | """Raise Misconfiguration exception since these hooks are not supported in DP mode""" |
|
0 commit comments