Skip to content

Commit 9a85e49

Browse files
authored
[NVBUG: 5612606] Clear GPU cache for large models layer quantization during export (#497)
## What does this PR do? **Type of change:** Bug fix **Overview:** ? For large models like llama4 maverick, the stacked weights to fp8 conversion might hit OOM. This change aim to fix that. --------- Signed-off-by: Chenjie Luo <[email protected]>
1 parent 5f0ef3b commit 9a85e49

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

modelopt/torch/export/quant_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
quantizer_attr_names,
3838
weight_attr_names,
3939
)
40+
from modelopt.torch.utils import clear_cuda_cache
4041

4142
from ..quantization.nn import SequentialQuantizer, TensorQuantizer
4243
from .model_config import (
@@ -763,6 +764,8 @@ def to_quantized_weight(
763764

764765
if weight.dim() == 3:
765766
# for MOE stacked weights
767+
# Clear GPU cache to avoid pontential GPU OOM issues for large models.
768+
clear_cuda_cache()
766769
return (weight / weights_scaling_factor.unsqueeze(-1)).to(torch.float8_e4m3fn)
767770
return (weight / weights_scaling_factor).to(torch.float8_e4m3fn)
768771

0 commit comments

Comments
 (0)