From 2031d8e0fba6f01d33b34a32c1636eea1610938a Mon Sep 17 00:00:00 2001 From: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com> Date: Thu, 4 Dec 2025 14:52:05 -0800 Subject: [PATCH] [NVBUG: 5701937]Clear GPU cache for 3D weight tensors Add GPU cache clearing for 3D weight tensors to prevent OOM issues. Signed-off-by: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com> --- modelopt/torch/export/quant_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index eee13dc51..a2efcd95d 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -756,6 +756,11 @@ def to_quantized_weight( if isinstance(weight, QTensorWrapper): return weight.data + if weight.dim() == 3: + # for MOE stacked weights + # Clear GPU cache to avoid pontential GPU OOM issues for large models. + clear_cuda_cache() + if quantization == QUANTIZATION_FP8: # Fix RuntimeError: Promotion for Float8 Types is not supported, attempted to promote Float8_e4m3fn and Float # in speculative decoding fp8 model export @@ -764,9 +769,6 @@ def to_quantized_weight( return weight if weight.dim() == 3: - # for MOE stacked weights - # Clear GPU cache to avoid pontential GPU OOM issues for large models. - clear_cuda_cache() return (weight / weights_scaling_factor.unsqueeze(-1)).to(torch.float8_e4m3fn) return (weight / weights_scaling_factor).to(torch.float8_e4m3fn)