|
20 | 20 | from torch.utils.data.sampler import BatchSampler, Sampler, SequentialSampler
|
21 | 21 |
|
22 | 22 | from pytorch_lightning import Trainer
|
| 23 | +from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin |
23 | 24 | from pytorch_lightning.trainer.states import RunningStage
|
24 | 25 | from pytorch_lightning.trainer.supporters import CombinedLoader
|
25 | 26 | from pytorch_lightning.utilities.enums import DistributedType
|
26 | 27 | from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
27 |
| -from tests.helpers import BoringModel, RandomDataset |
| 28 | +from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset |
28 | 29 | from tests.helpers.runif import RunIf
|
29 | 30 |
|
30 | 31 |
|
@@ -389,3 +390,16 @@ def test_non_sequential_sampler_warning_is_raised_for_eval_dataloader(val_dl):
|
389 | 390 | trainer._data_connector.attach_data(model, val_dataloaders=val_dl)
|
390 | 391 | with pytest.warns(UserWarning, match="recommended .* turn this off for val/test/predict"):
|
391 | 392 | trainer._reset_eval_dataloader(RunningStage.VALIDATING, model)
|
| 393 | + |
| 394 | + |
| 395 | +@pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING]) |
| 396 | +def test_dataloader_kwargs_replacement_with_iterable_dataset(mode): |
| 397 | + """Test that DataLoader kwargs are not replaced when using Iterable Dataset.""" |
| 398 | + dataset = RandomIterableDataset(7, 100) |
| 399 | + dataloader = DataLoader(dataset, batch_size=32) |
| 400 | + dl_kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, dataloader.sampler, mode=mode) |
| 401 | + assert dl_kwargs["sampler"] is None |
| 402 | + assert dl_kwargs["batch_sampler"] is None |
| 403 | + assert dl_kwargs["batch_size"] is dataloader.batch_size |
| 404 | + assert dl_kwargs["dataset"] is dataloader.dataset |
| 405 | + assert dl_kwargs["collate_fn"] is dataloader.collate_fn |
0 commit comments