GPU memory consumption fluctuates rapidly with FSDP training #13594
Unanswered
manideep2510
asked this question in
DDP / multi-GPU / multi-node
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, I'm trying to use FSDP training with Pytorch Lightning for a classification training task which has 1 Million classes in the classification layer, and the layer before that has 512 nodes. Without FSDP, I could not fit the model into 4 GPUs as expected, because there are around 512 Million parameters in the last layer. With FSDP, after model sharding of the last layer, I was able to fit the model and train it as well on these 4 GPUs.
But, with FSDP, the GPU memory consumption fluctuates rapidly between 6.5GB and 10GB. Because of this, I am not able to increase the batch size. Is this behavior expected, or is this a bug in Fairscale/Lightning, or am I using FSDP wrong? Any help will be appreciated.
System & Environment
4 x RTX 2080Ti, CUDA 11.6
12 Core Intel Xeon
RAM 64GB installed
Torch 1.12.0
Pytorch-lightning 1.6.4
fairscale 0.4.6
Python 3.9.13
Memory Consumption
Screen.Recording.2022-07-11.at.12.05.49.PM.mov
The part of Lightning Module where the Linear layer is wrapped with FSDP's
wrap()
,Beta Was this translation helpful? Give feedback.
All reactions