Skip to content

Commit c014dd1

Browse files
rohitgr7lexierule
authored andcommitted
Fix support for CombinedLoader while checking for warning raised with eval dataloaders (#10994)
1 parent 82d7d50 commit c014dd1

File tree

4 files changed

+50
-7
lines changed

4 files changed

+50
-7
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1818
- Fixed support for logging within callbacks returned from `LightningModule` ([#10991](https://github.com/PyTorchLightning/pytorch-lightning/pull/10991))
1919

2020

21+
- Fixed support for `CombinedLoader` while checking for warning raised with eval dataloaders ([#10994](https://github.com/PyTorchLightning/pytorch-lightning/pull/10994))
22+
23+
2124
-
2225

2326

pytorch_lightning/trainer/data_loading.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -455,9 +455,11 @@ def _reset_eval_dataloader(
455455
loader, SequentialSampler(loader.dataset), mode=mode
456456
)
457457
else:
458-
rank_zero_warn(
459-
f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
460-
"it is strongly recommended that you turn this off for val/test/predict dataloaders."
458+
apply_to_collection(
459+
loader.loaders if isinstance(loader, CombinedLoader) else loader,
460+
DataLoader,
461+
self._check_eval_shuffling,
462+
mode=mode,
461463
)
462464

463465
if any(dl is None for dl in dataloaders):
@@ -620,3 +622,16 @@ def replace_sampler(dataloader: DataLoader) -> DataLoader:
620622
dataloader = apply_to_collection(dataloader, DataLoader, replace_sampler)
621623

622624
return dataloader
625+
626+
@staticmethod
627+
def _check_eval_shuffling(dataloader, mode):
628+
if (
629+
hasattr(dataloader, "sampler")
630+
and not isinstance(dataloader.sampler, SequentialSampler)
631+
and not isinstance(dataloader.dataset, IterableDataset)
632+
):
633+
rank_zero_warn(
634+
f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
635+
" it is strongly recommended that you turn this off for val/test/predict dataloaders.",
636+
category=UserWarning,
637+
)

pytorch_lightning/trainer/supporters.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,10 +304,10 @@ def __len__(self) -> int:
304304

305305

306306
class CombinedLoader:
307-
"""Combines different dataloaders and allows sampling in parallel. Supported modes are 'min_size', which raises
308-
StopIteration after the shortest loader (the one with the lowest number of batches) is done, and
309-
'max_size_cycle` which raises StopIteration after the longest loader (the one with most batches) is done, while
310-
cycling through the shorter loaders.
307+
"""Combines different dataloaders and allows sampling in parallel. Supported modes are ``"min_size"``, which
308+
raises StopIteration after the shortest loader (the one with the lowest number of batches) is done, and
309+
``"max_size_cycle"`` which raises StopIteration after the longest loader (the one with most batches) is done,
310+
while cycling through the shorter loaders.
311311
312312
Examples:
313313
>>> loaders = {'a': torch.utils.data.DataLoader(range(6), batch_size=4),

tests/trainer/test_data_loading.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from pytorch_lightning import Trainer
2323
from pytorch_lightning.trainer.states import RunningStage
24+
from pytorch_lightning.trainer.supporters import CombinedLoader
2425
from pytorch_lightning.utilities.enums import DistributedType
2526
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2627
from tests.helpers import BoringModel, RandomDataset
@@ -364,3 +365,27 @@ def test_error_raised_with_float_limited_eval_batches():
364365
match=fr"{limit_val_batches} \* {dl_size} < 1. Please increase the `limit_val_batches`",
365366
):
366367
trainer._reset_eval_dataloader(RunningStage.VALIDATING, model)
368+
369+
370+
@pytest.mark.parametrize(
371+
"val_dl",
372+
[
373+
DataLoader(dataset=RandomDataset(32, 64), shuffle=True),
374+
CombinedLoader(DataLoader(dataset=RandomDataset(32, 64), shuffle=True)),
375+
CombinedLoader(
376+
[DataLoader(dataset=RandomDataset(32, 64)), DataLoader(dataset=RandomDataset(32, 64), shuffle=True)]
377+
),
378+
CombinedLoader(
379+
{
380+
"dl1": DataLoader(dataset=RandomDataset(32, 64)),
381+
"dl2": DataLoader(dataset=RandomDataset(32, 64), shuffle=True),
382+
}
383+
),
384+
],
385+
)
386+
def test_non_sequential_sampler_warning_is_raised_for_eval_dataloader(val_dl):
387+
trainer = Trainer()
388+
model = BoringModel()
389+
trainer._data_connector.attach_data(model, val_dataloaders=val_dl)
390+
with pytest.warns(UserWarning, match="recommended .* turn this off for val/test/predict"):
391+
trainer._reset_eval_dataloader(RunningStage.VALIDATING, model)

0 commit comments

Comments
 (0)