Skip to content

Commit 1657588

Browse files
ninginthecloudSean Narenananthsubtchaton
authored
deprecate on_{train/val/test/predict}_dataloader() from DataHooks (#9098)
Co-authored-by: Sean Naren <[email protected]> Co-authored-by: ananthsub <[email protected]> Co-authored-by: thomas chaton <[email protected]>
1 parent c993d0c commit 1657588

File tree

5 files changed

+141
-5
lines changed

5 files changed

+141
-5
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
161161

162162
- Deprecated the `TestTubeLogger` ([#9065](https://github.com/PyTorchLightning/pytorch-lightning/pull/9065))
163163

164+
- Deprecated `on_{train/val/test/predict}_dataloader()` from `LightningModule` and `LightningDataModule` [#9098](https://github.com/PyTorchLightning/pytorch-lightning/pull/9098)
165+
166+
-
164167

165168
- Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162))
166169

pytorch_lightning/core/hooks.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from pytorch_lightning.utilities import move_data_to_device
2222
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
23+
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn
2324

2425

2526
class ModelHooks:
@@ -684,16 +685,52 @@ def predict_dataloader(self) -> EVAL_DATALOADERS:
684685
raise NotImplementedError("`predict_dataloader` must be implemented to be used with the Lightning Trainer")
685686

686687
def on_train_dataloader(self) -> None:
687-
"""Called before requesting the train dataloader."""
688+
"""Called before requesting the train dataloader.
689+
690+
.. deprecated:: v1.5
691+
:meth:`on_train_dataloader` is deprecated and will be removed in v1.7.0.
692+
Please use :meth:`train_dataloader()` directly.
693+
"""
694+
rank_zero_deprecation(
695+
"Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
696+
" Please use `train_dataloader()` directly."
697+
)
688698

689699
def on_val_dataloader(self) -> None:
690-
"""Called before requesting the val dataloader."""
700+
"""Called before requesting the val dataloader.
701+
702+
.. deprecated:: v1.5
703+
:meth:`on_val_dataloader` is deprecated and will be removed in v1.7.0.
704+
Please use :meth:`val_dataloader()` directly.
705+
"""
706+
rank_zero_deprecation(
707+
"Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
708+
" Please use `val_dataloader()` directly."
709+
)
691710

692711
def on_test_dataloader(self) -> None:
693-
"""Called before requesting the test dataloader."""
712+
"""Called before requesting the test dataloader.
713+
714+
.. deprecated:: v1.5
715+
:meth:`on_test_dataloader` is deprecated and will be removed in v1.7.0.
716+
Please use :meth:`test_dataloader()` directly.
717+
"""
718+
rank_zero_deprecation(
719+
"Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
720+
" Please use `test_dataloader()` directly."
721+
)
694722

695723
def on_predict_dataloader(self) -> None:
696-
"""Called before requesting the predict dataloader."""
724+
"""Called before requesting the predict dataloader.
725+
726+
.. deprecated:: v1.5
727+
:meth:`on_predict_dataloader` is deprecated and will be removed in v1.7.0.
728+
Please use :meth:`predict_dataloader()` directly.
729+
"""
730+
rank_zero_deprecation(
731+
"Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
732+
" Please use `predict_dataloader()` directly."
733+
)
697734

698735
def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
699736
"""

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import pytorch_lightning as pl
1515
from pytorch_lightning.trainer.states import TrainerFn
16-
from pytorch_lightning.utilities import rank_zero_warn
16+
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
1717
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1818
from pytorch_lightning.utilities.model_helpers import is_overridden
1919
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
@@ -75,6 +75,25 @@ def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None
7575
" `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined."
7676
)
7777

78+
# ----------------------------------------------
79+
# verify model does not have
80+
# - on_train_dataloader
81+
# - on_val_dataloader
82+
# ----------------------------------------------
83+
has_on_train_dataloader = is_overridden("on_train_dataloader", model)
84+
if has_on_train_dataloader:
85+
rank_zero_deprecation(
86+
"Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
87+
" Please use `train_dataloader()` directly."
88+
)
89+
90+
has_on_val_dataloader = is_overridden("on_val_dataloader", model)
91+
if has_on_val_dataloader:
92+
rank_zero_deprecation(
93+
"Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
94+
" Please use `val_dataloader()` directly."
95+
)
96+
7897
trainer = self.trainer
7998

8099
trainer.overriden_optimizer_step = is_overridden("optimizer_step", model)
@@ -102,10 +121,39 @@ def __verify_eval_loop_configuration(self, model: "pl.LightningModule", stage: s
102121
if has_step and not has_loader:
103122
rank_zero_warn(f"you defined a {step_name} but have no {loader_name}. Skipping {stage} loop")
104123

124+
# ----------------------------------------------
125+
# verify model does not have
126+
# - on_val_dataloader
127+
# - on_test_dataloader
128+
# ----------------------------------------------
129+
has_on_val_dataloader = is_overridden("on_val_dataloader", model)
130+
if has_on_val_dataloader:
131+
rank_zero_deprecation(
132+
"Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
133+
" Please use `val_dataloader()` directly."
134+
)
135+
136+
has_on_test_dataloader = is_overridden("on_test_dataloader", model)
137+
if has_on_test_dataloader:
138+
rank_zero_deprecation(
139+
"Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
140+
" Please use `test_dataloader()` directly."
141+
)
142+
105143
def __verify_predict_loop_configuration(self, model: "pl.LightningModule") -> None:
106144
has_predict_dataloader = is_overridden("predict_dataloader", model)
107145
if not has_predict_dataloader:
108146
raise MisconfigurationException("Dataloader not found for `Trainer.predict`")
147+
# ----------------------------------------------
148+
# verify model does not have
149+
# - on_predict_dataloader
150+
# ----------------------------------------------
151+
has_on_predict_dataloader = is_overridden("on_predict_dataloader", model)
152+
if has_on_predict_dataloader:
153+
rank_zero_deprecation(
154+
"Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
155+
" Please use `predict_dataloader()` directly."
156+
)
109157

110158
def __verify_dp_batch_transfer_support(self, model: "pl.LightningModule") -> None:
111159
"""Raise Misconfiguration exception since these hooks are not supported in DP mode"""

tests/deprecated_api/test_remove_1-7.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,27 @@ def test_v1_7_0_trainer_prepare_data_per_node(tmpdir):
9191
_ = Trainer(prepare_data_per_node=False)
9292

9393

94+
def test_v1_7_0_deprecated_on_train_dataloader(tmpdir):
95+
96+
model = BoringModel()
97+
with pytest.deprecated_call(
98+
match="Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
99+
):
100+
model.on_train_dataloader()
101+
with pytest.deprecated_call(
102+
match="Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
103+
):
104+
model.on_val_dataloader()
105+
with pytest.deprecated_call(
106+
match="Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
107+
):
108+
model.on_test_dataloader()
109+
with pytest.deprecated_call(
110+
match="Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
111+
):
112+
model.on_predict_dataloader()
113+
114+
94115
@mock.patch("pytorch_lightning.loggers.test_tube.Experiment")
95116
def test_v1_7_0_test_tube_logger(_, tmpdir):
96117
with pytest.deprecated_call(match="The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7"):

tests/trainer/test_trainer.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,3 +1868,30 @@ def test_error_handling_all_stages(tmpdir, accelerator, num_processes):
18681868
trainer.test(model)
18691869
with pytest.raises(Exception, match=r"Error during predict"), patch("pytorch_lightning.Trainer._on_exception"):
18701870
trainer.predict(model, model.val_dataloader(), return_predictions=False)
1871+
1872+
1873+
def test_overridden_on_dataloaders(tmpdir):
1874+
model = BoringModel()
1875+
with pytest.deprecated_call(
1876+
match="Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
1877+
):
1878+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
1879+
trainer.fit(model)
1880+
1881+
with pytest.deprecated_call(
1882+
match="Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
1883+
):
1884+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
1885+
trainer.validate(model)
1886+
1887+
with pytest.deprecated_call(
1888+
match="Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
1889+
):
1890+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
1891+
trainer.test(model)
1892+
1893+
with pytest.deprecated_call(
1894+
match="Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
1895+
):
1896+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
1897+
trainer.predict(model)

0 commit comments

Comments
 (0)