-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
Description
Bug description
Using the DeepSpeed Strategy
configuration
_target_: lightning.pytorch.strategies.DeepSpeedStrategy
zero_optimization: true
stage: 3
allgather_bucket_size: 2e8
reduce_bucket_size: 2e8
offload_optimizer: false
offload_parameters: false
partition_activations: false
cpu_checkpointing: false
contiguous_gradients: false
overlap_comm: false
I am experiencing an issue (specifically with DeepSpeed stage 3, not stages 1-2) where the tensors registered within sub-nn.Modules
of my LightningModule
's main lit_model.network
nn.Module
are not moved by register_buffer()
to the correct device upon training the lit_module.network
. In particular, I am trying to register buffers as
distance_bins_tensor = tensor([0.0, 1.0, 2.0, 3.0])
self.register_buffer("distance_bins", distance_bins_tensor)
within the various submodules of my lit_module.network
. When my optimizer tries to perform a step, I get the error
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:6 and cpu!
when trying to use these registered buffers e.g., by multiplying them by feature tensors loaded onto (in this case) cuda:6
.
What version are you seeing the problem on?
v2.4
How to reproduce the bug
No response
Error messages and logs
# Error messages and logs here please
Environment
Current environment
- CUDA:
- GPU:
- NVIDIA A100 80GB PCIe
- NVIDIA A100 80GB PCIe
- available: True
- version: 11.8 - Lightning:
- adam-atan2-pytorch: 0.0.10
- alphafold3-pytorch: 0.0.41
- alphafold3-pytorch-lightning-hydra: 0.1.111
- frame-averaging-pytorch: 0.0.19
- lightning: 2.4.0
- lightning-utilities: 0.11.6
- pytorch-lightning: 2.4.0
- rotary-embedding-torch: 0.6.1
- torch: 2.3.0+cu118
- torch-geometric: 2.5.3
- torchaudio: 2.3.0+cu118
- torchmetrics: 1.4.1
- torchtyping: 0.1.4
- torchvision: 0.18.0+cu118 - Packages:
- adam-atan2-pytorch: 0.0.10
- aiofiles: 23.2.1
- aiohttp: 3.9.5
- aiosignal: 1.3.1
- alembic: 1.13.1
- alphafold3-pytorch: 0.0.41
- alphafold3-pytorch-lightning-hydra: 0.1.111
- annotated-types: 0.7.0
- antlr4-python3-runtime: 4.9.3
- anyio: 4.4.0
- appdirs: 1.4.4
- argcomplete: 3.3.0
- asttokens: 2.4.1
- async-timeout: 4.0.3
- attrs: 23.2.0
- autopage: 0.5.2
- beartype: 0.18.5
- beautifulsoup4: 4.12.3
- biopandas: 0.5.1.dev0
- biopython: 1.83
- bioservices: 1.11.2
- cattrs: 23.2.3
- certifi: 2024.8.30
- cfgv: 3.4.0
- chardet: 5.2.0
- charset-normalizer: 3.3.2
- click: 8.1.7
- cliff: 4.7.0
- cmaes: 0.10.0
- cmd2: 2.4.3
- colorama: 0.4.6
- colorlog: 6.8.2
- colt5-attention: 0.11.0
- comm: 0.2.2
- contourpy: 1.2.1
- cycler: 0.12.1
- debugpy: 1.8.1
- decorator: 5.1.1
- deepdiff: 7.0.1
- deepspeed: 0.15.0
- distlib: 0.3.8
- docker-pycreds: 0.4.0
- easydev: 0.13.2
- einops: 0.8.0
- einx: 0.2.2
- environs: 11.0.0
- exceptiongroup: 1.2.1
- executing: 2.0.1
- fastapi: 0.112.2
- ffmpy: 0.4.0
- filelock: 3.13.1
- fonttools: 4.52.4
- frame-averaging-pytorch: 0.0.19
- freetype-py: 2.3.0
- frozendict: 2.4.4
- frozenlist: 1.4.1
- fsspec: 2024.2.0
- gemmi: 0.6.6
- gevent: 24.2.1
- gitdb: 4.0.11
- gitpython: 3.1.43
- gradio: 4.43.0
- gradio-client: 1.3.0
- gradio-molecule3d: 0.0.5
- graphein: 1.7.6
- greenlet: 3.0.3
- grequests: 0.7.0
- h11: 0.14.0
- hjson: 3.1.0
- httpcore: 1.0.5
- httpx: 0.27.2
- huggingface-hub: 0.23.4
- hydra-colorlog: 1.2.0
- hydra-core: 1.3.2
- hydra-optuna-sweeper: 1.2.0
- identify: 2.5.36
- idna: 3.7
- importlib-resources: 6.4.4
- iniconfig: 2.0.0
- ipykernel: 6.29.4
- ipython: 8.24.0
- jaxtyping: 0.2.28
- jedi: 0.19.1
- jinja2: 3.1.3
- joblib: 1.4.2
- jupyter-client: 8.6.2
- jupyter-core: 5.7.2
- kiwisolver: 1.4.5
- lightning: 2.4.0
- lightning-utilities: 0.11.6
- line-profiler: 4.1.3
- local-attention: 1.9.1
- loguru: 0.7.2
- looseversion: 1.1.2
- lxml: 5.2.2
- mako: 1.3.5
- markdown-it-py: 3.0.0
- markupsafe: 2.1.5
- marshmallow: 3.21.3
- matplotlib: 3.8.4
- matplotlib-inline: 0.1.7
- mdurl: 0.1.2
- mmtf-python: 1.1.3
- mpmath: 1.3.0
- msgpack: 1.0.8
- multidict: 6.0.5
- multipledispatch: 1.0.0
- munkres: 1.1.4
- nest-asyncio: 1.6.0
- networkx: 3.2.1
- ninja: 1.11.1.1
- nodeenv: 1.8.0
- numpy: 1.23.5
- nvidia-cublas-cu11: 11.11.3.6
- nvidia-cuda-cupti-cu11: 11.8.87
- nvidia-cuda-nvrtc-cu11: 11.8.89
- nvidia-cuda-runtime-cu11: 11.8.89
- nvidia-cudnn-cu11: 8.7.0.84
- nvidia-cufft-cu11: 10.9.0.58
- nvidia-curand-cu11: 10.3.0.86
- nvidia-cusolver-cu11: 11.4.1.48
- nvidia-cusparse-cu11: 11.7.5.86
- nvidia-ml-py: 12.560.30
- nvidia-nccl-cu11: 2.20.5
- nvidia-nvtx-cu11: 11.8.86
- omegaconf: 2.3.0
- optree: 0.11.0
- optuna: 2.10.1
- ordered-set: 4.1.0
- orjson: 3.10.7
- packaging: 24.0
- pandas: 1.5.3
- parso: 0.8.4
- pbr: 6.0.0
- pdbeccdutils: 0.8.5
- pexpect: 4.9.0
- pillow: 10.2.0
- pip: 24.0
- pipx: 1.5.0
- platformdirs: 4.2.2
- plotly: 5.22.0
- pluggy: 1.5.0
- polars: 1.3.0
- pre-commit: 3.7.1
- prettytable: 3.10.0
- prompt-toolkit: 3.0.45
- protobuf: 4.25.4
- psutil: 5.9.8
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- py-cpuinfo: 9.0.0
- pycairo: 1.26.0
- pydantic: 2.8.2
- pydantic-core: 2.20.1
- pydub: 0.25.1
- pygments: 2.18.0
- pyparsing: 3.1.2
- pyperclip: 1.8.2
- pytest: 8.2.1
- python-dateutil: 2.9.0
- python-dotenv: 1.0.1
- python-multipart: 0.0.9
- pytorch-lightning: 2.4.0
- pytz: 2024.1
- pyyaml: 6.0.1
- pyzmq: 26.0.3
- rdkit: 2024.3.2
- reportlab: 4.1.0
- requests: 2.32.2
- requests-cache: 1.2.0
- retrying: 1.3.4
- rich: 13.7.1
- rich-click: 1.8.2
- rlpycairo: 0.2.0
- rootutils: 1.0.7
- rotary-embedding-torch: 0.6.1
- ruff: 0.6.4
- scikit-learn: 1.5.0
- scipy: 1.13.1
- seaborn: 0.13.2
- semantic-version: 2.10.0
- sentry-sdk: 2.12.0
- setproctitle: 1.3.3
- setuptools: 70.0.0
- sh: 2.0.7
- shellingham: 1.5.4
- shortuuid: 1.0.13
- six: 1.16.0
- smmap: 5.0.1
- sniffio: 1.3.1
- soupsieve: 2.5
- sqlalchemy: 2.0.30
- stack-data: 0.6.3
- starlette: 0.38.4
- stevedore: 5.2.0
- suds-community: 1.1.2
- sympy: 1.12
- taylor-series-linear-attention: 0.1.12
- tenacity: 8.3.0
- threadpoolctl: 3.5.0
- timeout-decorator: 0.5.0
- tomli: 2.0.1
- tomlkit: 0.12.0
- torch: 2.3.0+cu118
- torch-geometric: 2.5.3
- torchaudio: 2.3.0+cu118
- torchmetrics: 1.4.1
- torchtyping: 0.1.4
- torchvision: 0.18.0+cu118
- tornado: 6.4
- tqdm: 4.66.4
- traitlets: 5.14.3
- triton: 2.3.0
- typeguard: 2.13.3
- typer: 0.12.5
- typing-extensions: 4.11.0
- tzdata: 2024.1
- unicodedata2: 15.1.0
- url-normalize: 1.4.3
- urllib3: 2.2.1
- userpath: 1.9.2
- uvicorn: 0.30.6
- virtualenv: 20.26.2
- wandb: 0.16.6
- wcwidth: 0.2.13
- websockets: 12.0
- wget: 3.2
- wheel: 0.43.0
- wrapt: 1.16.0
- xarray: 2024.3.0
- xmltodict: 0.13.0
- yarl: 1.9.4
- zope.event: 5.0
- zope.interface: 6.4.post2 - System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.10.14
- release: 4.18.0-553.16.1.el8_10.x86_64
- version: Proposal for help #1 SMP Thu Aug 8 07:11:46 EDT 2024
More info
No response
cc @lantiga