Skip to content

Commit f4a0069

Browse files
krshrimalipre-commit-ci[bot]awaelchli
authored
set_epoch for validation and prediction data loaders (#12197)
* set_epoch for prediction and evaluation * minor fix in the test, warning msg was changed Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent f57f2d5 commit f4a0069

File tree

5 files changed

+53
-3
lines changed

5 files changed

+53
-3
lines changed

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,18 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
159159
# indicate the loop has run
160160
self._has_run = True
161161

162+
def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
163+
dataloader = self.current_dataloader
164+
if (
165+
dataloader is not None
166+
and getattr(dataloader, "sampler", None)
167+
and callable(getattr(dataloader.sampler, "set_epoch", None))
168+
):
169+
# set seed for distributed sampler (enables shuffling for each epoch)
170+
dataloader.sampler.set_epoch(self.trainer.fit_loop.epoch_progress.current.processed)
171+
172+
super().on_advance_start(*args, **kwargs)
173+
162174
def on_advance_end(self) -> None:
163175
self.trainer._logger_connector.epoch_end_reached()
164176

pytorch_lightning/loops/dataloader/prediction_loop.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,14 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
8787
"""Predicts one entire dataloader."""
8888
void(*args, **kwargs)
8989
dataloader = self.current_dataloader
90+
if (
91+
dataloader is not None
92+
and getattr(dataloader, "sampler", None)
93+
and callable(getattr(dataloader.sampler, "set_epoch", None))
94+
):
95+
# set seed for distributed sampler (enables shuffling for each epoch)
96+
dataloader.sampler.set_epoch(self.trainer.fit_loop.epoch_progress.current.processed)
97+
dataloader = self.trainer.strategy.process_dataloader(dataloader)
9098
dataloader_iter = enumerate(dataloader)
9199
dl_max_batches = self.max_batches[self.current_dataloader_idx]
92100

pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,8 @@ def replace_sampler(dataloader: DataLoader) -> DataLoader:
484484
def _check_eval_shuffling(dataloader, mode):
485485
if _is_dataloader_shuffled(dataloader):
486486
rank_zero_warn(
487-
f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
488-
" it is strongly recommended that you turn this off for val/test/predict dataloaders.",
487+
f"Your `{mode.dataloader_prefix}_dataloader`'s sampler has shuffling enabled,"
488+
" it is strongly recommended that you turn shuffling off for val/test/predict dataloaders.",
489489
category=PossibleUserWarning,
490490
)
491491

tests/loops/test_evaluation_loop.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from unittest import mock
15+
from unittest.mock import Mock
1516

1617
import torch
1718
from torch.utils.data.dataloader import DataLoader
19+
from torch.utils.data.sampler import RandomSampler
1820

1921
from pytorch_lightning import Trainer
2022
from pytorch_lightning.loops import EvaluationEpochLoop
@@ -42,6 +44,34 @@ def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir):
4244
assert eval_epoch_end_mock.call_count == 4
4345

4446

47+
def test_set_epoch_called_eval_predict(tmpdir):
48+
"""Tests that set_epoch (if the sampler has one) is called on the DataLoader during evaluation and
49+
prediction."""
50+
51+
def _get_dataloader():
52+
dataset = RandomDataset(32, 64)
53+
sampler = RandomSampler(dataset)
54+
sampler.set_epoch = Mock()
55+
return DataLoader(dataset, sampler=sampler)
56+
57+
model = BoringModel()
58+
trainer = Trainer(
59+
default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=2, enable_model_summary=False
60+
)
61+
62+
train_dataloader = _get_dataloader()
63+
val_dataloader = _get_dataloader()
64+
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
65+
# One for each epoch
66+
assert train_dataloader.sampler.set_epoch.call_count == 2
67+
# One for each epoch + sanity check
68+
assert val_dataloader.sampler.set_epoch.call_count == 3
69+
70+
val_dataloader = _get_dataloader()
71+
trainer.validate(model, val_dataloader)
72+
assert val_dataloader.sampler.set_epoch.call_count == 1
73+
74+
4575
@mock.patch(
4676
"pytorch_lightning.trainer.connectors.logger_connector.logger_connector.LoggerConnector.log_eval_end_metrics"
4777
)

tests/trainer/test_data_loading.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,5 +378,5 @@ def test_non_sequential_sampler_warning_is_raised_for_eval_dataloader(val_dl):
378378
trainer = Trainer()
379379
model = BoringModel()
380380
trainer._data_connector.attach_data(model, val_dataloaders=val_dl)
381-
with pytest.warns(PossibleUserWarning, match="recommended .* turn this off for val/test/predict"):
381+
with pytest.warns(PossibleUserWarning, match="recommended .* turn shuffling off for val/test/predict"):
382382
trainer._data_connector._reset_eval_dataloader(RunningStage.VALIDATING, model)

0 commit comments

Comments
 (0)