Skip to content

Commit b810098

Browse files
awaelchlilantiga
authored andcommitted
Fix overlapping samples in DDP when no global seed is set (#17713)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 53815e6)
1 parent 64d84cc commit b810098

File tree

3 files changed

+43
-3
lines changed

3 files changed

+43
-3
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3838
- Fixed inconsistent settings for FSDP Precision ([#17670](https://github.com/Lightning-AI/lightning/issues/17670))
3939

4040

41+
- Fixed an edge case causing overlapping samples in DDP when no global seed is set ([#17713](https://github.com/Lightning-AI/lightning/pull/17713))
42+
43+
4144
## [2.0.2] - 2023-04-24
4245

4346
### Fixed

src/lightning/pytorch/trainer/connectors/data_connector.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from dataclasses import dataclass, field
1717
from typing import Any, Iterable, Optional, Tuple, Union
1818

19-
from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler
19+
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler
2020
from torch.utils.data.distributed import DistributedSampler
2121

2222
import lightning.pytorch as pl
@@ -245,8 +245,11 @@ def _get_distributed_sampler(
245245
"""This function is used to created the distributed sampler injected within the user DataLoader."""
246246
kwargs["shuffle"] = shuffle and not overfit_batches
247247
kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0)))
248-
cls = UnrepeatedDistributedSamplerWrapper if mode == RunningStage.PREDICTING else DistributedSamplerWrapper
249-
return cls(dataloader.sampler, **kwargs)
248+
if mode == RunningStage.PREDICTING:
249+
return UnrepeatedDistributedSamplerWrapper(dataloader.sampler, **kwargs)
250+
if isinstance(dataloader.sampler, (RandomSampler, SequentialSampler)):
251+
return DistributedSampler(dataloader.dataset, **kwargs)
252+
return DistributedSamplerWrapper(dataloader.sampler, **kwargs)
250253

251254

252255
def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage) -> None:

tests/tests_pytorch/trainer/test_dataloaders.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,40 @@ def test_dataloader_distributed_sampler(tmpdir):
807807
trainer.test(model)
808808

809809

810+
class TestModelUniqueDDPSampling(BoringModel):
811+
def __init__(self):
812+
super().__init__()
813+
self.seen_samples = []
814+
815+
def training_step(self, batch):
816+
self.seen_samples.extend(batch.tolist())
817+
818+
def on_train_end(self):
819+
seen_samples = self.all_gather(self.seen_samples)
820+
# The samples should be unique across all processes
821+
assert set(torch.cat(seen_samples).view(-1).tolist()) == set(range(32))
822+
823+
824+
@RunIf(standalone=True)
825+
def test_distributed_sampler_without_global_seed(tmpdir):
826+
"""Test that the samples are non-overlapping in DDP when shuffling is enabled and no global seed is set."""
827+
# This test must run without a global seed set (e.g. through `seed_everything`), to ensure that each process
828+
# starts with a different initial state.
829+
assert "PL_GLOBAL_SEED" not in os.environ
830+
train_dataloader = DataLoader(range(32), shuffle=True, batch_size=4)
831+
trainer = Trainer(
832+
default_root_dir=tmpdir,
833+
num_sanity_val_steps=False,
834+
logger=False,
835+
enable_progress_bar=False,
836+
accelerator="cpu",
837+
devices=2,
838+
strategy="ddp",
839+
max_epochs=1,
840+
)
841+
trainer.fit(TestModelUniqueDDPSampling(), train_dataloader)
842+
843+
810844
class ModelWithDataLoaderDistributedSampler(BoringModel):
811845
def train_dataloader(self):
812846
dataloader = super().train_dataloader()

0 commit comments

Comments
 (0)