-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
I'm using FSDP and model checkpointing (default settings for both). My model has 254 million parameters. I'm not sure why but when I run Trainer.fit() it will successfully run the first epoch, then it will hit a CUDA OOM on the first backward pass of the second epoch. And this problem goes away when I disable model checkpointing, making me think it is a bug with FSDP model checkpointing. After all why should model checkpointing cause a CUDA OOM? Furthermore the problem persists with both state_dict_type='sharded'
and regular model checkpoints.
Also the fact that the OOM happens AFTER the model checkpointing (during the next backward()
) makes me think that the model checkpointing could be causing some kind of memory leak?
What version are you seeing the problem on?
v2.3
Error Message(s) (abbreviated):
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 239, in backward_fn
[rank4]: call._call_strategy_hook(self.trainer, "backward", loss, optimizer)
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 311, in _call_strategy_hook
[rank4]: output = fn(*args, **kwargs)
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 212, in backward
[rank4]: self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py", line 72, in backward
[rank4]: model.backward(tensor, *args, **kwargs)
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1103, in backward
[rank4]: loss.backward(*args, **kwargs)
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward
[rank4]: torch.autograd.backward(
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/torch/autograd/init.py", line 267, in backward
[rank4]: _engine_run_backward(
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank4]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank4]: torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 7.57 GiB. GPU
cc @lantiga