Accumulating Batches (not Gradients) via custom Loops and avoid CUDA OOM #15116
Unanswered
myscience
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi everyone,
I am facing a CUDA Out-Of-Memory error when training my model on a custom dataset via DDP on 4 GPUs and would like to ear whether I am missing a simple solution here or where my mistake is.
The problem is that a single example in my dataset is quite big (a tensor of shape
(batch, 80, 105, 85)
) and the model itself has two sub-modules, one of which is aBEiT
transformer (taken from HuggingFace), which has~85.7M
parameter. The Lightning report estimates the full model to be185MB
. On a single GPU (which has 16GB of memory) I can fit a batch size of2
, which is too small as I am using a contrastive-learning approach where in each batch I need positive and negative examples and I think a batch size of128
should be more reasonable. My idea for solving this issue was the following. Before the loss computation, the model computes some vector representations of the data which are far smaller (tensors of shape(batch, 700)
), so I can accumulate some batches, collect the vector representations till I reach something like(128, 700)
and then compute the loss and update everything. The question is: "Does this makes sense? And if so, how can I achieve this sort of behavior?".As I understood it, the
Lightning API
easily offer gradient accumulation, but I fear it is not useful here. In gradient accumulation the loss is computed on the individual mini-batches separately and then the gradients are accumulated. For me this would result in very poor individual gradients. After some investigation I found out about the LightningLoop API
and I thought I could use that to fit my needs. The idea was to subclass theTrainingEpochLoop
and request multiple batches from thedata_fetcher
using a generator (so that we only have one or two examples in memory at a time) and use thelightning_model_hook
on_train_batch_start
to pre-process the batch and transform the(1, 80, 105, 85)
tensor into the more manageable(1, 700)
tensor and then start accumulating those. What my code is doing at the moment looks something like the following.In my
LightningModule
I have implemented theon_train_batch_start
hook as follows (note that theexample2latent
function is calling one submodule of my model):Finally in the main script I simply connect the custom loop as:
The problem with all of this is that if I try to use
accumulate = 16
for example (thus aiming for a final latent vector of shape(16, 700)
) I get the Out Of Memory Error I mentioned at the beginning. How can it be? Is this whole logic wrong? Do you guys have a more general suggestion on how to tackle this problem? Thanks!P.S. I also tried to turn my
torch.Dataset
into atorch.IterableDataset
and mess around withprefect_factors
andnum_workers
and so on without luck.Beta Was this translation helpful? Give feedback.
All reactions