|
| 1 | +import multiprocessing |
1 | 2 | import platform
|
2 | 3 | from abc import ABC, abstractmethod
|
| 4 | +from distutils.version import LooseVersion |
3 | 5 | from typing import Union, List, Tuple, Callable, Optional
|
4 |
| -import multiprocessing |
5 | 6 |
|
| 7 | +import torch |
6 | 8 | import torch.distributed as torch_distrib
|
7 | 9 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
8 | 10 | from torch.utils.data.distributed import DistributedSampler
|
|
41 | 43 | HOROVOD_AVAILABLE = True
|
42 | 44 |
|
43 | 45 |
|
| 46 | +def _has_iterable_dataset(dataloader: DataLoader): |
| 47 | + return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \ |
| 48 | + and isinstance(dataloader.dataset, IterableDataset) |
| 49 | + |
| 50 | + |
44 | 51 | def _has_len(dataloader: DataLoader) -> bool:
|
45 | 52 | """ Checks if a given Dataloader has __len__ method implemented i.e. if
|
46 |
| - it is a finite dataloader or infinite dataloader """ |
| 53 | + it is a finite dataloader or infinite dataloader. """ |
| 54 | + |
47 | 55 | try:
|
48 | 56 | # try getting the length
|
49 | 57 | if len(dataloader) == 0:
|
50 | 58 | raise ValueError('`Dataloader` returned 0 length.'
|
51 | 59 | ' Please make sure that your Dataloader at least returns 1 batch')
|
52 |
| - return True |
| 60 | + has_len = True |
53 | 61 | except TypeError:
|
54 |
| - return False |
| 62 | + has_len = False |
55 | 63 | except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
|
56 |
| - return False |
| 64 | + has_len = False |
| 65 | + |
| 66 | + if has_len and _has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"): |
| 67 | + rank_zero_warn( |
| 68 | + 'Your `IterableDataset` has `__len__` defined.' |
| 69 | + ' In combination with multi-processing data loading (e.g. batch size > 1),' |
| 70 | + ' this can lead to unintended side effects since the samples will be duplicated.' |
| 71 | + ) |
| 72 | + return has_len |
57 | 73 |
|
58 | 74 |
|
59 | 75 | class TrainerDataLoadingMixin(ABC):
|
@@ -128,12 +144,9 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
|
128 | 144 | def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
|
129 | 145 |
|
130 | 146 | # don't do anything if it's not a dataloader
|
131 |
| - # don't manipulate iterable datasets |
132 | 147 | is_dataloader = isinstance(dataloader, DataLoader)
|
133 |
| - |
134 |
| - is_iterable_ds = False |
135 |
| - if ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset'): |
136 |
| - is_iterable_ds = isinstance(dataloader.dataset, IterableDataset) |
| 148 | + # don't manipulate iterable datasets |
| 149 | + is_iterable_ds = _has_iterable_dataset(dataloader) |
137 | 150 |
|
138 | 151 | if not is_dataloader or is_iterable_ds:
|
139 | 152 | return dataloader
|
@@ -285,11 +298,7 @@ def _reset_eval_dataloader(
|
285 | 298 | # datasets could be none, 1 or 2+
|
286 | 299 | if len(dataloaders) != 0:
|
287 | 300 | for i, dataloader in enumerate(dataloaders):
|
288 |
| - try: |
289 |
| - num_batches = len(dataloader) |
290 |
| - except (TypeError, NotImplementedError): |
291 |
| - num_batches = float('inf') |
292 |
| - |
| 301 | + num_batches = len(dataloader) if _has_len(dataloader) else float('inf') |
293 | 302 | self._worker_check(dataloader, f'{mode} dataloader {i}')
|
294 | 303 |
|
295 | 304 | # percent or num_steps
|
|
0 commit comments