Skip to content

Commit 1524251

Browse files
authored
fix(svdquant): run SVD on GPU to maintain utilization (NVIDIA#633)
## What does this PR do? **Type of change:** Bug fix **Overview:** Fix SVD quantization running on CPU instead of GPU, which caused jobs to be killed by the internal job scheduler due to low GPU utilization during long-running SVD operations. ## Changes - Keep tensor on GPU during SVD computation by explicitly specifying device - Add `full_matrices=False` for faster computation (only need first `lowrank` singular vectors) ## Motivation The original implementation ran SVD on CPU, causing: 1. Jobs killed by internal scheduler due to low GPU utilization during SVD phase 2. Potential performance degradation from CPU-GPU data transfer overhead ## Usage No API changes. Existing svdquant usage remains the same. ## Testing - Tested locally with Wan2.2 SVD quantization - Verified job no longer killed due to low GPU utilization ## Before your PR is "*Ready for review*" - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: No Signed-off-by: Taekyung Heo <[email protected]>
1 parent 11728b7 commit 1524251

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,7 +1029,11 @@ def svdquant(
10291029

10301030
def postprocess(module, name):
10311031
print_rank_0(f"SVD {name}")
1032-
u, s, vt = torch.linalg.svd(module.weight.data.double())
1032+
weight = module.weight.data
1033+
original_device = weight.device
1034+
original_dtype = weight.dtype
1035+
weight_f64 = weight.to(dtype=torch.float64, device=original_device)
1036+
u, s, vt = torch.linalg.svd(weight_f64, full_matrices=False)
10331037
if u.shape[1] < lowrank or vt.shape[0] < lowrank:
10341038
warnings.warn(
10351039
"The low-rank dimensions do not match the layer dimensions. "
@@ -1039,9 +1043,12 @@ def postprocess(module, name):
10391043
return
10401044
us = u[:, :lowrank] * s[:lowrank]
10411045
vt = vt[:lowrank]
1042-
dtype = module.weight.dtype
1043-
module.weight_quantizer.svdquant_lora_a = vt.to(dtype=dtype)
1044-
module.weight_quantizer.svdquant_lora_b = us.to(dtype=dtype)
1046+
module.weight_quantizer.svdquant_lora_a = vt.to(
1047+
dtype=original_dtype, device=original_device
1048+
)
1049+
module.weight_quantizer.svdquant_lora_b = us.to(
1050+
dtype=original_dtype, device=original_device
1051+
)
10451052
module.weight.data.sub_(
10461053
module.weight_quantizer.svdquant_lora_b @ module.weight_quantizer.svdquant_lora_a
10471054
)

0 commit comments

Comments
 (0)