Skip to content

Model Checkpointing + FSDP causes Cuda OOMΒ #20312

@profPlum

Description

@profPlum

Bug description

I'm using FSDP and model checkpointing (default settings for both). My model has 254 million parameters. I'm not sure why but when I run Trainer.fit() it will successfully run the first epoch, then it will hit a CUDA OOM on the first backward pass of the second epoch. And this problem goes away when I disable model checkpointing, making me think it is a bug with FSDP model checkpointing. After all why should model checkpointing cause a CUDA OOM? Furthermore the problem persists with both state_dict_type='sharded' and regular model checkpoints.

Also the fact that the OOM happens AFTER the model checkpointing (during the next backward()) makes me think that the model checkpointing could be causing some kind of memory leak?

What version are you seeing the problem on?

v2.3

Error Message(s) (abbreviated):

[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 239, in backward_fn
[rank4]: call._call_strategy_hook(self.trainer, "backward", loss, optimizer)
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 311, in _call_strategy_hook
[rank4]: output = fn(*args, **kwargs)
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 212, in backward
[rank4]: self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py", line 72, in backward
[rank4]: model.backward(tensor, *args, **kwargs)
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1103, in backward
[rank4]: loss.backward(*args, **kwargs)
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward
[rank4]: torch.autograd.backward(
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/torch/autograd/init.py", line 267, in backward
[rank4]: _engine_run_backward(
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank4]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank4]: torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 7.57 GiB. GPU

cc @lantiga

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions