Skip to content

CUDA OOM when running QAT for gpt-oss-20b #329

@eduardzl

Description

@eduardzl

I am trying to run QAT on gpt-oss-20b model after FFT on instance with 8 A100 GPUs (8x80GB).
I am getting CUDA OOM error for both configurations : sft_full.yaml or sft_lora.yaml, if the max_length is > 4096, in my case 32k.

The command is (flash_attn is installed) :

accelerate launch --config_file configs/zero3.yaml sft.py  --config configs/sft_lora.yaml \
        --quant_cfg MXFP4_MLP_WEIGHT_ONLY_CFG --attn_implementation kernels-community/vllm-flash-attn3

Doesn't matter what, after couple of training steps, the training OOM.

12%|███████████████                                                                                                         | 1/8 [01:38<11:28, 98.31s/it][Rank 0] Memory usage at training step 1: memory (MB) | allocated:  6.44e+01 | max_allocated:  8.34e+03 | reserved:  5.54e+04 | max_reserved:  5.54e+04
{'loss': 7.1055, 'grad_norm': 7.644806500757786, 'learning_rate': 0.0, 'entropy': 2.1449486911296844, 'num_tokens': 138061.0, 'mean_token_accuracy': 0.1823525046929717, 'epoch': 0.13}
{'loss': 7.1342, 'grad_norm': 8.164548440793329, 'learning_rate': 0.0002, 'entropy': 2.1739209294319153, 'num_tokens': 278805.0, 'mean_token_accuracy': 0.1771835247054696, 'epoch': 0.26}
{'loss': 6.7401, 'grad_norm': 6.377030560710944, 'learning_rate': 0.00019108719811121772, 'entropy': 2.2128825187683105, 'num_tokens': 416105.0, 'mean_token_accuracy': 0.19561117608100176, 'epoch': 0.38}
{'loss': 6.1746, 'grad_norm': 5.935622939274496, 'learning_rate': 0.00016611408216728603, 'entropy': 2.240649074316025, 'num_tokens': 554550.0, 'mean_token_accuracy': 0.21982254460453987, 'epoch': 0.51}
 50%|████████████████████████████████████████████████████████████                                                            | 4/8 [06:28<06:27, 96.90s/it][rank1]: Traceback (most recent call last):
[rank1]:   File "/workspace/TensorRT-Model-Optimizer/examples/gpt-oss/sft.py", line 123, in <module>
[rank1]:     main(script_args, training_args, model_args, quant_args)
[rank1]:   File "/workspace/TensorRT-Model-Optimizer/examples/gpt-oss/sft.py", line 112, in main
[rank1]:     trainer.train()
[rank1]:   File "/workspace/quant-env/lib/python3.12/site-packages/modelopt/torch/quantization/plugins/transformers_trainer.py", line 254, in train
[rank1]:     outputs = super().train(*args, **kwargs)
[rank1]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/quant-env/lib/python3.12/site-packages/transformers/trainer.py", line 2328, in train
[rank1]:     return inner_training_loop(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/quant-env/lib/python3.12/site-packages/transformers/trainer.py", line 2672, in _inner_training_loop
[rank1]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank1]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/quant-env/lib/python3.12/site-packages/modelopt/torch/quantization/plugins/transformers_trainer.py", line 235, in training_step
[rank1]:     return super().training_step(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/quant-env/lib/python3.12/site-packages/trl/trainer/sft_trainer.py", line 1189, in training_step
[rank1]:     return super().training_step(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/quant-env/lib/python3.12/site-packages/transformers/trainer.py", line 4060, in training_step
[rank1]:     self.accelerator.backward(loss, **kwargs)
[rank1]:   File "/workspace/quant-env/lib/python3.12/site-packages/accelerate/accelerator.py", line 2726, in backward
[rank1]:     self.deepspeed_engine_wrapped.backward(loss, sync_gradients=self.sync_gradients, **kwargs)
[rank1]:   File "/workspace/quant-env/lib/python3.12/site-packages/accelerate/utils/deepspeed.py", line 270, in backward
[rank1]:     self.engine.backward(loss, **kwargs)
[rank1]:   File "/workspace/quant-env/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank1]:     ret_val = func(*args, **kwargs)
[rank1]:               ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/quant-env/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 2298, in backward
[rank1]:     self._do_optimizer_backward(loss, retain_graph)
[rank1]:   File "/workspace/quant-env/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 2244, in _do_optimizer_backward
[rank1]:     self.optimizer.backward(loss, retain_graph=retain_graph)
[rank1]:   File "/workspace/quant-env/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank1]:     ret_val = func(*args, **kwargs)
[rank1]:               ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/quant-env/lib/python3.12/site-packages/deepspeed/runtime/zero/stage3.py", line 2305, in backward
[rank1]:     self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
[rank1]:   File "/workspace/quant-env/lib/python3.12/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 65, in backward
[rank1]:     scaled_loss.backward(retain_graph=retain_graph)
[rank1]:   File "/workspace/quant-env/lib/python3.12/site-packages/torch/_tensor.py", line 647, in backward
[rank1]:     torch.autograd.backward(
[rank1]:   File "/workspace/quant-env/lib/python3.12/site-packages/torch/autograd/__init__.py", line 354, in backward
[rank1]:     _engine_run_backward(
[rank1]:   File "/workspace/quant-env/lib/python3.12/site-packages/torch/autograd/graph.py", line 829, in _engine_run_backward
[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 24.55 GiB. GPU 1 has a total capacity of 79.25 GiB of which 18.90 GiB is free. Including non-PyTorch memory, this process has 60.33 GiB memory in use. Of the allocated memory 56.34 GiB is allocated by PyTorch, and 2.42 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
W0917 10:24:13.135000 17232 torch/distributed/elastic/multiprocessing/api.py:900] Sending process 17457 closing signal SIGTERM
W0917 10:24:13.138000 17232 torch/distributed/elastic/multiprocessing/api.py:900] Sending process 17459 closing signal SIGTERM
W0917 10:24:13.140000 17232 torch/distributed/elastic/multiprocessing/api.py:900] Sending process 17460 closing signal SIGTERM
W0917 10:24:13.156000 17232 torch/distributed/elastic/multiprocessing/api.py:900] Sending process 17461 closing signal SIGTERM
W0917 10:24:13.159000 17232 torch/distributed/elastic/multiprocessing/api.py:900] Sending process 17462 closing signal SIGTERM
W0917 10:24:13.165000 17232 torch/distributed/elastic/multiprocessing/api.py:900] Sending process 17463 closing signal SIGTERM
W0917 10:24:13.167000 17232 torch/distributed/elastic/multiprocessing/api.py:900] Sending process 17464 closing signal SIGTERM

I was able to run FFT on the same instance for 32k context length with Axolotl. No CUDA OOMs..
When trying QAT, I am not even able to run LoRA training.

Env:
Latest code from "main" was used.
torch==2.8.0
CUDA 12.6
transformers==4.56.1

Any help would be greatly appreciated.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions