diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 5276b1334..3ef014d65 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -25,6 +25,7 @@ import torch.nn.functional as F from modelopt.torch.opt.searcher import ForwardLoop +from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import ParallelState from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method @@ -368,13 +369,13 @@ def postprocess(module): for name, module in model.named_modules(): if is_quantized_linear(module): if not hasattr(module.input_quantizer, "_amax"): - print(f"Warning: {name} is not calibrated, skip smoothing") + warnings.warn(f"{name} is not calibrated, skip smoothing") continue if module.input_quantizer.num_bits != 8 or module.weight_quantizer.num_bits != 8: - print(f"Warning: only int8 smoothing is supported, skip {name}") + warnings.warn(f"Only int8 smoothing is supported, skip {name}") continue if module.input_quantizer.axis != -1: - print(f"Warning: only per-channel smoothing is supported, skip {name}") + warnings.warn(f"Only per-channel smoothing is supported, skip {name}") continue assert module.input_quantizer._amax.numel() > 1, ( @@ -385,52 +386,7 @@ def postprocess(module): postprocess(module) smoothed_modules += 1 - print(f"Smoothed {smoothed_modules} modules") - - -def _smoothquant_fasteval(model: nn.Module): - """Hacky implementation of Smooth-Quant. Copied from monkey-quant.""" - smoothed_modules = 0 - for name, module in model.named_modules(): - if is_quantized_linear(module): - if not hasattr(module.input_quantizer, "_amax"): - print(f"Warning: {name} is not calibrated, skip smoothing") - continue - if module.input_quantizer.num_bits != 8 or module.weight_quantizer.num_bits != 8: - print(f"Warning: only int8 smoothing is supported, skip {name}") - continue - if module.input_quantizer.axis != -1: - print(f"Warning: only per-channel smoothing is supported, skip {name}") - continue - - assert module.input_quantizer._amax.numel() > 1 - delattr(module.weight_quantizer, "_amax") - - # It is important to keep scaling math in fp32 to be numerically safe - act_amax = module.input_quantizer.amax.float() - if act_amax.shape[0] == 1: - act_amax = act_amax.squeeze(0) - # If model is split across devices, this tensor may be on wrong one - act_amax = act_amax.to(module.weight.device) - - max_bound = module.input_quantizer.maxbound - scale_a = max_bound / act_amax - # Some channel could have 0 amax which causes scale_a to overflow. Explicitly mask them out here - epsilon = 1.0 / (1 << 31) - if act_amax.min() <= epsilon: - zero_mask = act_amax <= epsilon - scale_a[zero_mask] = 1 - inv_scale_a = act_amax / max_bound - - module.weight.data.copy_( - (module.weight_quantizer(inv_scale_a * module.weight.float()) * scale_a).to( - module.weight.dtype - ) - ) - module.weight_quantizer.disable() - - smoothed_modules += 1 - print(f"Smoothed {smoothed_modules} modules") + print_rank_0(f"Smoothed {smoothed_modules} modules") def awq( @@ -481,7 +437,9 @@ def awq_lite( See :class:`AWQLiteCalibConfig ` for details on the remaining arguments. """ - assert forward_loop is not None, "forward_loop must be provided for awq_lite" + if forward_loop is None: + warnings.warn("forward_loop must be provided for awq_lite; skipping awq_lite") + return class AWQLiteHelper: cache_mode: bool = False @@ -493,11 +451,32 @@ def __init__(self, module, name): self.num_search_steps = 0 self.block_size = _get_awq_quantizer_block_size(module.weight, module.weight_quantizer) self.weight_scale = get_weight_scale(module.weight, self.block_size) - self.loss = {k.item(): 0.0 for k in torch.arange(0, 1.0 + alpha_step, alpha_step)} + self.loss = { + k.item(): torch.zeros((), device=module.weight.device, dtype=torch.float32) + for k in torch.arange(0, 1.0 + alpha_step, alpha_step) + } self.best_scale = None self.best_alpha = None self.is_input_quantized = module.input_quantizer.is_enabled self.num_tokens = 0 + self.module = module + self.is_enabled = True + + def setup(self): + module = self.module + bind_forward_method(module, forward, "_forward_no_awq") + if module.input_quantizer.is_enabled: + module.input_quantizer.disable() + if module.input_quantizer.axis not in [None, -1]: + self.is_enabled = False + return + module.input_quantizer.axis = -1 + + def cleanup(self): + module = self.module + if hasattr(module, "_if_calib"): + delattr(module, "_if_calib") + unpatch_forward_method(module, "_forward_no_awq") def get_weight_scale(weight, block_size=None): org_shape = weight.shape @@ -534,10 +513,13 @@ def get_scale(x_max, w_max, alpha, tensor_parallel_group=None): def update_loss(self, out, out_actual, alpha): out_actual = out_actual[0] if isinstance(out_actual, tuple) else out_actual out = out[0] if isinstance(out, tuple) else out - loss = (out - out_actual).float().pow(2).mean().item() + loss = (out - out_actual).float().pow(2).mean() self.awq_lite.loss[alpha] += loss def update_best_params(self): + if not self.awq_lite.is_enabled: + return + self.awq_lite.loss.update({k: float(v) for k, v in self.awq_lite.loss.items()}) self.awq_lite.best_alpha = min(self.awq_lite.loss, key=self.awq_lite.loss.get) self.awq_lite.best_scale = get_scale( self.awq_lite.act_scale, @@ -560,7 +542,8 @@ def forward(self, input, *args, **kwargs): out_actual = self._forward_no_awq(input, *args, **kwargs) self.weight_quantizer.enable() - if input.numel() == 0: # For MoEs, some experts might see 0 tokens + if input.numel() == 0 or not self.awq_lite.is_enabled: + # For MoEs, some experts might see 0 tokens return out_actual if AWQLiteHelper.cache_mode: @@ -589,7 +572,6 @@ def forward(self, input, *args, **kwargs): self.input_quantizer.pre_quant_scale = (1 / awq_scale).to(self.weight.dtype) self.weight_quantizer.pre_quant_scale = awq_scale.to(self.weight.dtype) out = self._forward_no_awq(input, *args, **kwargs) - update_loss(self, out, out_actual, alpha) self.awq_lite.num_search_steps += 1 @@ -601,19 +583,11 @@ def forward(self, input, *args, **kwargs): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: with enable_weight_access_and_writeback(module, model): module.awq_lite = AWQLiteHelper(module, name) - bind_forward_method(module, forward, "_forward_no_awq") - - if module.input_quantizer.is_enabled: - module.input_quantizer.disable() - if module.input_quantizer.axis not in [None, -1]: - raise NotImplementedError( - "input quantization needs to be per-tensor or None for AWQ algorithm" - ) - module.input_quantizer.axis = -1 + module.awq_lite.setup() # Collect activation scale values AWQLiteHelper.cache_mode = True - print("Caching activation statistics for awq_lite...") + print_rank_0("awq_lite: Caching activation statistics...") # Lets enable stats collection # This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass @@ -631,22 +605,25 @@ def forward(self, input, *args, **kwargs): and module.awq_lite.num_cache_steps > 0 ): module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps + if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any( + torch.isnan(module.awq_lite.weight_scale) + ): + module.awq_lite.is_enabled = False # Hack: MoEs forward all tokens through all experts if _if_calib is True module._if_calib = True AWQLiteHelper.cache_mode = False - print("Searching awq_lite parameters...") + print_rank_0("awq_lite: Searching parameters...") with torch.no_grad(): forward_loop(model) - def postprocess(module): + def postprocess(module, name): update_best_params(module) if hasattr(module.weight_quantizer, "_pre_quant_scale"): delattr(module.weight_quantizer, "_pre_quant_scale") if hasattr(module.input_quantizer, "_pre_quant_scale"): delattr(module.input_quantizer, "_pre_quant_scale") - if module.awq_lite.is_input_quantized: - assert module.input_quantizer.amax is not None + if module.awq_lite.is_input_quantized and module.input_quantizer.amax is not None: act_amax = module.input_quantizer.amax # TODO: make this a buffer after we support only heterogeneous checkpointing for MCore module.input_quantizer._amax_for_smoothing = act_amax.cpu() @@ -655,25 +632,29 @@ def postprocess(module): module.input_quantizer.amax = act_amax.amax() module.input_quantizer.enable() - apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale) + if module.awq_lite.is_enabled: + apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale) + else: + warnings.warn(f"awq_lite: Disabling for {name}, quantizing with max calibration.") + max_calibrate(module, lambda module: module.weight_quantizer(module.weight)) for name, module in model.named_modules(): if hasattr(module, "awq_lite"): - if module.awq_lite.num_cache_steps > 0: - assert module.awq_lite.num_search_steps > 0, ( - "Calling `forward_loop(model)` the second time did not forward data through the" - " model. Please provide a valid `forward_loop` function that can be used to" + if module.awq_lite.num_cache_steps == 0: + module.awq_lite.is_enabled = False + elif module.awq_lite.num_search_steps == 0: + module.awq_lite.is_enabled = False + warnings.warn( + "awq_lite: Calling `forward_loop(model)` the second time did not forward data through the" + f" {name}. Please provide a valid `forward_loop` function that can be used to" " forward data through the model many times." ) - with enable_weight_access_and_writeback(module, model): - postprocess(module) + with enable_weight_access_and_writeback(module, model): + postprocess(module, name) + module.awq_lite.cleanup() if not debug: delattr(module, "awq_lite") - if hasattr(module, "_if_calib"): - delattr(module, "_if_calib") - - unpatch_forward_method(module, "_forward_no_awq") @torch.no_grad() @@ -858,7 +839,7 @@ def forward(name, self, input, *args, **kwargs): with enable_weight_access_and_writeback(module, model): module.awq_clip = AWQClipHelper(module) - print("Estimating awq_clip parameters...") + print_rank_0("awq_clip: Estimating parameters...") # Lets enable stats collection # This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass enable_stats_collection(model) @@ -919,7 +900,7 @@ def svdquant( """ def postprocess(module, name): - print(f"SVD {name}") + print_rank_0(f"SVD {name}") u, s, vt = torch.linalg.svd(module.weight.data.double()) if u.shape[1] < lowrank or vt.shape[0] < lowrank: warnings.warn( diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index e29367963..deace8e0c 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -105,6 +105,11 @@ def forward_loop(model): mode_kwargs={"forward_loop": forward_loop}, ) + for name, module in model.named_modules(): + if isinstance(module, TensorQuantizer): + for attr_name in ["_amax", "_pre_quant_scale"]: + module.validate_attr(attr_name=attr_name, warn_error=True, name=name) + # TODO: Re-enable when the CUDA error: unspecified launch failure is fixed. # clear_cuda_cache() diff --git a/modelopt/torch/quantization/nn/modules/quant_linear.py b/modelopt/torch/quantization/nn/modules/quant_linear.py index 44655f7e2..f1d601557 100644 --- a/modelopt/torch/quantization/nn/modules/quant_linear.py +++ b/modelopt/torch/quantization/nn/modules/quant_linear.py @@ -62,9 +62,67 @@ class QuantLinear(_LegacyQuantLinearConvBaseMixin, nn.Linear): Linear = QuantLinear +class SVDQuantTensorQuantizer(TensorQuantizer): + """TensorQuantizer with svdquant LoRA support.""" + + @property + def svdquant_lora_a(self): + """Lora a weights for svdquant.""" + if not hasattr(self, "_svdquant_lora_a"): + return None + return self._svdquant_lora_a + + @svdquant_lora_a.setter + def svdquant_lora_a(self, value): + """Lora a weights for svdquant.""" + assert value is not None, "svdquant_lora_a cannot be set to None." + + if not isinstance(value, torch.Tensor): + value = torch.tensor(value) + + if not hasattr(self, "_svdquant_lora_a"): + self.register_buffer("_svdquant_lora_a", value.clone().detach()) + else: + if self._svdquant_lora_a.shape != value.shape: + raise RuntimeError("Changing shape when setting svdquant_lora_a is not allowed.") + self._svdquant_lora_a.data.copy_( + value.clone().detach().to(self._svdquant_lora_a.device) + ) + + @property + def svdquant_lora_b(self): + """Lora b weights for svdquant.""" + if not hasattr(self, "_svdquant_lora_b"): + return None + return self._svdquant_lora_b + + @svdquant_lora_b.setter + def svdquant_lora_b(self, value): + """Lora b weights for svdquant.""" + assert value is not None, "svdquant_lora_b cannot be set to None." + + if not isinstance(value, torch.Tensor): + value = torch.tensor(value) + + if not hasattr(self, "_svdquant_lora_b"): + self.register_buffer("_svdquant_lora_b", value.clone().detach()) + else: + if self._svdquant_lora_b.shape != value.shape: + raise RuntimeError("Changing shape when setting svdquant_lora_b is not allowed.") + self._svdquant_lora_b.data.copy_( + value.clone().detach().to(self._svdquant_lora_b.device) + ) + + class SVDQuantLinear(QuantLinearConvBase): """Base class for quantized linear modules with SVDQuant.""" + def _setup(self): + """Overrides and bypass the _setup function.""" + if isinstance(self.weight_quantizer, SVDQuantTensorQuantizer): + return + self.weight_quantizer.__class__ = SVDQuantTensorQuantizer + def _not_sequential_quantizers(self): return isinstance(self.weight_quantizer, TensorQuantizer) and isinstance( self.input_quantizer, TensorQuantizer @@ -104,9 +162,6 @@ def forward(self, input, *args, **kwargs): output = super().forward(input, *args, **kwargs) return output - def _setup(self): - """Overrides and bypass the _setup function.""" - def fold_weight(self): """Fold the weight for faster eval.""" super().fold_weight() diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 6e431dce9..bf801646b 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -28,6 +28,11 @@ except ImportError: DTensor = None +if hasattr(torch.onnx, "_globals"): + from torch.onnx._globals import GLOBALS +else: # torch >= 2.9 + from torch.onnx._internal.torchscript_exporter._globals import GLOBALS + import torch.nn.functional as F from torch import nn @@ -254,6 +259,7 @@ def reset_amax(self): if hasattr(self, "_amax"): delattr(self, "_amax") self._calibrator.reset() + self.reset_bias() def reset_bias(self): """Reset bias to None.""" @@ -419,54 +425,6 @@ def is_mx_format(self): and self.block_sizes.get("scale_bits", None) == (8, 0) ) - @property - def svdquant_lora_a(self): - """Lora a weights for svdquant.""" - if not hasattr(self, "_svdquant_lora_a"): - return None - return self._svdquant_lora_a - - @svdquant_lora_a.setter - def svdquant_lora_a(self, value): - """Lora a weights for svdquant.""" - assert value is not None, "svdquant_lora_a cannot be set to None." - - if not isinstance(value, torch.Tensor): - value = torch.tensor(value) - - if not hasattr(self, "_svdquant_lora_a"): - self.register_buffer("_svdquant_lora_a", value.clone().detach()) - else: - if self._svdquant_lora_a.shape != value.shape: - raise RuntimeError("Changing shape when setting svdquant_lora_a is not allowed.") - self._svdquant_lora_a.data.copy_( - value.clone().detach().to(self._svdquant_lora_a.device) - ) - - @property - def svdquant_lora_b(self): - """Lora b weights for svdquant.""" - if not hasattr(self, "_svdquant_lora_b"): - return None - return self._svdquant_lora_b - - @svdquant_lora_b.setter - def svdquant_lora_b(self, value): - """Lora b weights for svdquant.""" - assert value is not None, "svdquant_lora_b cannot be set to None." - - if not isinstance(value, torch.Tensor): - value = torch.tensor(value) - - if not hasattr(self, "_svdquant_lora_b"): - self.register_buffer("_svdquant_lora_b", value.clone().detach()) - else: - if self._svdquant_lora_b.shape != value.shape: - raise RuntimeError("Changing shape when setting svdquant_lora_b is not allowed.") - self._svdquant_lora_b.data.copy_( - value.clone().detach().to(self._svdquant_lora_b.device) - ) - def disable_calib(self): """Disable calibration.""" self._if_calib = False @@ -546,12 +504,27 @@ def _get_amax(self, inputs): amax = amax.detach() if is_torch_export_mode() else amax.data return amax - def _validate_amax(self, amax): - # Dynamic control flow is not supported by torch dynamo - if not is_torch_export_mode() and not torch.compiler.is_compiling(): - assert torch.all(amax >= 0) and not torch.any(torch.isinf(amax)), ( - f"Got invalid amax: {amax}" - ) + def validate_attr( + self, attr_value=None, attr_name="amax", raise_error=False, warn_error=False, name="" + ): + """Validate attribute.""" + attr_value = attr_value if attr_value is not None else getattr(self, attr_name, None) + if attr_value is None or (isinstance(attr_value, torch.Tensor) and attr_value.is_meta): + return True + is_valid = ( + torch.all(attr_value >= 0) + and not torch.any(torch.isinf(attr_value)) + and not torch.any(torch.isnan(attr_value)) + ) + if is_valid: + return True + name = f"{name}." if name else "" + msg = f"{name}{attr_name} contains invalid values: {attr_value}" + if warn_error: + warnings.warn(msg) + if raise_error: + raise ValueError(msg) + return False def _get_bias(self, inputs): """Get bias from buffer or compute it dynamically.""" @@ -651,14 +624,14 @@ def _fake_quantize(self, inputs): amax = None if not self.is_mx_format: amax = self._get_amax(inputs) - self._validate_amax(amax) if self.block_sizes is not None and self.block_sizes.get("type", "static") == "dynamic": # Block quantization, including dynamic and static block quantization block_size = self.block_sizes.get(-1, None) or self.block_sizes.get( inputs.dim() - 1, None ) - assert block_size is not None, "block size for dynamic quantization not found." + if block_size is None: + raise ValueError("block size for dynamic quantization not found.") outputs = dynamic_block_quant( inputs, @@ -800,10 +773,12 @@ def set_quant_params(axis, block_reshape_size, padding, slices, amax_shape=None) def _process_for_blockquant(self, inputs: torch.Tensor): if hasattr(self, "_padding"): inputs = F.pad(inputs, self._padding, "constant", 0) - assert inputs.shape == self._original_shape, ( - f"Input shape has changed from {self._original_shape} to {inputs.shape}." - " Block-quantization requires a fixed input shape." - ) + + if inputs.shape != self._original_shape: + raise ValueError( + f"Input shape has changed from {self._original_shape} to {inputs.shape}." + " Block-quantization requires a fixed input shape." + ) inputs = inputs.reshape(self._block_reshape_size) return inputs @@ -854,7 +829,7 @@ def export_amax(self) -> torch.Tensor | None: clamp_min, clamp_max = torch.finfo(amax.dtype).tiny, torch.finfo(amax.dtype).max amax = amax.clamp(min=clamp_min, max=clamp_max) - self._validate_amax(amax) + self.validate_attr(attr_name="_amax", attr_value=amax) if self.block_sizes is None: # tensorrt_llm assumes the scaling_factor dim >= 1 for per-tensor. @@ -878,11 +853,6 @@ def forward(self, inputs): Returns: outputs: A Tensor of type output_dtype """ - if hasattr(torch.onnx, "_globals"): - from torch.onnx._globals import GLOBALS - else: # torch >= 2.9 - from torch.onnx._internal.torchscript_exporter._globals import GLOBALS - if DTensor is not None and isinstance(inputs, DTensor): # TensorQuantizer only handles regular non-DTensor inputs device_mesh, placements = inputs.device_mesh, inputs.placements @@ -949,7 +919,6 @@ def forward(self, inputs): ) assert block_size is not None, "block size for dynamic quantization not found." - # Collect calibration data for bias self.collect(inputs) if self._if_quant: @@ -1012,7 +981,6 @@ def extra_repr(self): s += f" axis={self._axis}" if self._axis is not None else " per-tensor" s += f" amax={self._short_amax()}" s += " pre_quant_scale" if self.pre_quant_scale is not None else "" - s += " svdquant" if self.svdquant_lora_a is not None else "" s += " rotated" if self._rotate else "" s += ( f" calibrator={self._calibrator.__class__.__name__}"