-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingdata handlingGeneric data-related topicGeneric data-related topicdistributedGeneric distributed-related topicGeneric distributed-related topicver: 2.5.x
Description
Bug description
When using PyTorch Lightning Trainer in DDP with use_distributed_sampler=True
and providing a custom sampler to the dataloader, the sampling order from the custom sampler is neglected. Instead, data is always shuffled.
One can track this issue down to this line where the custom sampler is wrapped with a DistributedSamplerWrapper
. However, the kwargs contain shuffle=True
due to this line.
When using a custom trainer in combination with fabric as suggested here, the sampling order from the custom sampler is respected. The reason for this is found here where fabric does not pass shuffle=True
as keyword argument.
What version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
from torch.utils.data import Dataset, Sampler, DataLoader
import lightning as L
import torch.nn as nn
import torch
class IntegerDataset(Dataset):
def __init__(self):
self.data = [i for i in range(100)]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return {"data": torch.tensor([self.data[idx]])}
class InOrderSampler(Sampler):
def __init__(self, dataset):
self.dataset = dataset
def __iter__(self):
yield from range(len(self.dataset))
def __len__(self):
return len(self.dataset)
dataset = IntegerDataset()
sampler = InOrderSampler(dataset)
dataloader = DataLoader(dataset=dataset, batch_size=3, sampler=sampler)
class MyModule(L.LightningModule):
def __init__(self):
super().__init__()
self.layer = nn.Linear(10, 10)
def training_step(self, batch, batch_idx):
print(batch)
input = torch.randn(10, 10, device = self.layer.weight.device)
output = self.layer(input)
loss = nn.functional.mse_loss(output, input)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
model = MyModule()
trainer = L.Trainer(
max_epochs=1,
use_distributed_sampler=True,
)
trainer.fit(
model=model,
train_dataloaders=dataloader,
)
Error messages and logs
The integers from the dataset are not printed in order, but are randomly shuffled.
Environment
Current environment
pytorch-lightning 2.5.1.post0
torch 2.7.1
torchmetrics 1.7.0
torchvision 0.22.1
More info
No response
bhimrazy
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingdata handlingGeneric data-related topicGeneric data-related topicdistributedGeneric distributed-related topicGeneric distributed-related topicver: 2.5.x