Skip to content

Commit 2f7b2eb

Browse files
committed
TensorQuantizer: remove validation from forward path to reduce cpu-gpu sync, new SVDQuantTensorQuantizer, minor code clean up
Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]>
1 parent 0178562 commit 2f7b2eb

File tree

4 files changed

+166
-150
lines changed

4 files changed

+166
-150
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 73 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import torch.nn.functional as F
2626

2727
from modelopt.torch.opt.searcher import ForwardLoop
28+
from modelopt.torch.utils import print_rank_0
2829
from modelopt.torch.utils.distributed import ParallelState
2930
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
3031

@@ -368,13 +369,13 @@ def postprocess(module):
368369
for name, module in model.named_modules():
369370
if is_quantized_linear(module):
370371
if not hasattr(module.input_quantizer, "_amax"):
371-
print(f"Warning: {name} is not calibrated, skip smoothing")
372+
warnings.warn(f"{name} is not calibrated, skip smoothing")
372373
continue
373374
if module.input_quantizer.num_bits != 8 or module.weight_quantizer.num_bits != 8:
374-
print(f"Warning: only int8 smoothing is supported, skip {name}")
375+
warnings.warn(f"Only int8 smoothing is supported, skip {name}")
375376
continue
376377
if module.input_quantizer.axis != -1:
377-
print(f"Warning: only per-channel smoothing is supported, skip {name}")
378+
warnings.warn(f"Only per-channel smoothing is supported, skip {name}")
378379
continue
379380

380381
assert module.input_quantizer._amax.numel() > 1, (
@@ -385,52 +386,7 @@ def postprocess(module):
385386
postprocess(module)
386387

387388
smoothed_modules += 1
388-
print(f"Smoothed {smoothed_modules} modules")
389-
390-
391-
def _smoothquant_fasteval(model: nn.Module):
392-
"""Hacky implementation of Smooth-Quant. Copied from monkey-quant."""
393-
smoothed_modules = 0
394-
for name, module in model.named_modules():
395-
if is_quantized_linear(module):
396-
if not hasattr(module.input_quantizer, "_amax"):
397-
print(f"Warning: {name} is not calibrated, skip smoothing")
398-
continue
399-
if module.input_quantizer.num_bits != 8 or module.weight_quantizer.num_bits != 8:
400-
print(f"Warning: only int8 smoothing is supported, skip {name}")
401-
continue
402-
if module.input_quantizer.axis != -1:
403-
print(f"Warning: only per-channel smoothing is supported, skip {name}")
404-
continue
405-
406-
assert module.input_quantizer._amax.numel() > 1
407-
delattr(module.weight_quantizer, "_amax")
408-
409-
# It is important to keep scaling math in fp32 to be numerically safe
410-
act_amax = module.input_quantizer.amax.float()
411-
if act_amax.shape[0] == 1:
412-
act_amax = act_amax.squeeze(0)
413-
# If model is split across devices, this tensor may be on wrong one
414-
act_amax = act_amax.to(module.weight.device)
415-
416-
max_bound = module.input_quantizer.maxbound
417-
scale_a = max_bound / act_amax
418-
# Some channel could have 0 amax which causes scale_a to overflow. Explicitly mask them out here
419-
epsilon = 1.0 / (1 << 31)
420-
if act_amax.min() <= epsilon:
421-
zero_mask = act_amax <= epsilon
422-
scale_a[zero_mask] = 1
423-
inv_scale_a = act_amax / max_bound
424-
425-
module.weight.data.copy_(
426-
(module.weight_quantizer(inv_scale_a * module.weight.float()) * scale_a).to(
427-
module.weight.dtype
428-
)
429-
)
430-
module.weight_quantizer.disable()
431-
432-
smoothed_modules += 1
433-
print(f"Smoothed {smoothed_modules} modules")
389+
print_rank_0(f"Smoothed {smoothed_modules} modules")
434390

435391

436392
def awq(
@@ -481,7 +437,9 @@ def awq_lite(
481437
See :class:`AWQLiteCalibConfig <modelopt.torch.quantization.config.AWQLiteCalibConfig>` for
482438
details on the remaining arguments.
483439
"""
484-
assert forward_loop is not None, "forward_loop must be provided for awq_lite"
440+
if forward_loop is None:
441+
warnings.warn("forward_loop must be provided for awq_lite; skipping awq_lite")
442+
return
485443

486444
class AWQLiteHelper:
487445
cache_mode: bool = False
@@ -493,11 +451,32 @@ def __init__(self, module, name):
493451
self.num_search_steps = 0
494452
self.block_size = _get_awq_quantizer_block_size(module.weight, module.weight_quantizer)
495453
self.weight_scale = get_weight_scale(module.weight, self.block_size)
496-
self.loss = {k.item(): 0.0 for k in torch.arange(0, 1.0 + alpha_step, alpha_step)}
454+
self.loss = {
455+
k.item(): torch.zeros((), device=module.weight.device, dtype=torch.float32)
456+
for k in torch.arange(0, 1.0 + alpha_step, alpha_step)
457+
}
497458
self.best_scale = None
498459
self.best_alpha = None
499460
self.is_input_quantized = module.input_quantizer.is_enabled
500461
self.num_tokens = 0
462+
self.module = module
463+
self.is_enabled = True
464+
465+
def setup(self):
466+
module = self.module
467+
bind_forward_method(module, forward, "_forward_no_awq")
468+
if module.input_quantizer.is_enabled:
469+
module.input_quantizer.disable()
470+
if module.input_quantizer.axis not in [None, -1]:
471+
self.is_enabled = False
472+
return
473+
module.input_quantizer.axis = -1
474+
475+
def cleanup(self):
476+
module = self.module
477+
if hasattr(module, "_if_calib"):
478+
delattr(module, "_if_calib")
479+
unpatch_forward_method(module, "_forward_no_awq")
501480

502481
def get_weight_scale(weight, block_size=None):
503482
org_shape = weight.shape
@@ -538,6 +517,8 @@ def update_loss(self, out, out_actual, alpha):
538517
self.awq_lite.loss[alpha] += loss
539518

540519
def update_best_params(self):
520+
if not self.awq_lite.is_enabled:
521+
return
541522
self.awq_lite.best_alpha = min(self.awq_lite.loss, key=self.awq_lite.loss.get)
542523
self.awq_lite.best_scale = get_scale(
543524
self.awq_lite.act_scale,
@@ -560,7 +541,8 @@ def forward(self, input, *args, **kwargs):
560541
out_actual = self._forward_no_awq(input, *args, **kwargs)
561542
self.weight_quantizer.enable()
562543

563-
if input.numel() == 0: # For MoEs, some experts might see 0 tokens
544+
if input.numel() == 0 or not self.awq_lite.is_enabled:
545+
# For MoEs, some experts might see 0 tokens
564546
return out_actual
565547

566548
if AWQLiteHelper.cache_mode:
@@ -588,6 +570,23 @@ def forward(self, input, *args, **kwargs):
588570
)
589571
self.input_quantizer.pre_quant_scale = (1 / awq_scale).to(self.weight.dtype)
590572
self.weight_quantizer.pre_quant_scale = awq_scale.to(self.weight.dtype)
573+
574+
disable_awq = False
575+
for tq in [self.input_quantizer, self.weight_quantizer]:
576+
for attr in ["_pre_quant_scale", "_amax"]:
577+
if not tq.validate_attr(attr_name=attr):
578+
disable_awq = True
579+
warnings.warn(
580+
f"awq_lite: {attr} is not valid for {self.awq_lite.name}, skipping awq_lite"
581+
)
582+
break
583+
if disable_awq:
584+
break
585+
586+
if disable_awq:
587+
self.awq_lite.is_enabled = False
588+
return out_actual
589+
591590
out = self._forward_no_awq(input, *args, **kwargs)
592591

593592
update_loss(self, out, out_actual, alpha)
@@ -601,19 +600,11 @@ def forward(self, input, *args, **kwargs):
601600
if is_quantized_linear(module) and module.weight_quantizer.is_enabled:
602601
with enable_weight_access_and_writeback(module, model):
603602
module.awq_lite = AWQLiteHelper(module, name)
604-
bind_forward_method(module, forward, "_forward_no_awq")
605-
606-
if module.input_quantizer.is_enabled:
607-
module.input_quantizer.disable()
608-
if module.input_quantizer.axis not in [None, -1]:
609-
raise NotImplementedError(
610-
"input quantization needs to be per-tensor or None for AWQ algorithm"
611-
)
612-
module.input_quantizer.axis = -1
603+
module.awq_lite.setup()
613604

614605
# Collect activation scale values
615606
AWQLiteHelper.cache_mode = True
616-
print("Caching activation statistics for awq_lite...")
607+
print_rank_0("awq_lite: Caching activation statistics...")
617608

618609
# Lets enable stats collection
619610
# This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
@@ -635,18 +626,17 @@ def forward(self, input, *args, **kwargs):
635626
module._if_calib = True
636627

637628
AWQLiteHelper.cache_mode = False
638-
print("Searching awq_lite parameters...")
629+
print_rank_0("awq_lite: Searching parameters...")
639630
with torch.no_grad():
640631
forward_loop(model)
641632

642-
def postprocess(module):
633+
def postprocess(module, name):
643634
update_best_params(module)
644635
if hasattr(module.weight_quantizer, "_pre_quant_scale"):
645636
delattr(module.weight_quantizer, "_pre_quant_scale")
646637
if hasattr(module.input_quantizer, "_pre_quant_scale"):
647638
delattr(module.input_quantizer, "_pre_quant_scale")
648-
if module.awq_lite.is_input_quantized:
649-
assert module.input_quantizer.amax is not None
639+
if module.awq_lite.is_input_quantized and module.input_quantizer.amax is not None:
650640
act_amax = module.input_quantizer.amax
651641
# TODO: make this a buffer after we support only heterogeneous checkpointing for MCore
652642
module.input_quantizer._amax_for_smoothing = act_amax.cpu()
@@ -655,25 +645,29 @@ def postprocess(module):
655645
module.input_quantizer.amax = act_amax.amax()
656646
module.input_quantizer.enable()
657647

658-
apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale)
648+
if module.awq_lite.is_enabled:
649+
apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale)
650+
else:
651+
warnings.warn(f"awq_lite: Disabling for {name}, quantizing with max calibration.")
652+
max_calibrate(module, lambda module: module.weight_quantizer(module.weight))
659653

660654
for name, module in model.named_modules():
661655
if hasattr(module, "awq_lite"):
662-
if module.awq_lite.num_cache_steps > 0:
663-
assert module.awq_lite.num_search_steps > 0, (
664-
"Calling `forward_loop(model)` the second time did not forward data through the"
665-
" model. Please provide a valid `forward_loop` function that can be used to"
656+
if module.awq_lite.num_cache_steps == 0:
657+
module.awq_lite.is_enabled = False
658+
elif module.awq_lite.num_search_steps == 0:
659+
module.awq_lite.is_enabled = False
660+
warnings.warn(
661+
"awq_lite: Calling `forward_loop(model)` the second time did not forward data through the"
662+
f" {name}. Please provide a valid `forward_loop` function that can be used to"
666663
" forward data through the model many times."
667664
)
668-
with enable_weight_access_and_writeback(module, model):
669-
postprocess(module)
665+
with enable_weight_access_and_writeback(module, model):
666+
postprocess(module, name)
670667

668+
module.awq_lite.cleanup()
671669
if not debug:
672670
delattr(module, "awq_lite")
673-
if hasattr(module, "_if_calib"):
674-
delattr(module, "_if_calib")
675-
676-
unpatch_forward_method(module, "_forward_no_awq")
677671

678672

679673
@torch.no_grad()
@@ -858,7 +852,7 @@ def forward(name, self, input, *args, **kwargs):
858852
with enable_weight_access_and_writeback(module, model):
859853
module.awq_clip = AWQClipHelper(module)
860854

861-
print("Estimating awq_clip parameters...")
855+
print_rank_0("awq_clip: Estimating parameters...")
862856
# Lets enable stats collection
863857
# This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
864858
enable_stats_collection(model)
@@ -919,7 +913,7 @@ def svdquant(
919913
"""
920914

921915
def postprocess(module, name):
922-
print(f"SVD {name}")
916+
print_rank_0(f"SVD {name}")
923917
u, s, vt = torch.linalg.svd(module.weight.data.double())
924918
if u.shape[1] < lowrank or vt.shape[0] < lowrank:
925919
warnings.warn(

modelopt/torch/quantization/model_quant.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ def forward_loop(model):
105105
mode_kwargs={"forward_loop": forward_loop},
106106
)
107107

108+
for name, module in model.named_modules():
109+
if isinstance(module, TensorQuantizer):
110+
for attr_name in ["_amax", "_pre_quant_scale"]:
111+
module.validate_attr(attr_name=attr_name, warn_error=True, name=name)
112+
108113
# TODO: Re-enable when the CUDA error: unspecified launch failure is fixed.
109114
# clear_cuda_cache()
110115

modelopt/torch/quantization/nn/modules/quant_linear.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,65 @@ class QuantLinear(_LegacyQuantLinearConvBaseMixin, nn.Linear):
6262
Linear = QuantLinear
6363

6464

65+
class SVDQuantTensorQuantizer(TensorQuantizer):
66+
"""TensorQuantizer with svdquant LoRA support."""
67+
68+
@property
69+
def svdquant_lora_a(self):
70+
"""Lora a weights for svdquant."""
71+
if not hasattr(self, "_svdquant_lora_a"):
72+
return None
73+
return self._svdquant_lora_a
74+
75+
@svdquant_lora_a.setter
76+
def svdquant_lora_a(self, value):
77+
"""Lora a weights for svdquant."""
78+
assert value is not None, "svdquant_lora_a cannot be set to None."
79+
80+
if not isinstance(value, torch.Tensor):
81+
value = torch.tensor(value)
82+
83+
if not hasattr(self, "_svdquant_lora_a"):
84+
self.register_buffer("_svdquant_lora_a", value.clone().detach())
85+
else:
86+
if self._svdquant_lora_a.shape != value.shape:
87+
raise RuntimeError("Changing shape when setting svdquant_lora_a is not allowed.")
88+
self._svdquant_lora_a.data.copy_(
89+
value.clone().detach().to(self._svdquant_lora_a.device)
90+
)
91+
92+
@property
93+
def svdquant_lora_b(self):
94+
"""Lora b weights for svdquant."""
95+
if not hasattr(self, "_svdquant_lora_b"):
96+
return None
97+
return self._svdquant_lora_b
98+
99+
@svdquant_lora_b.setter
100+
def svdquant_lora_b(self, value):
101+
"""Lora b weights for svdquant."""
102+
assert value is not None, "svdquant_lora_b cannot be set to None."
103+
104+
if not isinstance(value, torch.Tensor):
105+
value = torch.tensor(value)
106+
107+
if not hasattr(self, "_svdquant_lora_b"):
108+
self.register_buffer("_svdquant_lora_b", value.clone().detach())
109+
else:
110+
if self._svdquant_lora_b.shape != value.shape:
111+
raise RuntimeError("Changing shape when setting svdquant_lora_b is not allowed.")
112+
self._svdquant_lora_b.data.copy_(
113+
value.clone().detach().to(self._svdquant_lora_b.device)
114+
)
115+
116+
65117
class SVDQuantLinear(QuantLinearConvBase):
66118
"""Base class for quantized linear modules with SVDQuant."""
67119

120+
def _setup(self):
121+
"""Overrides and bypass the _setup function."""
122+
self.weight_quantizer.__class__ = SVDQuantTensorQuantizer
123+
68124
def _not_sequential_quantizers(self):
69125
return isinstance(self.weight_quantizer, TensorQuantizer) and isinstance(
70126
self.input_quantizer, TensorQuantizer
@@ -104,9 +160,6 @@ def forward(self, input, *args, **kwargs):
104160
output = super().forward(input, *args, **kwargs)
105161
return output
106162

107-
def _setup(self):
108-
"""Overrides and bypass the _setup function."""
109-
110163
def fold_weight(self):
111164
"""Fold the weight for faster eval."""
112165
super().fold_weight()

0 commit comments

Comments
 (0)