Skip to content

Commit f6ed0bd

Browse files
ninginthecloudpre-commit-ci[bot]awaelchlitchatonrohitgr7
authored
introduce has_len_all_ranks() to check the length of dataloader across ranks (#9827)
* introduce , udpate tests * update CHANGELOG.md * change staticmethod and hook attribute naming * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * remove non-essential comment * fix merge error and comment format * try to fix test_tpu.py failure * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update on comments * chlog * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * chlog * update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * try fix * Revert back TPUSpawn changes * Update test Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: thomas chaton <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Kaushik B <[email protected]>
1 parent 34fcb87 commit f6ed0bd

File tree

8 files changed

+104
-23
lines changed

8 files changed

+104
-23
lines changed

CHANGELOG.md

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
179179
- Changed default value of the `max_steps` Trainer argument from `None` to -1 ([#9460](https://github.com/PyTorchLightning/pytorch-lightning/pull/9460))
180180
- LightningModule now raises an error when calling `log(on_step=False, on_epoch=False)` ([#10227](https://github.com/PyTorchLightning/pytorch-lightning/pull/10227))
181181
- Quantization aware training observers are now disabled by default during validating/testing/predicting stages ([#8540](https://github.com/PyTorchLightning/pytorch-lightning/pull/8540))
182+
- Raised `MisconfigurationException` when total length of `dataloader` across ranks is zero, and give warning when total length is non-zero, but only local rank length is zero. ([#9827](https://github.com/PyTorchLightning/pytorch-lightning/pull/9827))
182183
- Changed the model size calculation using `ByteCounter` ([#10123](https://github.com/PyTorchLightning/pytorch-lightning/pull/10123))
183-
184-
185184
- Enabled `on_load_checkpoint` for `LightningDataModule` for all `trainer_fn` ([#10238](https://github.com/PyTorchLightning/pytorch-lightning/pull/10238))
186-
187-
188185
- Allow separate config files for parameters with class type when LightningCLI is in subclass_mode=False ([#10286](https://github.com/PyTorchLightning/pytorch-lightning/pull/10286))
189186

190187

@@ -221,8 +218,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
221218
- Deprecated `ClusterEnvironment.creates_children()` in favor of `ClusterEnvironment.creates_processes_externally` (property) ([#10106](https://github.com/PyTorchLightning/pytorch-lightning/pull/10106))
222219
- Deprecated `PrecisionPlugin.master_params()` in favor of `PrecisionPlugin.main_params()` ([#10105](https://github.com/PyTorchLightning/pytorch-lightning/pull/10105))
223220
- Deprecated `lr_sch_names` from `LearningRateMonitor` ([#10066](https://github.com/PyTorchLightning/pytorch-lightning/pull/10066))
224-
225-
226221
- Deprecated `ProgressBar` callback in favor of `TQDMProgressBar` ([#10134](https://github.com/PyTorchLightning/pytorch-lightning/pull/10134))
227222

228223

pytorch_lightning/core/hooks.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,9 +314,13 @@ def __init__(self) -> None:
314314
prepare_data_per_node:
315315
If True, each LOCAL_RANK=0 will call prepare data.
316316
Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.
317+
allow_zero_length_dataloader_with_multiple_devices:
318+
If True, dataloader with zero length within local rank is allowed.
319+
Default value is False.
317320
"""
318321
super().__init__()
319322
self.prepare_data_per_node: bool = True
323+
self.allow_zero_length_dataloader_with_multiple_devices: bool = False
320324

321325
def prepare_data(self) -> None:
322326
"""Use this to download and prepare data.

pytorch_lightning/trainer/data_loading.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
CaptureMapDataset,
3838
FastForwardSampler,
3939
)
40-
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
40+
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len_all_ranks
4141
from pytorch_lightning.utilities.enums import DistributedType
4242
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4343
from pytorch_lightning.utilities.imports import _fault_tolerant_training
@@ -346,7 +346,12 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
346346
# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
347347
self.train_dataloader = CombinedLoader(self.train_dataloader, self._data_connector.multiple_trainloader_mode)
348348

349-
self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float("inf")
349+
module = model or self.lightning_module or self.datamodule
350+
self.num_training_batches = (
351+
len(self.train_dataloader)
352+
if has_len_all_ranks(self.train_dataloader, self.training_type_plugin, module)
353+
else float("inf")
354+
)
350355

351356
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
352357
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
@@ -371,7 +376,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
371376
"If you want to disable validation set `limit_val_batches` to 0.0 instead."
372377
)
373378
else:
374-
if not has_len(self.train_dataloader):
379+
if not has_len_all_ranks(self.train_dataloader, self.training_type_plugin, module):
375380
if self.val_check_interval == 1.0:
376381
self.val_check_batch = float("inf")
377382
else:
@@ -452,9 +457,14 @@ def _reset_eval_dataloader(
452457

453458
# determine number of batches
454459
# datasets could be none, 1 or 2+
460+
module = model or self.lightning_module or self.datamodule
455461
if len(dataloaders) != 0:
456462
for i, dataloader in enumerate(dataloaders):
457-
num_batches = len(dataloader) if has_len(dataloader) else float("inf")
463+
num_batches = (
464+
len(dataloader)
465+
if has_len_all_ranks(dataloader, self.training_type_plugin, module)
466+
else float("inf")
467+
)
458468
self._worker_check(dataloader, f"{mode.dataloader_prefix}_dataloader {i}")
459469

460470
# percent or num_steps

pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
import uuid
1717
from typing import Optional, Tuple
1818

19+
from torch.utils.data import DataLoader
20+
1921
import pytorch_lightning as pl
2022
from pytorch_lightning.loggers.base import DummyLogger
2123
from pytorch_lightning.utilities import rank_zero_warn
2224
from pytorch_lightning.utilities.cloud_io import get_filesystem
23-
from pytorch_lightning.utilities.data import has_len
25+
from pytorch_lightning.utilities.data import has_len_all_ranks
2426
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2527
from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error
2628
from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr
@@ -257,13 +259,14 @@ def _adjust_batch_size(
257259
if desc:
258260
log.info(f"Batch size {batch_size} {desc}, trying batch size {new_size}")
259261

260-
if not _is_valid_batch_size(new_size, trainer.train_dataloader):
262+
if not _is_valid_batch_size(new_size, trainer.train_dataloader, trainer):
261263
new_size = min(new_size, len(trainer.train_dataloader.dataset))
262264

263265
changed = new_size != batch_size
264266
lightning_setattr(model, batch_arg_name, new_size)
265267
return new_size, changed
266268

267269

268-
def _is_valid_batch_size(current_size, dataloader):
269-
return not has_len(dataloader) or current_size <= len(dataloader)
270+
def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer"):
271+
module = trainer.lightning_module or trainer.datamodule
272+
return not has_len_all_ranks(dataloader, trainer.training_type_plugin, module) or batch_size <= len(dataloader)

pytorch_lightning/utilities/data.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
import torch
1818
from torch.utils.data import DataLoader, IterableDataset
1919

20+
import pytorch_lightning as pl
2021
from pytorch_lightning.utilities import rank_zero_warn
22+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2123
from pytorch_lightning.utilities.warnings import WarningCache
2224

2325
BType = Union[torch.Tensor, str, Mapping[Any, "BType"], Iterable["BType"]]
@@ -93,6 +95,55 @@ def has_len(dataloader: Union[DataLoader, Iterable]) -> bool:
9395
return has_len
9496

9597

98+
def has_len_all_ranks(
99+
dataloader: DataLoader,
100+
training_type: "pl.TrainingTypePlugin",
101+
model: Union["pl.LightningModule", "pl.LightningDataModule"],
102+
) -> bool:
103+
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
104+
infinite dataloader.
105+
106+
Raises:
107+
ValueError:
108+
If the length of Dataloader is 0, as it requires at least one batch
109+
"""
110+
try:
111+
total_length = training_type.reduce(torch.tensor(len(dataloader)).to(model.device), reduce_op="sum")
112+
local_length = len(dataloader)
113+
114+
if total_length == 0:
115+
raise MisconfigurationException(
116+
"Total length of `Dataloader` across ranks is zero. Please make sure that it returns at least 1 batch."
117+
)
118+
if total_length > 0 and local_length == 0:
119+
if model.allow_zero_length_dataloader_with_multiple_devices:
120+
rank_zero_warn(
121+
"Total length of `Dataloader` across ranks is zero, but local rank has zero length."
122+
" Please be cautious of uneven batch length."
123+
)
124+
has_len = False
125+
else:
126+
raise MisconfigurationException(
127+
"`Dataloader` within local rank has zero length. Please make sure that it returns at least 1 batch."
128+
)
129+
else:
130+
has_len = True
131+
132+
except TypeError:
133+
has_len = False
134+
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
135+
has_len = False
136+
137+
if has_len and has_iterable_dataset(dataloader):
138+
rank_zero_warn(
139+
"Your `IterableDataset` has `__len__` defined."
140+
" In combination with multi-process data loading (when num_workers > 1),"
141+
" `__len__` could be inaccurate if each worker is not configured independently"
142+
" to avoid having duplicate data."
143+
)
144+
return has_len
145+
146+
96147
def get_len(dataloader: DataLoader) -> Union[int, float]:
97148
"""Return the length of the given DataLoader.
98149

tests/accelerators/test_tpu.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,6 @@ def test_xla_checkpoint_plugin_being_default():
308308
def test_mp_device_dataloader_attribute(_):
309309
dataset = RandomDataset(32, 64)
310310
dataloader = TPUSpawnPlugin().process_dataloader(DataLoader(dataset))
311-
312311
assert dataloader.dataset == dataset
313312

314313

tests/trainer/test_dataloaders.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import tests.helpers.pipelines as tpipes
2727
from pytorch_lightning import Callback, seed_everything, Trainer
2828
from pytorch_lightning.callbacks import ModelCheckpoint
29-
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
29+
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len_all_ranks
3030
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3131
from tests.base import EvalModelTemplate
3232
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset, RandomIterableDatasetWithLen
@@ -265,7 +265,7 @@ def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches,
265265

266266
num_batches = 128 / batch_size
267267
for dl in (train_dl, val_dl, test_dl):
268-
if has_len(dl):
268+
if has_len_all_ranks(dl, trainer.training_type_plugin, model):
269269
assert len(dl) == num_batches
270270
else:
271271
assert sum(1 for _ in dl) == num_batches
@@ -855,10 +855,10 @@ def __len__(self):
855855
return len(original_dataset)
856856

857857
# with __len__ defined
858+
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
858859
dataloader = DataLoader(IterableWithLen(), batch_size=16)
859-
assert has_len(dataloader)
860+
assert has_len_all_ranks(dataloader, trainer.training_type_plugin, model)
860861
assert has_iterable_dataset(dataloader)
861-
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
862862
with pytest.warns(UserWarning, match="Your `IterableDataset` has `__len__` defined."):
863863
trainer.validate(model, val_dataloaders=[dataloader])
864864
with pytest.warns(UserWarning, match="Your `IterableDataset` has `__len__` defined."):
@@ -869,10 +869,10 @@ def __len__(self):
869869
trainer.predict(model, dataloaders=[dataloader])
870870

871871
# without __len__ defined
872+
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
872873
dataloader = DataLoader(IterableWithoutLen(), batch_size=16)
873-
assert not has_len(dataloader)
874+
assert not has_len_all_ranks(dataloader, trainer.training_type_plugin, model)
874875
assert has_iterable_dataset(dataloader)
875-
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
876876
trainer.validate(model, val_dataloaders=dataloader)
877877
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
878878
trainer.test(model, test_dataloaders=dataloader)

tests/utilities/test_data.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,17 @@
22
import torch
33
from torch.utils.data.dataloader import DataLoader
44

5-
from pytorch_lightning.utilities.data import extract_batch_size, get_len, has_iterable_dataset, has_len, warning_cache
6-
from tests.helpers.boring_model import RandomDataset, RandomIterableDataset
5+
from pytorch_lightning import Trainer
6+
from pytorch_lightning.utilities.data import (
7+
extract_batch_size,
8+
get_len,
9+
has_iterable_dataset,
10+
has_len,
11+
has_len_all_ranks,
12+
warning_cache,
13+
)
14+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
15+
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset
716

817

918
def test_extract_batch_size():
@@ -73,3 +82,13 @@ def test_get_len():
7382

7483
assert isinstance(value, float)
7584
assert value == float("inf")
85+
86+
87+
def test_has_len_all_rank():
88+
trainer = Trainer(fast_dev_run=True)
89+
model = BoringModel()
90+
91+
with pytest.raises(MisconfigurationException, match="Total length of `Dataloader` across ranks is zero."):
92+
assert not has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.training_type_plugin, model)
93+
94+
assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.training_type_plugin, model)

0 commit comments

Comments
 (0)