Skip to content

Lightning+DDP: use_distributed_sampler=True always shuffles data in DDP despite using custom samplerΒ #21131

@dmholtz

Description

@dmholtz

Bug description

When using PyTorch Lightning Trainer in DDP with use_distributed_sampler=True and providing a custom sampler to the dataloader, the sampling order from the custom sampler is neglected. Instead, data is always shuffled.

One can track this issue down to this line where the custom sampler is wrapped with a DistributedSamplerWrapper. However, the kwargs contain shuffle=True due to this line.

When using a custom trainer in combination with fabric as suggested here, the sampling order from the custom sampler is respected. The reason for this is found here where fabric does not pass shuffle=True as keyword argument.

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

from torch.utils.data import Dataset, Sampler, DataLoader
import lightning as L
import torch.nn as nn
import torch

class IntegerDataset(Dataset):
    def __init__(self):
        self.data = [i for i in range(100)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {"data": torch.tensor([self.data[idx]])}

class InOrderSampler(Sampler):
    def __init__(self, dataset):
        self.dataset = dataset

    def __iter__(self):   
       yield from range(len(self.dataset))

    def __len__(self):
        return len(self.dataset)
    
dataset = IntegerDataset()
sampler = InOrderSampler(dataset)
dataloader = DataLoader(dataset=dataset, batch_size=3, sampler=sampler)

class MyModule(L.LightningModule):

    def __init__(self):
        super().__init__()

        self.layer = nn.Linear(10, 10)

    def training_step(self, batch, batch_idx):
        print(batch)

        input = torch.randn(10, 10, device = self.layer.weight.device)
        output = self.layer(input)
        loss = nn.functional.mse_loss(output, input)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

model = MyModule()

trainer = L.Trainer(
    max_epochs=1,
    use_distributed_sampler=True,
)
trainer.fit(
    model=model,
    train_dataloaders=dataloader,
)

Error messages and logs

The integers from the dataset are not printed in order, but are randomly shuffled.

Environment

Current environment
pytorch-lightning         2.5.1.post0
torch                     2.7.1
torchmetrics              1.7.0
torchvision               0.22.1

More info

No response

cc @justusschock @tchaton

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingdata handlingGeneric data-related topicdistributedGeneric distributed-related topicver: 2.5.x

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions