Skip to content

Commit bd5866b

Browse files
authored
fix batchsampler does not work correctly (#20327)
* fix batchsampler does not work correctly * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add batch sampler shuffle state test
1 parent 1f2d7a1 commit bd5866b

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

src/lightning/pytorch/utilities/data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,8 @@ def _is_dataloader_shuffled(dataloader: object) -> bool:
349349
if not hasattr(dataloader, "sampler"):
350350
# shuffling is enabled via a sampler. No sampler, no shuffling
351351
return False
352-
sampler = dataloader.sampler
352+
batch_sampler = dataloader.batch_sampler
353+
sampler = batch_sampler.sampler if batch_sampler is not None else dataloader.sampler
353354
if isinstance(sampler, SequentialSampler):
354355
return False
355356
return isinstance(sampler, RandomSampler)

tests/tests_pytorch/utilities/test_data.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from lightning.pytorch.trainer.states import RunningStage
1313
from lightning.pytorch.utilities.data import (
1414
_get_dataloader_init_args_and_kwargs,
15+
_is_dataloader_shuffled,
1516
_update_dataloader,
1617
extract_batch_size,
1718
has_len_all_ranks,
@@ -20,7 +21,7 @@
2021
from lightning.pytorch.utilities.exceptions import MisconfigurationException
2122
from lightning_utilities.test.warning import no_warning_call
2223
from torch import Tensor
23-
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
24+
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
2425

2526

2627
def test_extract_batch_size():
@@ -304,6 +305,31 @@ def __init__(self, extra_arg):
304305
_ = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING)
305306

306307

308+
def test_batch_sampler_shuffle_setting():
309+
"""Test whether the `shuffle` state is correctly set in the `BatchSampler`."""
310+
311+
random_sampler = RandomSampler(range(10))
312+
seq_sampler = SequentialSampler(range(10))
313+
shuffled_dataloader = DataLoader(
314+
range(10), batch_sampler=BatchSampler(random_sampler, batch_size=2, drop_last=False)
315+
)
316+
sequential_dataloader = DataLoader(
317+
range(10), batch_sampler=BatchSampler(seq_sampler, batch_size=2, drop_last=False)
318+
)
319+
320+
# if batch_size is 1, the pytorch init a default SequentialSampler and set BatchSampler to None
321+
single_dataloader = DataLoader(range(10), batch_sampler=BatchSampler(seq_sampler, batch_size=1, drop_last=False))
322+
assert _is_dataloader_shuffled(shuffled_dataloader)
323+
assert not _is_dataloader_shuffled(sequential_dataloader)
324+
assert not _is_dataloader_shuffled(single_dataloader)
325+
326+
# if batch_size is 1, and no batch_sampler is set, the pytorch will set BatchSampler to None
327+
single_dataloader = DataLoader(range(10), batch_size=1)
328+
shuffled_single_dataloader = DataLoader(range(10), batch_size=1, shuffle=True)
329+
assert not _is_dataloader_shuffled(single_dataloader)
330+
assert _is_dataloader_shuffled(shuffled_single_dataloader)
331+
332+
307333
@pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING])
308334
def test_dataloader_kwargs_replacement_with_iterable_dataset(mode):
309335
"""Test that DataLoader kwargs are not replaced when using Iterable Dataset."""

0 commit comments

Comments
 (0)