|
12 | 12 | from lightning.pytorch.trainer.states import RunningStage
|
13 | 13 | from lightning.pytorch.utilities.data import (
|
14 | 14 | _get_dataloader_init_args_and_kwargs,
|
| 15 | + _is_dataloader_shuffled, |
15 | 16 | _update_dataloader,
|
16 | 17 | extract_batch_size,
|
17 | 18 | has_len_all_ranks,
|
|
20 | 21 | from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
21 | 22 | from lightning_utilities.test.warning import no_warning_call
|
22 | 23 | from torch import Tensor
|
23 |
| -from torch.utils.data import BatchSampler, DataLoader, RandomSampler |
| 24 | +from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler |
24 | 25 |
|
25 | 26 |
|
26 | 27 | def test_extract_batch_size():
|
@@ -304,6 +305,31 @@ def __init__(self, extra_arg):
|
304 | 305 | _ = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING)
|
305 | 306 |
|
306 | 307 |
|
| 308 | +def test_batch_sampler_shuffle_setting(): |
| 309 | + """Test whether the `shuffle` state is correctly set in the `BatchSampler`.""" |
| 310 | + |
| 311 | + random_sampler = RandomSampler(range(10)) |
| 312 | + seq_sampler = SequentialSampler(range(10)) |
| 313 | + shuffled_dataloader = DataLoader( |
| 314 | + range(10), batch_sampler=BatchSampler(random_sampler, batch_size=2, drop_last=False) |
| 315 | + ) |
| 316 | + sequential_dataloader = DataLoader( |
| 317 | + range(10), batch_sampler=BatchSampler(seq_sampler, batch_size=2, drop_last=False) |
| 318 | + ) |
| 319 | + |
| 320 | + # if batch_size is 1, the pytorch init a default SequentialSampler and set BatchSampler to None |
| 321 | + single_dataloader = DataLoader(range(10), batch_sampler=BatchSampler(seq_sampler, batch_size=1, drop_last=False)) |
| 322 | + assert _is_dataloader_shuffled(shuffled_dataloader) |
| 323 | + assert not _is_dataloader_shuffled(sequential_dataloader) |
| 324 | + assert not _is_dataloader_shuffled(single_dataloader) |
| 325 | + |
| 326 | + # if batch_size is 1, and no batch_sampler is set, the pytorch will set BatchSampler to None |
| 327 | + single_dataloader = DataLoader(range(10), batch_size=1) |
| 328 | + shuffled_single_dataloader = DataLoader(range(10), batch_size=1, shuffle=True) |
| 329 | + assert not _is_dataloader_shuffled(single_dataloader) |
| 330 | + assert _is_dataloader_shuffled(shuffled_single_dataloader) |
| 331 | + |
| 332 | + |
307 | 333 | @pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING])
|
308 | 334 | def test_dataloader_kwargs_replacement_with_iterable_dataset(mode):
|
309 | 335 | """Test that DataLoader kwargs are not replaced when using Iterable Dataset."""
|
|
0 commit comments