|
| 1 | +"""AWQ (Activation-Aware Weight Quantization) implementation for LLM quantization.""" |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +import numpy as np |
| 6 | +from typing import Optional, Dict, Any, List, Union, Tuple |
| 7 | +from transformers import PreTrainedModel |
| 8 | +from .quantization_engine import QuantizationConfig, QuantizedLinear |
| 9 | + |
| 10 | +class AWQQuantizer: |
| 11 | + """AWQ quantization implementation.""" |
| 12 | + |
| 13 | + def __init__( |
| 14 | + self, |
| 15 | + model: PreTrainedModel, |
| 16 | + bits: int = 4, |
| 17 | + group_size: int = 128, |
| 18 | + zero_point: bool = True, |
| 19 | + scale_dtype: str = "fp32", |
| 20 | + version: str = "v2", |
| 21 | + enable_mnn_kernel: bool = False |
| 22 | + ): |
| 23 | + self.model = model |
| 24 | + self.bits = bits |
| 25 | + self.group_size = group_size |
| 26 | + self.zero_point = zero_point |
| 27 | + self.scale_dtype = scale_dtype |
| 28 | + self.version = version |
| 29 | + self.enable_mnn_kernel = enable_mnn_kernel |
| 30 | + |
| 31 | + # Initialize activation statistics dictionaries |
| 32 | + self.act_scales = {} |
| 33 | + self.weight_scales = {} |
| 34 | + |
| 35 | + def quantize( |
| 36 | + self, |
| 37 | + calibration_data: Optional[torch.Tensor] = None, |
| 38 | + calibration_steps: int = 100 |
| 39 | + ) -> PreTrainedModel: |
| 40 | + """ |
| 41 | + Quantize model using AWQ algorithm. |
| 42 | + |
| 43 | + Args: |
| 44 | + calibration_data: Data used for computing activation statistics |
| 45 | + calibration_steps: Number of steps for calibration |
| 46 | + |
| 47 | + Returns: |
| 48 | + Quantized model |
| 49 | + """ |
| 50 | + if calibration_data is None: |
| 51 | + raise ValueError("AWQ requires calibration data for quantization") |
| 52 | + |
| 53 | + # Prepare model for quantization |
| 54 | + self.model.eval() |
| 55 | + |
| 56 | + # Collect activation statistics |
| 57 | + self._collect_activation_stats(calibration_data, calibration_steps) |
| 58 | + |
| 59 | + # Convert linear layers to quantized versions |
| 60 | + for name, module in self.model.named_modules(): |
| 61 | + if isinstance(module, nn.Linear): |
| 62 | + # Get activation scale for this layer |
| 63 | + act_scale = self.act_scales.get(name, None) |
| 64 | + if act_scale is None: |
| 65 | + continue |
| 66 | + |
| 67 | + # Convert to quantized layer |
| 68 | + quantized = self._quantize_layer(module, act_scale) |
| 69 | + |
| 70 | + # Replace layer in model |
| 71 | + parent_name = '.'.join(name.split('.')[:-1]) |
| 72 | + child_name = name.split('.')[-1] |
| 73 | + if parent_name: |
| 74 | + parent = self.model.get_submodule(parent_name) |
| 75 | + setattr(parent, child_name, quantized) |
| 76 | + else: |
| 77 | + setattr(self.model, name, quantized) |
| 78 | + |
| 79 | + return self.model |
| 80 | + |
| 81 | + def _collect_activation_stats( |
| 82 | + self, |
| 83 | + data: torch.Tensor, |
| 84 | + num_steps: int |
| 85 | + ): |
| 86 | + """Collect activation statistics for each layer.""" |
| 87 | + |
| 88 | + # Register hooks for all linear layers |
| 89 | + handles = [] |
| 90 | + for name, module in self.model.named_modules(): |
| 91 | + if isinstance(module, nn.Linear): |
| 92 | + def hook_fn(name): |
| 93 | + def fn(module, input, output): |
| 94 | + if name not in self.act_scales: |
| 95 | + self.act_scales[name] = [] |
| 96 | + x = input[0].detach() |
| 97 | + scale = torch.max(torch.abs(x)) |
| 98 | + self.act_scales[name].append(scale) |
| 99 | + return fn |
| 100 | + |
| 101 | + handles.append( |
| 102 | + module.register_forward_hook(hook_fn(name)) |
| 103 | + ) |
| 104 | + |
| 105 | + # Run calibration |
| 106 | + with torch.no_grad(): |
| 107 | + for _ in range(num_steps): |
| 108 | + self.model(data) |
| 109 | + |
| 110 | + # Remove hooks |
| 111 | + for handle in handles: |
| 112 | + handle.remove() |
| 113 | + |
| 114 | + # Process collected statistics |
| 115 | + for name in self.act_scales: |
| 116 | + scales = torch.stack(self.act_scales[name]) |
| 117 | + # Use 99.9th percentile for more robust statistics |
| 118 | + self.act_scales[name] = torch.quantile(scales, 0.999) |
| 119 | + |
| 120 | + def _quantize_layer( |
| 121 | + self, |
| 122 | + layer: nn.Linear, |
| 123 | + act_scale: torch.Tensor |
| 124 | + ) -> QuantizedLinear: |
| 125 | + """Quantize a single layer using AWQ.""" |
| 126 | + device = next(layer.parameters()).device |
| 127 | + |
| 128 | + # Initialize quantized layer |
| 129 | + quantized = QuantizedLinear( |
| 130 | + layer.in_features, |
| 131 | + layer.out_features, |
| 132 | + bias=layer.bias is not None, |
| 133 | + config=QuantizationConfig( |
| 134 | + bits=self.bits, |
| 135 | + scheme="symmetric", |
| 136 | + granularity="per-channel" if self.group_size > 0 else "per-tensor", |
| 137 | + calibration="minmax", |
| 138 | + channel_wise=True, |
| 139 | + dtype=f"int{self.bits}", |
| 140 | + format="awq" |
| 141 | + ) |
| 142 | + ) |
| 143 | + |
| 144 | + # Copy bias if exists |
| 145 | + if layer.bias is not None: |
| 146 | + quantized.bias.data.copy_(layer.bias.data) |
| 147 | + |
| 148 | + # Get weight matrix |
| 149 | + W = layer.weight.data.clone() |
| 150 | + |
| 151 | + # Scale weights by activation scale |
| 152 | + W = W / act_scale.view(1, -1) |
| 153 | + |
| 154 | + # Compute quantization scales per group |
| 155 | + if self.group_size > 0: |
| 156 | + n_groups = W.shape[0] // self.group_size |
| 157 | + W_groups = W.view(n_groups, self.group_size, -1) |
| 158 | + |
| 159 | + scales = [] |
| 160 | + zero_points = [] if self.zero_point else None |
| 161 | + |
| 162 | + for idx in range(n_groups): |
| 163 | + group = W_groups[idx] |
| 164 | + max_abs = torch.max(torch.abs(group)) |
| 165 | + scale = (2 ** (self.bits - 1) - 1) / max_abs |
| 166 | + scales.append(scale) |
| 167 | + |
| 168 | + if self.zero_point: |
| 169 | + zero_point = -(torch.max(group) + torch.min(group)) / 2 * scale |
| 170 | + zero_points.append(zero_point) |
| 171 | + |
| 172 | + scales = torch.stack(scales) |
| 173 | + if self.zero_point: |
| 174 | + zero_points = torch.stack(zero_points) |
| 175 | + else: |
| 176 | + zero_points = torch.zeros_like(scales) |
| 177 | + else: |
| 178 | + max_abs = torch.max(torch.abs(W), dim=1)[0] |
| 179 | + scales = (2 ** (self.bits - 1) - 1) / max_abs |
| 180 | + if self.zero_point: |
| 181 | + max_vals = torch.max(W, dim=1)[0] |
| 182 | + min_vals = torch.min(W, dim=1)[0] |
| 183 | + zero_points = -(max_vals + min_vals) / 2 * scales |
| 184 | + else: |
| 185 | + zero_points = torch.zeros_like(scales) |
| 186 | + |
| 187 | + # Quantize weights |
| 188 | + W_quant = torch.round(W * scales.view(-1, 1) - zero_points.view(-1, 1)) |
| 189 | + |
| 190 | + # Store quantized weights and parameters |
| 191 | + quantized.weight_quantized.copy_(W_quant.to(torch.int8)) |
| 192 | + quantized.weight_scale.copy_(1.0 / scales) |
| 193 | + quantized.weight_zero_point.copy_(zero_points) |
| 194 | + |
| 195 | + # Store additional AWQ-specific information |
| 196 | + if hasattr(quantized, 'act_scale'): |
| 197 | + quantized.act_scale.copy_(act_scale) |
| 198 | + |
| 199 | + return quantized |
0 commit comments