Skip to content

Custom Pytorch BatchSampler does not work well with pytorch lightning #20326

@dadwadw233

Description

@dadwadw233

Bug description

where is the bug❓

When I use a custom BatchSampler to initialize the DataLoader and use it with pytorch lightning's datamodule, I find that the shuffle settings don't take effect correctly, as evidenced by the fact that no matter how I set the sampler used to initialize the BatchSampler ( random or sequential), pytorch lightning sets the wrapped distributedsampler to the default option, i.e., shuffle by default for the training stage, and by the dataloader's sampler type for the other stages.

Analyse

The problem arises in the _is_dataloader_shuffled function (in pytorch_lightining.utilities.data) in the pytorch lightning code, where the selection of the shuffle state is based on the sampler state of the dataloader, which may seem like No problem, but in fact pytorch ignores the sampler when setting the BatchSampler (its sampler is set to the default sequential sampler), so pytorch lightning's behavior here will always get a sequential sampler which results in shuffle not working as I expected.

In fact I think the pytroch implementation is equally problematic, in the latest version of the pytorch code the dataloader property Sampler is kept mutually exclusive with BatchSampler, Shuffle etc. That is, when I use a custom BatchSampler, pytorch will only initialize a default SequentialSmapler, which is a bit counter-intuitive, but you don't get the wrong result by doing that, because pytorch chooses to use the batchsampler for data loading when it exists, and the sampler is only used when the batchsize is 1.

key code:
image

Suggestions

Since the problems with the pytorch code do not trigger the mentioned problem when pytorch lighting is not used, I would suggest a change to the pytorch lighting code:

before:

def _is_dataloader_shuffled(dataloader: object) -> bool:
    if hasattr(dataloader, "__pl_saved_kwargs"):
        # this attribute is not part of PyTorch's DataLoader, but could have been set by
        # our `_replace_init_method` context manager
        if "shuffle" in dataloader.__pl_saved_kwargs:
            return dataloader.__pl_saved_kwargs["shuffle"]
        if "shuffle" in dataloader.__pl_saved_arg_names:
            return dataloader.__pl_saved_args[dataloader.__pl_saved_arg_names.index("shuffle")]
    if hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset):
        # shuffling is useless with iterable datasets
        return False
    if not hasattr(dataloader, "sampler"):
        # shuffling is enabled via a sampler. No sampler, no shuffling
        return False
   
    sampler = dataloder.sampler
    if isinstance(sampler, SequentialSampler):
        return False
    return isinstance(sampler, RandomSampler)

after:

def _is_dataloader_shuffled(dataloader: object) -> bool:
    if hasattr(dataloader, "__pl_saved_kwargs"):
        # this attribute is not part of PyTorch's DataLoader, but could have been set by
        # our `_replace_init_method` context manager
        if "shuffle" in dataloader.__pl_saved_kwargs:
            return dataloader.__pl_saved_kwargs["shuffle"]
        if "shuffle" in dataloader.__pl_saved_arg_names:
            return dataloader.__pl_saved_args[dataloader.__pl_saved_arg_names.index("shuffle")]
    if hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset):
        # shuffling is useless with iterable datasets
        return False
    if not hasattr(dataloader, "sampler"):
        # shuffling is enabled via a sampler. No sampler, no shuffling
        return False
    
    batch_sampler = dataloader.batch_sampler
    if batch_sampler is not None:
        sampler = batch_sampler.sampler
    else:
        sampler = dataloder.sampler
        
    sampler_cls = type(sampler)
    if sampler_cls not in (RandomSampler, SequentialSampler):
        # custom sampler case:
        if hasattr(sampler, "generator"):
            # maybe custom random sampler
            return True
        else:
            # we don't know
            return False
        
    if isinstance(sampler, SequentialSampler):
        return False
    return isinstance(sampler, RandomSampler)

What version are you seeing the problem on?

master

How to reproduce the bug

Firstly, define some customized BatchSampler like(or just use default BatchSampler):

class DynamicBatchSampler(BatchSampler):
    def __init__(self, sampler: Sampler, batch_size: int, drop_last: bool, dataset):
        super().__init__(sampler, batch_size, drop_last)
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.dataset = dataset

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            if not batch and getattr(self.dataset, 'dynamic_length', False):
                min_len, max_len = self.dataset.min_length, self.dataset.max_length
                length = random.randint(min_len, max_len)
            else:
                length = None
            batch.append((idx, length))
            if len(batch) == self.batch_size:
                print(batch)
                yield batch
                batch = []
        if batch and not self.drop_last:
            yield batch

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return math.ceil(len(self.sampler) / self.batch_size)

Secondly, init the dataloader by BatchSampler like:

if cfg.shuffle:
        sampler = RandomSampler(dataset)
    else:
        sampler = SequentialSampler(dataset)
        
    bsampler = DynamicBatchSampler(sampler, cfg.batch_size, cfg.drop_last, dataset)
    dl = DataLoader(dataset, batch_sampler=bsampler, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory)

    return dl

If you use the dl to init datamodule, the bug will occurred

cc @tchaton

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions