-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
where is the bug❓
When I use a custom BatchSampler to initialize the DataLoader and use it with pytorch lightning's datamodule, I find that the shuffle settings don't take effect correctly, as evidenced by the fact that no matter how I set the sampler used to initialize the BatchSampler ( random or sequential), pytorch lightning sets the wrapped distributedsampler to the default option, i.e., shuffle by default for the training stage, and by the dataloader's sampler type for the other stages.
Analyse
The problem arises in the _is_dataloader_shuffled function (in pytorch_lightining.utilities.data) in the pytorch lightning code, where the selection of the shuffle state is based on the sampler state of the dataloader, which may seem like No problem, but in fact pytorch ignores the sampler when setting the BatchSampler (its sampler is set to the default sequential sampler), so pytorch lightning's behavior here will always get a sequential sampler which results in shuffle not working as I expected.
In fact I think the pytroch implementation is equally problematic, in the latest version of the pytorch code the dataloader property Sampler is kept mutually exclusive with BatchSampler, Shuffle etc. That is, when I use a custom BatchSampler, pytorch will only initialize a default SequentialSmapler, which is a bit counter-intuitive, but you don't get the wrong result by doing that, because pytorch chooses to use the batchsampler for data loading when it exists, and the sampler is only used when the batchsize is 1.
Suggestions
Since the problems with the pytorch code do not trigger the mentioned problem when pytorch lighting is not used, I would suggest a change to the pytorch lighting code:
before:
def _is_dataloader_shuffled(dataloader: object) -> bool:
if hasattr(dataloader, "__pl_saved_kwargs"):
# this attribute is not part of PyTorch's DataLoader, but could have been set by
# our `_replace_init_method` context manager
if "shuffle" in dataloader.__pl_saved_kwargs:
return dataloader.__pl_saved_kwargs["shuffle"]
if "shuffle" in dataloader.__pl_saved_arg_names:
return dataloader.__pl_saved_args[dataloader.__pl_saved_arg_names.index("shuffle")]
if hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset):
# shuffling is useless with iterable datasets
return False
if not hasattr(dataloader, "sampler"):
# shuffling is enabled via a sampler. No sampler, no shuffling
return False
sampler = dataloder.sampler
if isinstance(sampler, SequentialSampler):
return False
return isinstance(sampler, RandomSampler)
after:
def _is_dataloader_shuffled(dataloader: object) -> bool:
if hasattr(dataloader, "__pl_saved_kwargs"):
# this attribute is not part of PyTorch's DataLoader, but could have been set by
# our `_replace_init_method` context manager
if "shuffle" in dataloader.__pl_saved_kwargs:
return dataloader.__pl_saved_kwargs["shuffle"]
if "shuffle" in dataloader.__pl_saved_arg_names:
return dataloader.__pl_saved_args[dataloader.__pl_saved_arg_names.index("shuffle")]
if hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset):
# shuffling is useless with iterable datasets
return False
if not hasattr(dataloader, "sampler"):
# shuffling is enabled via a sampler. No sampler, no shuffling
return False
batch_sampler = dataloader.batch_sampler
if batch_sampler is not None:
sampler = batch_sampler.sampler
else:
sampler = dataloder.sampler
sampler_cls = type(sampler)
if sampler_cls not in (RandomSampler, SequentialSampler):
# custom sampler case:
if hasattr(sampler, "generator"):
# maybe custom random sampler
return True
else:
# we don't know
return False
if isinstance(sampler, SequentialSampler):
return False
return isinstance(sampler, RandomSampler)
What version are you seeing the problem on?
master
How to reproduce the bug
Firstly, define some customized BatchSampler like(or just use default BatchSampler):
class DynamicBatchSampler(BatchSampler):
def __init__(self, sampler: Sampler, batch_size: int, drop_last: bool, dataset):
super().__init__(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
self.dataset = dataset
def __iter__(self):
batch = []
for idx in self.sampler:
if not batch and getattr(self.dataset, 'dynamic_length', False):
min_len, max_len = self.dataset.min_length, self.dataset.max_length
length = random.randint(min_len, max_len)
else:
length = None
batch.append((idx, length))
if len(batch) == self.batch_size:
print(batch)
yield batch
batch = []
if batch and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return math.ceil(len(self.sampler) / self.batch_size)
Secondly, init the dataloader by BatchSampler like:
if cfg.shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
bsampler = DynamicBatchSampler(sampler, cfg.batch_size, cfg.drop_last, dataset)
dl = DataLoader(dataset, batch_sampler=bsampler, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory)
return dl
If you use the dl to init datamodule, the bug will occurred
cc @tchaton