Skip to content

Commit 9cd0824

Browse files
Make all tensors on same device for svdquant with cpu-offloading (NVIDIA#550)
## What does this PR do? **Type of change:** Bug Fix **Overview:** ? While running SVDQuant with cpu-offloading enabled using diffuser-ptq example (sd3.5-medium model), error about "not all tensors on same device" were observed at following steps: 1. awq-scale computation - get_scale() using x_max and w_max 2. loss update for each alpha - update_loss() 3. _apply_weight_pre_quant_scale() - while multiplying with pre-quant-scale 4. apply_pre_quant_scale_and_smooth() - while multiplying with pre-quant-scale These errors should also be seen with flux model - with SVDQuant and cpu-offloading enabled. So, in this change, updating above places to ensure that concerned tensors are on same device. Using ".to(device)" for this effect. ## Testing - Tried SVDQuant with cpu-offloading enabled - with sd3.5-medium, on RTX 5090, Windows 11 22621. With this change, final ONNX model (transformer) was produced without any error. ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: vipandya <[email protected]>
1 parent e3e399a commit 9cd0824

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,9 @@ def disable_pre_quant_scale_and_resmooth(linear: nn.Module, delete_pre_quant_sca
251251
def _apply_weight_pre_quant_scale(linear, pre_quant_scale):
252252
if _ENABLE_FOLDING_PQS_TO_WEIGHTS:
253253
linear.weight.data.copy_(
254-
(linear.weight * pre_quant_scale.squeeze()[None, :]).to(linear.weight.dtype)
254+
(linear.weight * pre_quant_scale.to(linear.weight.device).squeeze()[None, :]).to(
255+
linear.weight.dtype
256+
)
255257
)
256258
else:
257259
linear.weight_quantizer._enable_pre_quant_scale = True
@@ -300,7 +302,9 @@ def apply_pre_quant_scale_and_smooth(
300302
_amax_for_smoothing = linear.input_quantizer._amax_for_smoothing.to(
301303
device=device, dtype=dtype
302304
)
303-
linear.input_quantizer.amax = (_amax_for_smoothing * pre_quant_scale).amax().to(dtype)
305+
linear.input_quantizer.amax = (
306+
(_amax_for_smoothing * pre_quant_scale.to(device)).amax().to(dtype)
307+
)
304308

305309
if is_quantized_column_parallel_linear(linear) or is_quantized_row_parallel_linear(linear):
306310
linear.input_quantizer.sync_amax_across_distributed_group(
@@ -507,7 +511,10 @@ def get_act_scale(x):
507511

508512
def get_scale(x_max, w_max, alpha, tensor_parallel_group=None):
509513
scales = (
510-
(x_max.pow(alpha) / (w_max.pow(1 - alpha) + torch.finfo(torch.float32).tiny))
514+
(
515+
x_max.pow(alpha)
516+
/ (w_max.to(x_max.device).pow(1 - alpha) + torch.finfo(torch.float32).tiny)
517+
)
511518
.clamp(min=1e-4, max=1e4)
512519
.view(-1)
513520
)
@@ -521,7 +528,7 @@ def update_loss(self, out, out_actual, alpha):
521528
out_actual = out_actual[0] if isinstance(out_actual, tuple) else out_actual
522529
out = out[0] if isinstance(out, tuple) else out
523530
loss = (out - out_actual).float().pow(2).mean()
524-
self.awq_lite.loss[alpha] += loss
531+
self.awq_lite.loss[alpha] += loss.to(self.awq_lite.loss[alpha].device)
525532

526533
def update_best_params(self):
527534
if not self.awq_lite.is_enabled:

0 commit comments

Comments
 (0)