Skip to content

Commit b8e5ad2

Browse files
committed
TensorQuantizer: remove validation from forward path to reduce cpu-gpu sync, new SVDQuantTensorQuantizer, minor code clean up
Signed-off-by: realAsma <[email protected]>
1 parent d649fb8 commit b8e5ad2

File tree

4 files changed

+161
-152
lines changed

4 files changed

+161
-152
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 62 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import torch.nn.functional as F
2626

2727
from modelopt.torch.opt.searcher import ForwardLoop
28+
from modelopt.torch.utils import print_rank_0
2829
from modelopt.torch.utils.distributed import ParallelState
2930
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
3031

@@ -368,13 +369,13 @@ def postprocess(module):
368369
for name, module in model.named_modules():
369370
if is_quantized_linear(module):
370371
if not hasattr(module.input_quantizer, "_amax"):
371-
print(f"Warning: {name} is not calibrated, skip smoothing")
372+
warnings.warn(f"{name} is not calibrated, skip smoothing")
372373
continue
373374
if module.input_quantizer.num_bits != 8 or module.weight_quantizer.num_bits != 8:
374-
print(f"Warning: only int8 smoothing is supported, skip {name}")
375+
warnings.warn(f"Only int8 smoothing is supported, skip {name}")
375376
continue
376377
if module.input_quantizer.axis != -1:
377-
print(f"Warning: only per-channel smoothing is supported, skip {name}")
378+
warnings.warn(f"Only per-channel smoothing is supported, skip {name}")
378379
continue
379380

380381
assert module.input_quantizer._amax.numel() > 1, (
@@ -385,52 +386,7 @@ def postprocess(module):
385386
postprocess(module)
386387

387388
smoothed_modules += 1
388-
print(f"Smoothed {smoothed_modules} modules")
389-
390-
391-
def _smoothquant_fasteval(model: nn.Module):
392-
"""Hacky implementation of Smooth-Quant. Copied from monkey-quant."""
393-
smoothed_modules = 0
394-
for name, module in model.named_modules():
395-
if is_quantized_linear(module):
396-
if not hasattr(module.input_quantizer, "_amax"):
397-
print(f"Warning: {name} is not calibrated, skip smoothing")
398-
continue
399-
if module.input_quantizer.num_bits != 8 or module.weight_quantizer.num_bits != 8:
400-
print(f"Warning: only int8 smoothing is supported, skip {name}")
401-
continue
402-
if module.input_quantizer.axis != -1:
403-
print(f"Warning: only per-channel smoothing is supported, skip {name}")
404-
continue
405-
406-
assert module.input_quantizer._amax.numel() > 1
407-
delattr(module.weight_quantizer, "_amax")
408-
409-
# It is important to keep scaling math in fp32 to be numerically safe
410-
act_amax = module.input_quantizer.amax.float()
411-
if act_amax.shape[0] == 1:
412-
act_amax = act_amax.squeeze(0)
413-
# If model is split across devices, this tensor may be on wrong one
414-
act_amax = act_amax.to(module.weight.device)
415-
416-
max_bound = module.input_quantizer.maxbound
417-
scale_a = max_bound / act_amax
418-
# Some channel could have 0 amax which causes scale_a to overflow. Explicitly mask them out here
419-
epsilon = 1.0 / (1 << 31)
420-
if act_amax.min() <= epsilon:
421-
zero_mask = act_amax <= epsilon
422-
scale_a[zero_mask] = 1
423-
inv_scale_a = act_amax / max_bound
424-
425-
module.weight.data.copy_(
426-
(module.weight_quantizer(inv_scale_a * module.weight.float()) * scale_a).to(
427-
module.weight.dtype
428-
)
429-
)
430-
module.weight_quantizer.disable()
431-
432-
smoothed_modules += 1
433-
print(f"Smoothed {smoothed_modules} modules")
389+
print_rank_0(f"Smoothed {smoothed_modules} modules")
434390

435391

436392
def awq(
@@ -481,7 +437,9 @@ def awq_lite(
481437
See :class:`AWQLiteCalibConfig <modelopt.torch.quantization.config.AWQLiteCalibConfig>` for
482438
details on the remaining arguments.
483439
"""
484-
assert forward_loop is not None, "forward_loop must be provided for awq_lite"
440+
if forward_loop is None:
441+
warnings.warn("forward_loop must be provided for awq_lite; skipping awq_lite")
442+
return
485443

486444
class AWQLiteHelper:
487445
cache_mode: bool = False
@@ -493,11 +451,32 @@ def __init__(self, module, name):
493451
self.num_search_steps = 0
494452
self.block_size = _get_awq_quantizer_block_size(module.weight, module.weight_quantizer)
495453
self.weight_scale = get_weight_scale(module.weight, self.block_size)
496-
self.loss = {k.item(): 0.0 for k in torch.arange(0, 1.0 + alpha_step, alpha_step)}
454+
self.loss = {
455+
k.item(): torch.zeros((), device=module.weight.device, dtype=torch.float32)
456+
for k in torch.arange(0, 1.0 + alpha_step, alpha_step)
457+
}
497458
self.best_scale = None
498459
self.best_alpha = None
499460
self.is_input_quantized = module.input_quantizer.is_enabled
500461
self.num_tokens = 0
462+
self.module = module
463+
self.is_enabled = True
464+
465+
def setup(self):
466+
module = self.module
467+
bind_forward_method(module, forward, "_forward_no_awq")
468+
if module.input_quantizer.is_enabled:
469+
module.input_quantizer.disable()
470+
if module.input_quantizer.axis not in [None, -1]:
471+
self.is_enabled = False
472+
return
473+
module.input_quantizer.axis = -1
474+
475+
def cleanup(self):
476+
module = self.module
477+
if hasattr(module, "_if_calib"):
478+
delattr(module, "_if_calib")
479+
unpatch_forward_method(module, "_forward_no_awq")
501480

502481
def get_weight_scale(weight, block_size=None):
503482
org_shape = weight.shape
@@ -534,10 +513,13 @@ def get_scale(x_max, w_max, alpha, tensor_parallel_group=None):
534513
def update_loss(self, out, out_actual, alpha):
535514
out_actual = out_actual[0] if isinstance(out_actual, tuple) else out_actual
536515
out = out[0] if isinstance(out, tuple) else out
537-
loss = (out - out_actual).float().pow(2).mean().item()
516+
loss = (out - out_actual).float().pow(2).mean()
538517
self.awq_lite.loss[alpha] += loss
539518

540519
def update_best_params(self):
520+
if not self.awq_lite.is_enabled:
521+
return
522+
self.awq_lite.loss.update({k: float(v) for k, v in self.awq_lite.loss.items()})
541523
self.awq_lite.best_alpha = min(self.awq_lite.loss, key=self.awq_lite.loss.get)
542524
self.awq_lite.best_scale = get_scale(
543525
self.awq_lite.act_scale,
@@ -560,7 +542,8 @@ def forward(self, input, *args, **kwargs):
560542
out_actual = self._forward_no_awq(input, *args, **kwargs)
561543
self.weight_quantizer.enable()
562544

563-
if input.numel() == 0: # For MoEs, some experts might see 0 tokens
545+
if input.numel() == 0 or not self.awq_lite.is_enabled:
546+
# For MoEs, some experts might see 0 tokens
564547
return out_actual
565548

566549
if AWQLiteHelper.cache_mode:
@@ -589,7 +572,6 @@ def forward(self, input, *args, **kwargs):
589572
self.input_quantizer.pre_quant_scale = (1 / awq_scale).to(self.weight.dtype)
590573
self.weight_quantizer.pre_quant_scale = awq_scale.to(self.weight.dtype)
591574
out = self._forward_no_awq(input, *args, **kwargs)
592-
593575
update_loss(self, out, out_actual, alpha)
594576

595577
self.awq_lite.num_search_steps += 1
@@ -601,19 +583,11 @@ def forward(self, input, *args, **kwargs):
601583
if is_quantized_linear(module) and module.weight_quantizer.is_enabled:
602584
with enable_weight_access_and_writeback(module, model):
603585
module.awq_lite = AWQLiteHelper(module, name)
604-
bind_forward_method(module, forward, "_forward_no_awq")
605-
606-
if module.input_quantizer.is_enabled:
607-
module.input_quantizer.disable()
608-
if module.input_quantizer.axis not in [None, -1]:
609-
raise NotImplementedError(
610-
"input quantization needs to be per-tensor or None for AWQ algorithm"
611-
)
612-
module.input_quantizer.axis = -1
586+
module.awq_lite.setup()
613587

614588
# Collect activation scale values
615589
AWQLiteHelper.cache_mode = True
616-
print("Caching activation statistics for awq_lite...")
590+
print_rank_0("awq_lite: Caching activation statistics...")
617591

618592
# Lets enable stats collection
619593
# This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
@@ -631,22 +605,25 @@ def forward(self, input, *args, **kwargs):
631605
and module.awq_lite.num_cache_steps > 0
632606
):
633607
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps
608+
if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
609+
torch.isnan(module.awq_lite.weight_scale)
610+
):
611+
module.awq_lite.is_enabled = False
634612
# Hack: MoEs forward all tokens through all experts if _if_calib is True
635613
module._if_calib = True
636614

637615
AWQLiteHelper.cache_mode = False
638-
print("Searching awq_lite parameters...")
616+
print_rank_0("awq_lite: Searching parameters...")
639617
with torch.no_grad():
640618
forward_loop(model)
641619

642-
def postprocess(module):
620+
def postprocess(module, name):
643621
update_best_params(module)
644622
if hasattr(module.weight_quantizer, "_pre_quant_scale"):
645623
delattr(module.weight_quantizer, "_pre_quant_scale")
646624
if hasattr(module.input_quantizer, "_pre_quant_scale"):
647625
delattr(module.input_quantizer, "_pre_quant_scale")
648-
if module.awq_lite.is_input_quantized:
649-
assert module.input_quantizer.amax is not None
626+
if module.awq_lite.is_input_quantized and module.input_quantizer.amax is not None:
650627
act_amax = module.input_quantizer.amax
651628
# TODO: make this a buffer after we support only heterogeneous checkpointing for MCore
652629
module.input_quantizer._amax_for_smoothing = act_amax.cpu()
@@ -655,25 +632,29 @@ def postprocess(module):
655632
module.input_quantizer.amax = act_amax.amax()
656633
module.input_quantizer.enable()
657634

658-
apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale)
635+
if module.awq_lite.is_enabled:
636+
apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale)
637+
else:
638+
warnings.warn(f"awq_lite: Disabling for {name}, quantizing with max calibration.")
639+
max_calibrate(module, lambda module: module.weight_quantizer(module.weight))
659640

660641
for name, module in model.named_modules():
661642
if hasattr(module, "awq_lite"):
662-
if module.awq_lite.num_cache_steps > 0:
663-
assert module.awq_lite.num_search_steps > 0, (
664-
"Calling `forward_loop(model)` the second time did not forward data through the"
665-
" model. Please provide a valid `forward_loop` function that can be used to"
643+
if module.awq_lite.num_cache_steps == 0:
644+
module.awq_lite.is_enabled = False
645+
elif module.awq_lite.num_search_steps == 0:
646+
module.awq_lite.is_enabled = False
647+
warnings.warn(
648+
"awq_lite: Calling `forward_loop(model)` the second time did not forward data through the"
649+
f" {name}. Please provide a valid `forward_loop` function that can be used to"
666650
" forward data through the model many times."
667651
)
668-
with enable_weight_access_and_writeback(module, model):
669-
postprocess(module)
652+
with enable_weight_access_and_writeback(module, model):
653+
postprocess(module, name)
670654

655+
module.awq_lite.cleanup()
671656
if not debug:
672657
delattr(module, "awq_lite")
673-
if hasattr(module, "_if_calib"):
674-
delattr(module, "_if_calib")
675-
676-
unpatch_forward_method(module, "_forward_no_awq")
677658

678659

679660
@torch.no_grad()
@@ -858,7 +839,7 @@ def forward(name, self, input, *args, **kwargs):
858839
with enable_weight_access_and_writeback(module, model):
859840
module.awq_clip = AWQClipHelper(module)
860841

861-
print("Estimating awq_clip parameters...")
842+
print_rank_0("awq_clip: Estimating parameters...")
862843
# Lets enable stats collection
863844
# This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
864845
enable_stats_collection(model)
@@ -919,7 +900,7 @@ def svdquant(
919900
"""
920901

921902
def postprocess(module, name):
922-
print(f"SVD {name}")
903+
print_rank_0(f"SVD {name}")
923904
u, s, vt = torch.linalg.svd(module.weight.data.double())
924905
if u.shape[1] < lowrank or vt.shape[0] < lowrank:
925906
warnings.warn(

modelopt/torch/quantization/model_quant.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ def forward_loop(model):
105105
mode_kwargs={"forward_loop": forward_loop},
106106
)
107107

108+
for name, module in model.named_modules():
109+
if isinstance(module, TensorQuantizer):
110+
for attr_name in ["_amax", "_pre_quant_scale"]:
111+
module.validate_attr(attr_name=attr_name, warn_error=True, name=name)
112+
108113
# TODO: Re-enable when the CUDA error: unspecified launch failure is fixed.
109114
# clear_cuda_cache()
110115

modelopt/torch/quantization/nn/modules/quant_linear.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,67 @@ class QuantLinear(_LegacyQuantLinearConvBaseMixin, nn.Linear):
6262
Linear = QuantLinear
6363

6464

65+
class SVDQuantTensorQuantizer(TensorQuantizer):
66+
"""TensorQuantizer with svdquant LoRA support."""
67+
68+
@property
69+
def svdquant_lora_a(self):
70+
"""Lora a weights for svdquant."""
71+
if not hasattr(self, "_svdquant_lora_a"):
72+
return None
73+
return self._svdquant_lora_a
74+
75+
@svdquant_lora_a.setter
76+
def svdquant_lora_a(self, value):
77+
"""Lora a weights for svdquant."""
78+
assert value is not None, "svdquant_lora_a cannot be set to None."
79+
80+
if not isinstance(value, torch.Tensor):
81+
value = torch.tensor(value)
82+
83+
if not hasattr(self, "_svdquant_lora_a"):
84+
self.register_buffer("_svdquant_lora_a", value.clone().detach())
85+
else:
86+
if self._svdquant_lora_a.shape != value.shape:
87+
raise RuntimeError("Changing shape when setting svdquant_lora_a is not allowed.")
88+
self._svdquant_lora_a.data.copy_(
89+
value.clone().detach().to(self._svdquant_lora_a.device)
90+
)
91+
92+
@property
93+
def svdquant_lora_b(self):
94+
"""Lora b weights for svdquant."""
95+
if not hasattr(self, "_svdquant_lora_b"):
96+
return None
97+
return self._svdquant_lora_b
98+
99+
@svdquant_lora_b.setter
100+
def svdquant_lora_b(self, value):
101+
"""Lora b weights for svdquant."""
102+
assert value is not None, "svdquant_lora_b cannot be set to None."
103+
104+
if not isinstance(value, torch.Tensor):
105+
value = torch.tensor(value)
106+
107+
if not hasattr(self, "_svdquant_lora_b"):
108+
self.register_buffer("_svdquant_lora_b", value.clone().detach())
109+
else:
110+
if self._svdquant_lora_b.shape != value.shape:
111+
raise RuntimeError("Changing shape when setting svdquant_lora_b is not allowed.")
112+
self._svdquant_lora_b.data.copy_(
113+
value.clone().detach().to(self._svdquant_lora_b.device)
114+
)
115+
116+
65117
class SVDQuantLinear(QuantLinearConvBase):
66118
"""Base class for quantized linear modules with SVDQuant."""
67119

120+
def _setup(self):
121+
"""Overrides and bypass the _setup function."""
122+
if isinstance(self.weight_quantizer, SVDQuantTensorQuantizer):
123+
return
124+
self.weight_quantizer.__class__ = SVDQuantTensorQuantizer
125+
68126
def _not_sequential_quantizers(self):
69127
return isinstance(self.weight_quantizer, TensorQuantizer) and isinstance(
70128
self.input_quantizer, TensorQuantizer
@@ -104,9 +162,6 @@ def forward(self, input, *args, **kwargs):
104162
output = super().forward(input, *args, **kwargs)
105163
return output
106164

107-
def _setup(self):
108-
"""Overrides and bypass the _setup function."""
109-
110165
def fold_weight(self):
111166
"""Fold the weight for faster eval."""
112167
super().fold_weight()

0 commit comments

Comments
 (0)