Skip to content

Commit 927f305

Browse files
awaelchliwilliamFalconBorda
authored
Warn user when IterableDataset has __len__ defined (#2437)
* add warning when getting checking len * added test * changelog * pep * do not show warning below 1.4 * try version parse * comments * xfail * Update requirements/base.txt Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/trainer/data_loading.py Co-authored-by: Jirka Borovec <[email protected]> * version Co-authored-by: William Falcon <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka <[email protected]>
1 parent 325852c commit 927f305

File tree

4 files changed

+60
-17
lines changed

4 files changed

+60
-17
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
- Added reduce ddp results on eval ([#2434](https://github.com/PyTorchLightning/pytorch-lightning/pull/2434))
1212

13+
- Added a warning when an `IterableDataset` has `__len__` defined ([#2437](https://github.com/PyTorchLightning/pytorch-lightning/pull/2437))
14+
1315
### Changed
1416

1517

pytorch_lightning/trainer/data_loading.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import multiprocessing
12
import platform
23
from abc import ABC, abstractmethod
4+
from distutils.version import LooseVersion
35
from typing import Union, List, Tuple, Callable, Optional
4-
import multiprocessing
56

7+
import torch
68
import torch.distributed as torch_distrib
79
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
810
from torch.utils.data.distributed import DistributedSampler
@@ -41,19 +43,33 @@
4143
HOROVOD_AVAILABLE = True
4244

4345

46+
def _has_iterable_dataset(dataloader: DataLoader):
47+
return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \
48+
and isinstance(dataloader.dataset, IterableDataset)
49+
50+
4451
def _has_len(dataloader: DataLoader) -> bool:
4552
""" Checks if a given Dataloader has __len__ method implemented i.e. if
46-
it is a finite dataloader or infinite dataloader """
53+
it is a finite dataloader or infinite dataloader. """
54+
4755
try:
4856
# try getting the length
4957
if len(dataloader) == 0:
5058
raise ValueError('`Dataloader` returned 0 length.'
5159
' Please make sure that your Dataloader at least returns 1 batch')
52-
return True
60+
has_len = True
5361
except TypeError:
54-
return False
62+
has_len = False
5563
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
56-
return False
64+
has_len = False
65+
66+
if has_len and _has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
67+
rank_zero_warn(
68+
'Your `IterableDataset` has `__len__` defined.'
69+
' In combination with multi-processing data loading (e.g. batch size > 1),'
70+
' this can lead to unintended side effects since the samples will be duplicated.'
71+
)
72+
return has_len
5773

5874

5975
class TrainerDataLoadingMixin(ABC):
@@ -128,12 +144,9 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
128144
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
129145

130146
# don't do anything if it's not a dataloader
131-
# don't manipulate iterable datasets
132147
is_dataloader = isinstance(dataloader, DataLoader)
133-
134-
is_iterable_ds = False
135-
if ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset'):
136-
is_iterable_ds = isinstance(dataloader.dataset, IterableDataset)
148+
# don't manipulate iterable datasets
149+
is_iterable_ds = _has_iterable_dataset(dataloader)
137150

138151
if not is_dataloader or is_iterable_ds:
139152
return dataloader
@@ -285,11 +298,7 @@ def _reset_eval_dataloader(
285298
# datasets could be none, 1 or 2+
286299
if len(dataloaders) != 0:
287300
for i, dataloader in enumerate(dataloaders):
288-
try:
289-
num_batches = len(dataloader)
290-
except (TypeError, NotImplementedError):
291-
num_batches = float('inf')
292-
301+
num_batches = len(dataloader) if _has_len(dataloader) else float('inf')
293302
self._worker_check(dataloader, f'{mode} dataloader {i}')
294303

295304
# percent or num_steps

requirements/base.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ tensorboard>=1.14
66
future>=0.17.1 # required for builtins in setup.py
77
# pyyaml>=3.13
88
PyYAML>=5.1 # OmegaConf requirement
9-
tqdm>=4.41.0
9+
tqdm>=4.41.0

tests/trainer/test_dataloaders.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
import pytest
44
import torch
5+
from packaging.version import parse
56
from torch.utils.data.dataloader import DataLoader
6-
from torch.utils.data.dataset import Subset
7+
from torch.utils.data.dataset import Subset, IterableDataset
78

89
import tests.base.develop_pipelines as tpipes
910
from pytorch_lightning import Trainer
11+
from pytorch_lightning.trainer.data_loading import _has_len, _has_iterable_dataset
1012
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1113
from tests.base import EvalModelTemplate
1214

@@ -487,6 +489,36 @@ def test_warning_with_few_workers(tmpdir, ckpt_path):
487489
trainer.test(**test_options)
488490

489491

492+
@pytest.mark.xfail(
493+
parse(torch.__version__) < parse("1.4.0"),
494+
reason="IterableDataset with __len__ before 1.4 raises",
495+
)
496+
def test_warning_with_iterable_dataset_and_len(tmpdir):
497+
""" Tests that a warning messages is shown when an IterableDataset defines `__len__`. """
498+
model = EvalModelTemplate()
499+
original_dataset = model.train_dataloader().dataset
500+
501+
class IterableWithLen(IterableDataset):
502+
503+
def __iter__(self):
504+
return iter(original_dataset)
505+
506+
def __len__(self):
507+
return len(original_dataset)
508+
509+
dataloader = DataLoader(IterableWithLen(), batch_size=16)
510+
assert _has_len(dataloader)
511+
assert _has_iterable_dataset(dataloader)
512+
trainer = Trainer(
513+
default_root_dir=tmpdir,
514+
max_steps=3,
515+
)
516+
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
517+
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
518+
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
519+
trainer.test(model, test_dataloaders=[dataloader])
520+
521+
490522
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
491523
def test_dataloader_reinit_for_subclass():
492524

0 commit comments

Comments
 (0)