Skip to content

Commit f72bf31

Browse files
rohitgr7lexierule
authored andcommitted
Disable attaching samplers when using IterableDataset (#11507)
1 parent e95d8b1 commit f72bf31

File tree

3 files changed

+25
-4
lines changed

3 files changed

+25
-4
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1717
- Changed `LSFEnvironment` to use `LSB_DJOB_RANKFILE` environment variable instead of `LSB_HOSTS` for determining node rank and main address ([#10825](https://github.com/PyTorchLightning/pytorch-lightning/pull/10825))
1818

1919

20+
- Disbled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507))
21+
22+
2023
## [1.5.8] - 2022-01-05
2124

2225
### Fixed

pytorch_lightning/trainer/data_loading.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,13 @@ def _get_dataloader_init_kwargs(
272272

273273
# kwargs to re-construct the dataloader
274274
dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
275-
dl_kwargs.update(
276-
TrainerDataLoadingMixin._dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode=mode)
277-
)
275+
if isinstance(dl_kwargs["dataset"], IterableDataset):
276+
dl_kwargs["batch_sampler"] = None
277+
dl_kwargs["sampler"] = None
278+
else:
279+
dl_kwargs.update(
280+
TrainerDataLoadingMixin._dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode=mode)
281+
)
278282

279283
required_args = {
280284
p.name

tests/trainer/test_data_loading.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@
2020
from torch.utils.data.sampler import BatchSampler, Sampler, SequentialSampler
2121

2222
from pytorch_lightning import Trainer
23+
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
2324
from pytorch_lightning.trainer.states import RunningStage
2425
from pytorch_lightning.trainer.supporters import CombinedLoader
2526
from pytorch_lightning.utilities.enums import DistributedType
2627
from pytorch_lightning.utilities.exceptions import MisconfigurationException
27-
from tests.helpers import BoringModel, RandomDataset
28+
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset
2829
from tests.helpers.runif import RunIf
2930

3031

@@ -389,3 +390,16 @@ def test_non_sequential_sampler_warning_is_raised_for_eval_dataloader(val_dl):
389390
trainer._data_connector.attach_data(model, val_dataloaders=val_dl)
390391
with pytest.warns(UserWarning, match="recommended .* turn this off for val/test/predict"):
391392
trainer._reset_eval_dataloader(RunningStage.VALIDATING, model)
393+
394+
395+
@pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING])
396+
def test_dataloader_kwargs_replacement_with_iterable_dataset(mode):
397+
"""Test that DataLoader kwargs are not replaced when using Iterable Dataset."""
398+
dataset = RandomIterableDataset(7, 100)
399+
dataloader = DataLoader(dataset, batch_size=32)
400+
dl_kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, dataloader.sampler, mode=mode)
401+
assert dl_kwargs["sampler"] is None
402+
assert dl_kwargs["batch_sampler"] is None
403+
assert dl_kwargs["batch_size"] is dataloader.batch_size
404+
assert dl_kwargs["dataset"] is dataloader.dataset
405+
assert dl_kwargs["collate_fn"] is dataloader.collate_fn

0 commit comments

Comments
 (0)