-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
bugSomething isn't workingSomething isn't workingstrategy: ddpDistributedDataParallelDistributedDataParallelver: 2.2.x
Description
Bug description
The batches and their order are the same across different executions of the script when using strategy='ddp' and dataloader with shuffle=True
What version are you seeing the problem on?
v2.2
How to reproduce the bug
Say you have train.py that prints the current input on each training iteration and has shuffling enabled in the
dataloader:
import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F
import lightning.pytorch as pl
class SomeLightningModule(pl.LightningModule):
def __init__(self):
super().__init__()
self.p1 = torch.nn.Parameter(torch.tensor(0.0))
self.p2 = torch.nn.Parameter(torch.tensor(0.0))
def training_step(self, batch):
x, y = batch
print(x.item())
return F.mse_loss(x * self.p1 + self.p2, y)
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(),
)
return {
"optimizer": optimizer,
}
lightning_module = SomeLightningModule()
trainer = pl.Trainer(
strategy='ddp',
max_epochs=1,
)
train_dataset = TensorDataset(torch.arange(5).float(), torch.arange(5).float())
train_loader = DataLoader(train_dataset, shuffle=True)
trainer.fit(lightning_module, train_dataloaders=train_loader)
When strategy='ddp', the script will print the same numbers across different runs:
$ python3 train.py
4.0
0.0
1.0
3.0
2.0
$ python3 train.py
4.0
0.0
1.0
3.0
2.0
Such behavior can be unwanted, as people might want to try different orders of batches (e.g. to construct ensembles or get the average performance)
Error messages and logs
# Error messages and logs here please
Environment
Current environment
- CUDA:
- GPU:
- Graphics Device
- available: True
- version: 11.8
- GPU:
- Lightning:
- lightning: 2.2.0.post0
- lightning-utilities: 0.10.1
- pytorch-lightning: 1.7.7
- torch: 2.1.2
- torchaudio: 2.1.2
- torchmetrics: 0.10.3
- torchvision: 0.16.2
- Packages:
- absl-py: 1.3.0
- aiohttp: 3.8.3
- aiosignal: 1.3.1
- alphafold-colabfold: 2.3.6
- altair: 5.4.0
- anarci: 1.3
- antiberty: 0.1.3
- antlr4-python3-runtime: 4.9.3
- anyio: 3.5.0
- appdirs: 1.4.4
- argon2-cffi: 21.3.0
- argon2-cffi-bindings: 21.2.0
- asttokens: 2.0.5
- astunparse: 1.6.3
- async-lru: 2.0.4
- async-timeout: 4.0.2
- attrs: 22.1.0
- babel: 2.11.0
- backcall: 0.2.0
- beautifulsoup4: 4.12.2
- biopython: 1.79
- bleach: 4.1.0
- blinker: 1.5
- bottleneck: 1.3.5
- brotlipy: 0.7.0
- cached-property: 1.5.2
- cachetools: 5.2.0
- certifi: 2023.5.7
- cffi: 1.15.1
- charset-normalizer: 2.1.1
- chex: 0.1.86
- click: 8.1.3
- cmake: 3.28.3
- colabfold: 1.5.5
- colorama: 0.4.6
- comm: 0.1.2
- contextlib2: 21.6.0
- contourpy: 1.0.6
- cryptography: 38.0.3
- cycler: 0.11.0
- debugpy: 1.6.7
- decorator: 5.1.1
- deepspeed: 0.9.5
- defusedxml: 0.7.1
- dm-haiku: 0.0.12
- dm-tree: 0.1.8
- docker-pycreds: 0.4.0
- docstring-parser: 0.15
- einops: 0.8.0
- entrypoints: 0.4
- et-xmlfile: 1.1.0
- etils: 1.5.2
- exceptiongroup: 1.0.4
- executing: 0.8.3
- fastjsonschema: 2.16.2
- filelock: 3.13.1
- flatbuffers: 24.3.25
- flax: 0.8.5
- fonttools: 4.38.0
- frozenlist: 1.3.3
- fsspec: 2024.3.1
- gast: 0.6.0
- gdown: 5.1.0
- gemmi: 0.5.7
- gitdb: 4.0.9
- gitpython: 3.1.29
- gmpy2: 2.1.2
- google-auth: 2.14.1
- google-auth-oauthlib: 0.4.6
- google-pasta: 0.2.0
- grpcio: 1.49.1
- h5py: 3.11.0
- hjson: 3.1.0
- huggingface-hub: 0.22.2
- hydra-core: 1.3.2
- idna: 3.4
- immutabledict: 4.2.0
- importlib-metadata: 4.13.0
- importlib-resources: 6.1.2
- ipykernel: 6.25.0
- ipython: 8.15.0
- ipython-genutils: 0.2.0
- ipywidgets: 8.0.4
- jax: 0.3.25
- jaxlib: 0.3.25+cuda11.cudnn82
- jedi: 0.18.1
- jinja2: 3.1.2
- jmp: 0.0.4
- json5: 0.9.6
- jsonargparse: 4.27.5
- jsonschema: 4.17.3
- jupyter: 1.0.0
- jupyter-client: 7.4.9
- jupyter-console: 6.6.3
- jupyter-core: 5.5.0
- jupyter-events: 0.6.3
- jupyter-lsp: 2.2.0
- jupyter-server: 2.10.0
- jupyter-server-terminals: 0.4.4
- jupyterlab: 4.0.8
- jupyterlab-pygments: 0.1.2
- jupyterlab-server: 2.22.0
- jupyterlab-widgets: 3.0.9
- keras: 3.4.1
- kiwisolver: 1.4.4
- libclang: 18.1.1
- lightning: 2.2.0.post0
- lightning-utilities: 0.10.1
- lit: 18.1.1
- markdown: 3.4.1
- markdown-it-py: 3.0.0
- markupsafe: 2.1.1
- matplotlib: 3.6.2
- matplotlib-inline: 0.1.6
- mdurl: 0.1.2
- mistune: 2.0.4
- mkl-fft: 1.3.1
- mkl-random: 1.2.2
- mkl-service: 2.4.0
- ml-collections: 0.1.1
- ml-dtypes: 0.3.2
- mmcif-pdbx: 2.0.1
- mpi4py: 3.1.4
- mpmath: 1.3.0
- msgpack: 1.0.8
- multidict: 6.0.2
- munkres: 1.1.4
- namex: 0.0.8
- narwhals: 1.5.0
- nbclient: 0.8.0
- nbconvert: 7.10.0
- nbformat: 5.9.2
- nest-asyncio: 1.5.6
- networkx: 3.1
- ninja: 1.11.1
- notebook: 6.3.0
- notebook-shim: 0.2.3
- numexpr: 2.8.4
- numpy: 1.23.5
- oauthlib: 3.2.2
- omegaconf: 2.3.0
- openpyxl: 3.1.5
- opt-einsum: 3.3.0
- optax: 0.2.2
- optree: 0.11.0
- orbax-checkpoint: 0.5.20
- overrides: 7.4.0
- packaging: 21.3
- pandas: 1.5.3
- pandocfilters: 1.5.0
- parso: 0.8.3
- path: 16.2.0
- pathtools: 0.1.2
- pdb2pqr: 3.6.1
- pexpect: 4.8.0
- pickleshare: 0.7.5
- pillow: 9.2.0
- pip: 22.3.1
- platformdirs: 3.10.0
- ply: 3.11
- pmw: 2.0.1
- pooch: 1.6.0
- prody: 2.2.0
- prometheus-client: 0.14.1
- promise: 2.3
- prompt-toolkit: 3.0.43
- propka: 3.5.1
- protobuf: 4.21.9
- psutil: 5.9.4
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- py-cpuinfo: 9.0.0
- py3dmol: 2.0.4
- pyasn1: 0.4.8
- pyasn1-modules: 0.3.0
- pycollada: 0.8
- pycparser: 2.21
- pydantic: 1.10.11
- pydeprecate: 0.3.2
- pygments: 2.15.1
- pyjwt: 2.6.0
- pykerberos: 1.2.4
- pymol: 2.5.5
- pyopenssl: 22.1.0
- pyparsing: 3.0.9
- pyqt5: 5.15.7
- pyqt5-sip: 12.11.0
- pyrsistent: 0.20.0
- pysocks: 1.7.1
- python-dateutil: 2.8.2
- python-json-logger: 2.0.7
- pytorch-lightning: 1.7.7
- pytz: 2022.7
- pyu2f: 0.1.5
- pyyaml: 6.0
- pyzmq: 25.1.0
- qtconsole: 5.5.1
- qtpy: 2.4.1
- regex: 2023.12.25
- requests: 2.28.1
- requests-oauthlib: 1.3.1
- rfc3339-validator: 0.1.4
- rfc3986-validator: 0.1.1
- rich: 13.7.1
- rjieba: 0.1.11
- rsa: 4.9
- safetensors: 0.4.2
- scipy: 1.10.1
- seaborn: 0.13.2
- send2trash: 1.8.2
- sentry-sdk: 1.11.0
- setproctitle: 1.3.2
- setuptools: 59.5.0
- shortuuid: 1.0.11
- sip: 6.7.12
- six: 1.16.0
- smmap: 3.0.5
- sniffio: 1.2.0
- soupsieve: 2.5
- stack-data: 0.2.0
- sympy: 1.12
- tabulate: 0.9.0
- tensorboard: 2.16.2
- tensorboard-data-server: 0.7.2
- tensorboard-plugin-wit: 1.8.1
- tensorflow-cpu: 2.16.2
- tensorflow-io-gcs-filesystem: 0.37.0
- tensorstore: 0.1.63
- termcolor: 2.4.0
- terminado: 0.17.1
- tinycss2: 1.2.1
- tmtools: 0.2.0
- tokenizers: 0.15.2
- toml: 0.10.2
- tomli: 2.0.1
- toolz: 0.12.0
- torch: 2.1.2
- torchaudio: 2.1.2
- torchmetrics: 0.10.3
- torchvision: 0.16.2
- tornado: 6.3.3
- tqdm: 4.64.1
- trainable-folding: 0.0.0
- traitlets: 5.7.1
- transformers: 4.39.3
- triton: 2.1.0
- tunedabs: 0.0.1
- typeshed-client: 2.5.1
- typing-extensions: 4.10.0
- unicodedata2: 15.0.0
- urllib3: 1.26.11
- wandb: 0.13.5
- wcwidth: 0.2.5
- webencodings: 0.5.1
- websocket-client: 0.58.0
- werkzeug: 2.2.2
- wheel: 0.40.0
- widgetsnbextension: 4.0.5
- wrapt: 1.16.0
- yarl: 1.8.1
- zipp: 3.10.0
- System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.9.13
- release: 3.10.0-693.17.1.el7.x86_64
- version: Proposal for help #1 SMP Thu Jan 25 20:13:58 UTC 2018
More info
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingstrategy: ddpDistributedDataParallelDistributedDataParallelver: 2.2.x