Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 62 additions & 81 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import torch.nn.functional as F

from modelopt.torch.opt.searcher import ForwardLoop
from modelopt.torch.utils import print_rank_0
from modelopt.torch.utils.distributed import ParallelState
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method

Expand Down Expand Up @@ -368,13 +369,13 @@ def postprocess(module):
for name, module in model.named_modules():
if is_quantized_linear(module):
if not hasattr(module.input_quantizer, "_amax"):
print(f"Warning: {name} is not calibrated, skip smoothing")
warnings.warn(f"{name} is not calibrated, skip smoothing")
continue
if module.input_quantizer.num_bits != 8 or module.weight_quantizer.num_bits != 8:
print(f"Warning: only int8 smoothing is supported, skip {name}")
warnings.warn(f"Only int8 smoothing is supported, skip {name}")
continue
if module.input_quantizer.axis != -1:
print(f"Warning: only per-channel smoothing is supported, skip {name}")
warnings.warn(f"Only per-channel smoothing is supported, skip {name}")
continue

assert module.input_quantizer._amax.numel() > 1, (
Expand All @@ -385,52 +386,7 @@ def postprocess(module):
postprocess(module)

smoothed_modules += 1
print(f"Smoothed {smoothed_modules} modules")


def _smoothquant_fasteval(model: nn.Module):
"""Hacky implementation of Smooth-Quant. Copied from monkey-quant."""
smoothed_modules = 0
for name, module in model.named_modules():
if is_quantized_linear(module):
if not hasattr(module.input_quantizer, "_amax"):
print(f"Warning: {name} is not calibrated, skip smoothing")
continue
if module.input_quantizer.num_bits != 8 or module.weight_quantizer.num_bits != 8:
print(f"Warning: only int8 smoothing is supported, skip {name}")
continue
if module.input_quantizer.axis != -1:
print(f"Warning: only per-channel smoothing is supported, skip {name}")
continue

assert module.input_quantizer._amax.numel() > 1
delattr(module.weight_quantizer, "_amax")

# It is important to keep scaling math in fp32 to be numerically safe
act_amax = module.input_quantizer.amax.float()
if act_amax.shape[0] == 1:
act_amax = act_amax.squeeze(0)
# If model is split across devices, this tensor may be on wrong one
act_amax = act_amax.to(module.weight.device)

max_bound = module.input_quantizer.maxbound
scale_a = max_bound / act_amax
# Some channel could have 0 amax which causes scale_a to overflow. Explicitly mask them out here
epsilon = 1.0 / (1 << 31)
if act_amax.min() <= epsilon:
zero_mask = act_amax <= epsilon
scale_a[zero_mask] = 1
inv_scale_a = act_amax / max_bound

module.weight.data.copy_(
(module.weight_quantizer(inv_scale_a * module.weight.float()) * scale_a).to(
module.weight.dtype
)
)
module.weight_quantizer.disable()

smoothed_modules += 1
print(f"Smoothed {smoothed_modules} modules")
print_rank_0(f"Smoothed {smoothed_modules} modules")


def awq(
Expand Down Expand Up @@ -481,7 +437,9 @@ def awq_lite(
See :class:`AWQLiteCalibConfig <modelopt.torch.quantization.config.AWQLiteCalibConfig>` for
details on the remaining arguments.
"""
assert forward_loop is not None, "forward_loop must be provided for awq_lite"
if forward_loop is None:
warnings.warn("forward_loop must be provided for awq_lite; skipping awq_lite")
return

class AWQLiteHelper:
cache_mode: bool = False
Expand All @@ -493,11 +451,32 @@ def __init__(self, module, name):
self.num_search_steps = 0
self.block_size = _get_awq_quantizer_block_size(module.weight, module.weight_quantizer)
self.weight_scale = get_weight_scale(module.weight, self.block_size)
self.loss = {k.item(): 0.0 for k in torch.arange(0, 1.0 + alpha_step, alpha_step)}
self.loss = {
k.item(): torch.zeros((), device=module.weight.device, dtype=torch.float32)
for k in torch.arange(0, 1.0 + alpha_step, alpha_step)
}
self.best_scale = None
self.best_alpha = None
self.is_input_quantized = module.input_quantizer.is_enabled
self.num_tokens = 0
self.module = module
self.is_enabled = True

def setup(self):
module = self.module
bind_forward_method(module, forward, "_forward_no_awq")
if module.input_quantizer.is_enabled:
module.input_quantizer.disable()
if module.input_quantizer.axis not in [None, -1]:
self.is_enabled = False
return
module.input_quantizer.axis = -1

def cleanup(self):
module = self.module
if hasattr(module, "_if_calib"):
delattr(module, "_if_calib")
unpatch_forward_method(module, "_forward_no_awq")

def get_weight_scale(weight, block_size=None):
org_shape = weight.shape
Expand Down Expand Up @@ -534,10 +513,13 @@ def get_scale(x_max, w_max, alpha, tensor_parallel_group=None):
def update_loss(self, out, out_actual, alpha):
out_actual = out_actual[0] if isinstance(out_actual, tuple) else out_actual
out = out[0] if isinstance(out, tuple) else out
loss = (out - out_actual).float().pow(2).mean().item()
loss = (out - out_actual).float().pow(2).mean()
self.awq_lite.loss[alpha] += loss

def update_best_params(self):
if not self.awq_lite.is_enabled:
return
self.awq_lite.loss.update({k: float(v) for k, v in self.awq_lite.loss.items()})
self.awq_lite.best_alpha = min(self.awq_lite.loss, key=self.awq_lite.loss.get)
self.awq_lite.best_scale = get_scale(
self.awq_lite.act_scale,
Expand All @@ -560,7 +542,8 @@ def forward(self, input, *args, **kwargs):
out_actual = self._forward_no_awq(input, *args, **kwargs)
self.weight_quantizer.enable()

if input.numel() == 0: # For MoEs, some experts might see 0 tokens
if input.numel() == 0 or not self.awq_lite.is_enabled:
# For MoEs, some experts might see 0 tokens
return out_actual

if AWQLiteHelper.cache_mode:
Expand Down Expand Up @@ -589,7 +572,6 @@ def forward(self, input, *args, **kwargs):
self.input_quantizer.pre_quant_scale = (1 / awq_scale).to(self.weight.dtype)
self.weight_quantizer.pre_quant_scale = awq_scale.to(self.weight.dtype)
out = self._forward_no_awq(input, *args, **kwargs)

update_loss(self, out, out_actual, alpha)

self.awq_lite.num_search_steps += 1
Expand All @@ -601,19 +583,11 @@ def forward(self, input, *args, **kwargs):
if is_quantized_linear(module) and module.weight_quantizer.is_enabled:
with enable_weight_access_and_writeback(module, model):
module.awq_lite = AWQLiteHelper(module, name)
bind_forward_method(module, forward, "_forward_no_awq")

if module.input_quantizer.is_enabled:
module.input_quantizer.disable()
if module.input_quantizer.axis not in [None, -1]:
raise NotImplementedError(
"input quantization needs to be per-tensor or None for AWQ algorithm"
)
module.input_quantizer.axis = -1
module.awq_lite.setup()

# Collect activation scale values
AWQLiteHelper.cache_mode = True
print("Caching activation statistics for awq_lite...")
print_rank_0("awq_lite: Caching activation statistics...")

# Lets enable stats collection
# This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
Expand All @@ -631,22 +605,25 @@ def forward(self, input, *args, **kwargs):
and module.awq_lite.num_cache_steps > 0
):
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps
if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
torch.isnan(module.awq_lite.weight_scale)
):
module.awq_lite.is_enabled = False
# Hack: MoEs forward all tokens through all experts if _if_calib is True
module._if_calib = True

AWQLiteHelper.cache_mode = False
print("Searching awq_lite parameters...")
print_rank_0("awq_lite: Searching parameters...")
with torch.no_grad():
forward_loop(model)

def postprocess(module):
def postprocess(module, name):
update_best_params(module)
if hasattr(module.weight_quantizer, "_pre_quant_scale"):
delattr(module.weight_quantizer, "_pre_quant_scale")
if hasattr(module.input_quantizer, "_pre_quant_scale"):
delattr(module.input_quantizer, "_pre_quant_scale")
if module.awq_lite.is_input_quantized:
assert module.input_quantizer.amax is not None
if module.awq_lite.is_input_quantized and module.input_quantizer.amax is not None:
act_amax = module.input_quantizer.amax
# TODO: make this a buffer after we support only heterogeneous checkpointing for MCore
module.input_quantizer._amax_for_smoothing = act_amax.cpu()
Expand All @@ -655,25 +632,29 @@ def postprocess(module):
module.input_quantizer.amax = act_amax.amax()
module.input_quantizer.enable()

apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale)
if module.awq_lite.is_enabled:
apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale)
else:
warnings.warn(f"awq_lite: Disabling for {name}, quantizing with max calibration.")
max_calibrate(module, lambda module: module.weight_quantizer(module.weight))

for name, module in model.named_modules():
if hasattr(module, "awq_lite"):
if module.awq_lite.num_cache_steps > 0:
assert module.awq_lite.num_search_steps > 0, (
"Calling `forward_loop(model)` the second time did not forward data through the"
" model. Please provide a valid `forward_loop` function that can be used to"
if module.awq_lite.num_cache_steps == 0:
module.awq_lite.is_enabled = False
elif module.awq_lite.num_search_steps == 0:
module.awq_lite.is_enabled = False
warnings.warn(
"awq_lite: Calling `forward_loop(model)` the second time did not forward data through the"
f" {name}. Please provide a valid `forward_loop` function that can be used to"
" forward data through the model many times."
)
with enable_weight_access_and_writeback(module, model):
postprocess(module)
with enable_weight_access_and_writeback(module, model):
postprocess(module, name)

module.awq_lite.cleanup()
if not debug:
delattr(module, "awq_lite")
if hasattr(module, "_if_calib"):
delattr(module, "_if_calib")

unpatch_forward_method(module, "_forward_no_awq")


@torch.no_grad()
Expand Down Expand Up @@ -858,7 +839,7 @@ def forward(name, self, input, *args, **kwargs):
with enable_weight_access_and_writeback(module, model):
module.awq_clip = AWQClipHelper(module)

print("Estimating awq_clip parameters...")
print_rank_0("awq_clip: Estimating parameters...")
# Lets enable stats collection
# This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
enable_stats_collection(model)
Expand Down Expand Up @@ -919,7 +900,7 @@ def svdquant(
"""

def postprocess(module, name):
print(f"SVD {name}")
print_rank_0(f"SVD {name}")
u, s, vt = torch.linalg.svd(module.weight.data.double())
if u.shape[1] < lowrank or vt.shape[0] < lowrank:
warnings.warn(
Expand Down
5 changes: 5 additions & 0 deletions modelopt/torch/quantization/model_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ def forward_loop(model):
mode_kwargs={"forward_loop": forward_loop},
)

for name, module in model.named_modules():
if isinstance(module, TensorQuantizer):
for attr_name in ["_amax", "_pre_quant_scale"]:
module.validate_attr(attr_name=attr_name, warn_error=True, name=name)

# TODO: Re-enable when the CUDA error: unspecified launch failure is fixed.
# clear_cuda_cache()

Expand Down
61 changes: 58 additions & 3 deletions modelopt/torch/quantization/nn/modules/quant_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,67 @@ class QuantLinear(_LegacyQuantLinearConvBaseMixin, nn.Linear):
Linear = QuantLinear


class SVDQuantTensorQuantizer(TensorQuantizer):
"""TensorQuantizer with svdquant LoRA support."""

@property
def svdquant_lora_a(self):
"""Lora a weights for svdquant."""
if not hasattr(self, "_svdquant_lora_a"):
return None
return self._svdquant_lora_a

@svdquant_lora_a.setter
def svdquant_lora_a(self, value):
"""Lora a weights for svdquant."""
assert value is not None, "svdquant_lora_a cannot be set to None."

if not isinstance(value, torch.Tensor):
value = torch.tensor(value)

if not hasattr(self, "_svdquant_lora_a"):
self.register_buffer("_svdquant_lora_a", value.clone().detach())
else:
if self._svdquant_lora_a.shape != value.shape:
raise RuntimeError("Changing shape when setting svdquant_lora_a is not allowed.")
self._svdquant_lora_a.data.copy_(
value.clone().detach().to(self._svdquant_lora_a.device)
)

@property
def svdquant_lora_b(self):
"""Lora b weights for svdquant."""
if not hasattr(self, "_svdquant_lora_b"):
return None
return self._svdquant_lora_b

@svdquant_lora_b.setter
def svdquant_lora_b(self, value):
"""Lora b weights for svdquant."""
assert value is not None, "svdquant_lora_b cannot be set to None."

if not isinstance(value, torch.Tensor):
value = torch.tensor(value)

if not hasattr(self, "_svdquant_lora_b"):
self.register_buffer("_svdquant_lora_b", value.clone().detach())
else:
if self._svdquant_lora_b.shape != value.shape:
raise RuntimeError("Changing shape when setting svdquant_lora_b is not allowed.")
self._svdquant_lora_b.data.copy_(
value.clone().detach().to(self._svdquant_lora_b.device)
)


class SVDQuantLinear(QuantLinearConvBase):
"""Base class for quantized linear modules with SVDQuant."""

def _setup(self):
"""Overrides and bypass the _setup function."""
if isinstance(self.weight_quantizer, SVDQuantTensorQuantizer):
return
self.weight_quantizer.__class__ = SVDQuantTensorQuantizer

def _not_sequential_quantizers(self):
return isinstance(self.weight_quantizer, TensorQuantizer) and isinstance(
self.input_quantizer, TensorQuantizer
Expand Down Expand Up @@ -104,9 +162,6 @@ def forward(self, input, *args, **kwargs):
output = super().forward(input, *args, **kwargs)
return output

def _setup(self):
"""Overrides and bypass the _setup function."""

def fold_weight(self):
"""Fold the weight for faster eval."""
super().fold_weight()
Expand Down
Loading
Loading