Skip to content

Commit 4623fb0

Browse files
committed
TensorQuantizer: remove validation from forward path to reduce cpu-gpu sync, new SVDQuantTensorQuantizer, minor code clean up
1 parent c0590b0 commit 4623fb0

File tree

4 files changed

+163
-149
lines changed

4 files changed

+163
-149
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 69 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import torch.nn.functional as F
2626

2727
from modelopt.torch.opt.searcher import ForwardLoop
28+
from modelopt.torch.utils import print_rank_0
2829
from modelopt.torch.utils.distributed import ParallelState
2930
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
3031

@@ -368,13 +369,13 @@ def postprocess(module):
368369
for name, module in model.named_modules():
369370
if is_quantized_linear(module):
370371
if not hasattr(module.input_quantizer, "_amax"):
371-
print(f"Warning: {name} is not calibrated, skip smoothing")
372+
print_rank_0(f"Warning: {name} is not calibrated, skip smoothing")
372373
continue
373374
if module.input_quantizer.num_bits != 8 or module.weight_quantizer.num_bits != 8:
374-
print(f"Warning: only int8 smoothing is supported, skip {name}")
375+
print_rank_0(f"Warning: only int8 smoothing is supported, skip {name}")
375376
continue
376377
if module.input_quantizer.axis != -1:
377-
print(f"Warning: only per-channel smoothing is supported, skip {name}")
378+
print_rank_0(f"Warning: only per-channel smoothing is supported, skip {name}")
378379
continue
379380

380381
assert module.input_quantizer._amax.numel() > 1, (
@@ -385,52 +386,7 @@ def postprocess(module):
385386
postprocess(module)
386387

387388
smoothed_modules += 1
388-
print(f"Smoothed {smoothed_modules} modules")
389-
390-
391-
def _smoothquant_fasteval(model: nn.Module):
392-
"""Hacky implementation of Smooth-Quant. Copied from monkey-quant."""
393-
smoothed_modules = 0
394-
for name, module in model.named_modules():
395-
if is_quantized_linear(module):
396-
if not hasattr(module.input_quantizer, "_amax"):
397-
print(f"Warning: {name} is not calibrated, skip smoothing")
398-
continue
399-
if module.input_quantizer.num_bits != 8 or module.weight_quantizer.num_bits != 8:
400-
print(f"Warning: only int8 smoothing is supported, skip {name}")
401-
continue
402-
if module.input_quantizer.axis != -1:
403-
print(f"Warning: only per-channel smoothing is supported, skip {name}")
404-
continue
405-
406-
assert module.input_quantizer._amax.numel() > 1
407-
delattr(module.weight_quantizer, "_amax")
408-
409-
# It is important to keep scaling math in fp32 to be numerically safe
410-
act_amax = module.input_quantizer.amax.float()
411-
if act_amax.shape[0] == 1:
412-
act_amax = act_amax.squeeze(0)
413-
# If model is split across devices, this tensor may be on wrong one
414-
act_amax = act_amax.to(module.weight.device)
415-
416-
max_bound = module.input_quantizer.maxbound
417-
scale_a = max_bound / act_amax
418-
# Some channel could have 0 amax which causes scale_a to overflow. Explicitly mask them out here
419-
epsilon = 1.0 / (1 << 31)
420-
if act_amax.min() <= epsilon:
421-
zero_mask = act_amax <= epsilon
422-
scale_a[zero_mask] = 1
423-
inv_scale_a = act_amax / max_bound
424-
425-
module.weight.data.copy_(
426-
(module.weight_quantizer(inv_scale_a * module.weight.float()) * scale_a).to(
427-
module.weight.dtype
428-
)
429-
)
430-
module.weight_quantizer.disable()
431-
432-
smoothed_modules += 1
433-
print(f"Smoothed {smoothed_modules} modules")
389+
print_rank_0(f"Smoothed {smoothed_modules} modules")
434390

435391

436392
def awq(
@@ -481,7 +437,9 @@ def awq_lite(
481437
See :class:`AWQLiteCalibConfig <modelopt.torch.quantization.config.AWQLiteCalibConfig>` for
482438
details on the remaining arguments.
483439
"""
484-
assert forward_loop is not None, "forward_loop must be provided for awq_lite"
440+
if forward_loop is None:
441+
warnings.warn("forward_loop must be provided for awq_lite; skipping awq_lite")
442+
return
485443

486444
class AWQLiteHelper:
487445
cache_mode: bool = False
@@ -498,6 +456,25 @@ def __init__(self, module, name):
498456
self.best_alpha = None
499457
self.is_input_quantized = module.input_quantizer.is_enabled
500458
self.num_tokens = 0
459+
self.module = module
460+
self.setup()
461+
self.is_enabled = True
462+
463+
def setup(self):
464+
module = self.module
465+
bind_forward_method(module, forward, "_forward_no_awq")
466+
if module.input_quantizer.is_enabled:
467+
module.input_quantizer.disable()
468+
if module.input_quantizer.axis not in [None, -1]:
469+
self.is_enabled = False
470+
return
471+
module.input_quantizer.axis = -1
472+
473+
def cleanup(self):
474+
module = self.module
475+
if hasattr(module, "_if_calib"):
476+
delattr(module, "_if_calib")
477+
unpatch_forward_method(module, "_forward_no_awq")
501478

502479
def get_weight_scale(weight, block_size=None):
503480
org_shape = weight.shape
@@ -538,6 +515,8 @@ def update_loss(self, out, out_actual, alpha):
538515
self.awq_lite.loss[alpha] += loss
539516

540517
def update_best_params(self):
518+
if not module.awq_lite.is_enabled:
519+
return
541520
self.awq_lite.best_alpha = min(self.awq_lite.loss, key=self.awq_lite.loss.get)
542521
self.awq_lite.best_scale = get_scale(
543522
self.awq_lite.act_scale,
@@ -560,7 +539,8 @@ def forward(self, input, *args, **kwargs):
560539
out_actual = self._forward_no_awq(input, *args, **kwargs)
561540
self.weight_quantizer.enable()
562541

563-
if input.numel() == 0: # For MoEs, some experts might see 0 tokens
542+
if input.numel() == 0 or not self.awq_lite.is_enabled:
543+
# For MoEs, some experts might see 0 tokens
564544
return out_actual
565545

566546
if AWQLiteHelper.cache_mode:
@@ -588,6 +568,23 @@ def forward(self, input, *args, **kwargs):
588568
)
589569
self.input_quantizer.pre_quant_scale = (1 / awq_scale).to(self.weight.dtype)
590570
self.weight_quantizer.pre_quant_scale = awq_scale.to(self.weight.dtype)
571+
572+
disable_awq = False
573+
for tq in [self.input_quantizer, self.weight_quantizer]:
574+
for attr in ["_pre_quant_scale", "_amax"]:
575+
if not tq.validate_attr(attr_name=attr):
576+
disable_awq = True
577+
warnings.warn(
578+
f"awq_lite: {attr} is not valid for {self.awq_lite.name}, skipping awq_lite"
579+
)
580+
break
581+
if disable_awq:
582+
break
583+
584+
if disable_awq:
585+
self.awq_lite.is_enabled = False
586+
return out_actual
587+
591588
out = self._forward_no_awq(input, *args, **kwargs)
592589

593590
update_loss(self, out, out_actual, alpha)
@@ -601,19 +598,10 @@ def forward(self, input, *args, **kwargs):
601598
if is_quantized_linear(module) and module.weight_quantizer.is_enabled:
602599
with enable_weight_access_and_writeback(module, model):
603600
module.awq_lite = AWQLiteHelper(module, name)
604-
bind_forward_method(module, forward, "_forward_no_awq")
605-
606-
if module.input_quantizer.is_enabled:
607-
module.input_quantizer.disable()
608-
if module.input_quantizer.axis not in [None, -1]:
609-
raise NotImplementedError(
610-
"input quantization needs to be per-tensor or None for AWQ algorithm"
611-
)
612-
module.input_quantizer.axis = -1
613601

614602
# Collect activation scale values
615603
AWQLiteHelper.cache_mode = True
616-
print("Caching activation statistics for awq_lite...")
604+
print_rank_0("awq_lite: Caching activation statistics...")
617605

618606
# Lets enable stats collection
619607
# This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
@@ -635,18 +623,17 @@ def forward(self, input, *args, **kwargs):
635623
module._if_calib = True
636624

637625
AWQLiteHelper.cache_mode = False
638-
print("Searching awq_lite parameters...")
626+
print_rank_0("awq_lite: Searching parameters...")
639627
with torch.no_grad():
640628
forward_loop(model)
641629

642-
def postprocess(module):
630+
def postprocess(module, name):
643631
update_best_params(module)
644632
if hasattr(module.weight_quantizer, "_pre_quant_scale"):
645633
delattr(module.weight_quantizer, "_pre_quant_scale")
646634
if hasattr(module.input_quantizer, "_pre_quant_scale"):
647635
delattr(module.input_quantizer, "_pre_quant_scale")
648-
if module.awq_lite.is_input_quantized:
649-
assert module.input_quantizer.amax is not None
636+
if module.awq_lite.is_input_quantized and module.input_quantizer.amax is not None:
650637
act_amax = module.input_quantizer.amax
651638
# TODO: make this a buffer after we support only heterogeneous checkpointing for MCore
652639
module.input_quantizer._amax_for_smoothing = act_amax.cpu()
@@ -655,25 +642,29 @@ def postprocess(module):
655642
module.input_quantizer.amax = act_amax.amax()
656643
module.input_quantizer.enable()
657644

658-
apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale)
645+
if module.awq_lite.is_enabled:
646+
apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale)
647+
else:
648+
warnings.warn(f"awq_lite: Disabling for {name}, quantizing with max calibration.")
649+
max_calibrate(module, lambda module: module.weight_quantizer(module.weight))
659650

660651
for name, module in model.named_modules():
661652
if hasattr(module, "awq_lite"):
662-
if module.awq_lite.num_cache_steps > 0:
663-
assert module.awq_lite.num_search_steps > 0, (
664-
"Calling `forward_loop(model)` the second time did not forward data through the"
665-
" model. Please provide a valid `forward_loop` function that can be used to"
653+
if module.awq_lite.num_cache_steps == 0:
654+
module.awq_lite.is_enabled = False
655+
elif module.awq_lite.num_search_steps == 0:
656+
module.awq_lite.is_enabled = False
657+
warnings.warn(
658+
"awq_lite: Calling `forward_loop(model)` the second time did not forward data through the"
659+
f" {name}. Please provide a valid `forward_loop` function that can be used to"
666660
" forward data through the model many times."
667661
)
668-
with enable_weight_access_and_writeback(module, model):
669-
postprocess(module)
662+
with enable_weight_access_and_writeback(module, model):
663+
postprocess(module, name)
670664

665+
module.awq_lite.cleanup()
671666
if not debug:
672667
delattr(module, "awq_lite")
673-
if hasattr(module, "_if_calib"):
674-
delattr(module, "_if_calib")
675-
676-
unpatch_forward_method(module, "_forward_no_awq")
677668

678669

679670
@torch.no_grad()
@@ -858,7 +849,7 @@ def forward(name, self, input, *args, **kwargs):
858849
with enable_weight_access_and_writeback(module, model):
859850
module.awq_clip = AWQClipHelper(module)
860851

861-
print("Estimating awq_clip parameters...")
852+
print_rank_0("awq_clip: Estimating parameters...")
862853
# Lets enable stats collection
863854
# This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
864855
enable_stats_collection(model)
@@ -919,7 +910,7 @@ def svdquant(
919910
"""
920911

921912
def postprocess(module, name):
922-
print(f"SVD {name}")
913+
print_rank_0(f"SVD {name}")
923914
u, s, vt = torch.linalg.svd(module.weight.data.double())
924915
if u.shape[1] < lowrank or vt.shape[0] < lowrank:
925916
warnings.warn(

modelopt/torch/quantization/model_quant.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ def forward_loop(model):
105105
mode_kwargs={"forward_loop": forward_loop},
106106
)
107107

108+
for name, module in model.named_modules():
109+
if isinstance(module, TensorQuantizer):
110+
for attr_name in ["_amax", "_bias", "_pre_quant_scale"]:
111+
module.validate_attr(attr_name=attr_name, raise_error=True, name=name)
112+
108113
# TODO: Re-enable when the CUDA error: unspecified launch failure is fixed.
109114
# clear_cuda_cache()
110115

@@ -136,7 +141,7 @@ def quantize(
136141
"""Quantizes and calibrates the model in-place.
137142
138143
This method performs replacement of modules with their quantized counterparts and
139-
performs calibration as specified by ``quant_cfg``.
144+
performs calibration as specified by ``quaTruent_cfg``.
140145
``forward_loop`` is used to forward data through the model and gather statistics for calibration.
141146
142147
Args:

modelopt/torch/quantization/nn/modules/quant_linear.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,65 @@ class QuantLinear(_LegacyQuantLinearConvBaseMixin, nn.Linear):
6262
Linear = QuantLinear
6363

6464

65+
class SVDQuantTensorQuantizer(TensorQuantizer):
66+
"""TensorQuantizer with svdquant LoRA support."""
67+
68+
@property
69+
def svdquant_lora_a(self):
70+
"""Lora a weights for svdquant."""
71+
if not hasattr(self, "_svdquant_lora_a"):
72+
return None
73+
return self._svdquant_lora_a
74+
75+
@svdquant_lora_a.setter
76+
def svdquant_lora_a(self, value):
77+
"""Lora a weights for svdquant."""
78+
assert value is not None, "svdquant_lora_a cannot be set to None."
79+
80+
if not isinstance(value, torch.Tensor):
81+
value = torch.tensor(value)
82+
83+
if not hasattr(self, "_svdquant_lora_a"):
84+
self.register_buffer("_svdquant_lora_a", value.clone().detach())
85+
else:
86+
if self._svdquant_lora_a.shape != value.shape:
87+
raise RuntimeError("Changing shape when setting svdquant_lora_a is not allowed.")
88+
self._svdquant_lora_a.data.copy_(
89+
value.clone().detach().to(self._svdquant_lora_a.device)
90+
)
91+
92+
@property
93+
def svdquant_lora_b(self):
94+
"""Lora b weights for svdquant."""
95+
if not hasattr(self, "_svdquant_lora_b"):
96+
return None
97+
return self._svdquant_lora_b
98+
99+
@svdquant_lora_b.setter
100+
def svdquant_lora_b(self, value):
101+
"""Lora b weights for svdquant."""
102+
assert value is not None, "svdquant_lora_b cannot be set to None."
103+
104+
if not isinstance(value, torch.Tensor):
105+
value = torch.tensor(value)
106+
107+
if not hasattr(self, "_svdquant_lora_b"):
108+
self.register_buffer("_svdquant_lora_b", value.clone().detach())
109+
else:
110+
if self._svdquant_lora_b.shape != value.shape:
111+
raise RuntimeError("Changing shape when setting svdquant_lora_b is not allowed.")
112+
self._svdquant_lora_b.data.copy_(
113+
value.clone().detach().to(self._svdquant_lora_b.device)
114+
)
115+
116+
65117
class SVDQuantLinear(QuantLinearConvBase):
66118
"""Base class for quantized linear modules with SVDQuant."""
67119

120+
def _setup(self):
121+
"""Overrides and bypass the _setup function."""
122+
self.weight_quantizer.__class__ = SVDQuantTensorQuantizer
123+
68124
def _not_sequential_quantizers(self):
69125
return isinstance(self.weight_quantizer, TensorQuantizer) and isinstance(
70126
self.input_quantizer, TensorQuantizer
@@ -104,9 +160,6 @@ def forward(self, input, *args, **kwargs):
104160
output = super().forward(input, *args, **kwargs)
105161
return output
106162

107-
def _setup(self):
108-
"""Overrides and bypass the _setup function."""
109-
110163
def fold_weight(self):
111164
"""Fold the weight for faster eval."""
112165
super().fold_weight()

0 commit comments

Comments
 (0)