|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | + |
| 8 | +""" |
| 9 | +The goal of this is to allow range setting methods from TorchAO (formerly Quanty) |
| 10 | +to be incorporated into the PT2E flow. |
| 11 | +
|
| 12 | +We implement the two main range setting methods: |
| 13 | +1) MSE weight range setting (via a custom observer) |
| 14 | +2) Activation loss weight range setting (via precomputing scales with Quanty, and loading them into a manual observer) |
| 15 | +
|
| 16 | +""" |
| 17 | +import sys |
| 18 | +import logging |
| 19 | + |
| 20 | +import torch |
| 21 | +import torch.nn as nn |
| 22 | +import torch.nn.functional as F |
| 23 | +from executorch.backends.qualcomm.quantizer.annotators import OP_ANNOTATOR |
| 24 | +from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import ( |
| 25 | + PerChannelParamObserver, |
| 26 | +) |
| 27 | + |
| 28 | +from executorch.backends.qualcomm.quantizer.qconfig import ( |
| 29 | + _derived_bias_quant_spec, |
| 30 | + QuantizationConfig, |
| 31 | +) |
| 32 | +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype |
| 33 | + |
| 34 | +from executorch.examples.qualcomm.utils import make_quantizer |
| 35 | + |
| 36 | +from torchao.prototype.quantization.module_swap import ( |
| 37 | + QuantizationRecipe, |
| 38 | + quantize_module_swap, |
| 39 | + QuantizedLinear, |
| 40 | +) |
| 41 | +from torchao.prototype.quantization.module_swap.module_swap import ( |
| 42 | + get_layer_parent_by_name, |
| 43 | +) |
| 44 | +from torchao.prototype.quantization.module_swap.quantized_modules import ( |
| 45 | + QuantizedEmbedding, |
| 46 | +) |
| 47 | +from torchao.prototype.quantization.module_swap.range_setting_methods import ( |
| 48 | + set_weight_range_activation_loss, |
| 49 | +) |
| 50 | + |
| 51 | +from torchao.quantization.pt2e import ( |
| 52 | + HistogramObserver, |
| 53 | + MinMaxObserver, |
| 54 | + ObserverBase, |
| 55 | + PerChannelMinMaxObserver, |
| 56 | +) |
| 57 | +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e |
| 58 | +from torchao.quantization.pt2e.quantizer import QuantizationSpec |
| 59 | + |
| 60 | + |
| 61 | +class PerChannelMSEObserver(PerChannelParamObserver): |
| 62 | + |
| 63 | + @torch.jit.export |
| 64 | + def forward(self, x_orig): |
| 65 | + # since params are static, one calibration is enough |
| 66 | + if not self.calibrated: |
| 67 | + x = x_orig.detach().to(self.min_val.dtype) |
| 68 | + self.min_val, self.max_val = self.line_search(x) |
| 69 | + self.calibrated = True |
| 70 | + |
| 71 | + return x_orig |
| 72 | + |
| 73 | + |
| 74 | + |
| 75 | +class PerChannelFixedQParamsObserver(PerChannelMinMaxObserver): |
| 76 | + r""" |
| 77 | + Fixed scale that is set manually. Symmetric quantization, so zero point is always zero |
| 78 | + Used for per channel quantization |
| 79 | + If scale not set, defaults to minmax |
| 80 | + """ |
| 81 | + |
| 82 | + def __init__( |
| 83 | + self, |
| 84 | + ch_axis=0, |
| 85 | + dtype=torch.quint8, |
| 86 | + qscheme=torch.per_channel_symmetric, |
| 87 | + quant_min=0, |
| 88 | + quant_max=255, |
| 89 | + is_dynamic=False, |
| 90 | + **kwargs, |
| 91 | + ): |
| 92 | + super().__init__(ch_axis=ch_axis, dtype=dtype, qscheme=qscheme, is_dynamic=is_dynamic, **kwargs) |
| 93 | + self.quant_min = quant_min |
| 94 | + self.quant_max = quant_max |
| 95 | + |
| 96 | + def set_scale(self, scale, device): |
| 97 | + self.scale = scale.to(device=device) |
| 98 | + self.zero_point = torch.zeros_like(scale).to(device=device) |
| 99 | + |
| 100 | + @torch.jit.export |
| 101 | + def calculate_qparams(self): |
| 102 | + if hasattr(self, "scale"): |
| 103 | + return self.scale, self.zero_point |
| 104 | + return self._calculate_qparams(self.min_val, self.max_val) |
| 105 | + |
| 106 | + |
| 107 | +def reverse_quantize_module_swap(model: nn.Module) -> nn.Module: |
| 108 | + """ |
| 109 | + Reverse `quantize_module_swap` |
| 110 | + QuantizedLinear --> Linear |
| 111 | + QuantizedEmbedding --> Embedding |
| 112 | + """ |
| 113 | + model = reverse_replace_all_linear_with_quantized(model) |
| 114 | + model = reverse_replace_all_embedding_with_quantized(model) |
| 115 | + return model |
| 116 | + |
| 117 | + |
| 118 | +def reverse_replace_all_embedding_with_quantized( |
| 119 | + model: nn.Module |
| 120 | +) -> nn.Module: |
| 121 | + """ |
| 122 | + Reverse `replace_all_embedding_with_quantized` |
| 123 | + QuantizedEmbedding --> Embedding |
| 124 | + """ |
| 125 | + for name, module in model.named_modules(): |
| 126 | + if isinstance(module, QuantizedEmbedding): |
| 127 | + embedding = nn.Embedding( |
| 128 | + num_embeddings=module.num_embeddings, |
| 129 | + embedding_dim=module.embedding_dim, |
| 130 | + padding_idx=module.padding_idx, |
| 131 | + max_norm=module.max_norm, |
| 132 | + norm_type=module.norm_type, |
| 133 | + scale_grad_by_freq=module.scale_grad_by_freq, |
| 134 | + sparse=module.sparse, |
| 135 | + _weight=module.weight, |
| 136 | + ) |
| 137 | + attribute_name = name.rsplit(".", 1)[-1] |
| 138 | + parent_of_module = get_layer_parent_by_name(model, name) |
| 139 | + setattr(parent_of_module, attribute_name, embedding) |
| 140 | + |
| 141 | + return model |
| 142 | + |
| 143 | + |
| 144 | +def reverse_replace_all_linear_with_quantized( |
| 145 | + model: nn.Module, |
| 146 | +) -> nn.Module: |
| 147 | + """ |
| 148 | + Reverse `replace_all_linear_with_quantized_linear` |
| 149 | + QuantizedLinear --> Linear |
| 150 | + """ |
| 151 | + for name, module in model.named_modules(): |
| 152 | + if isinstance(module, QuantizedLinear): |
| 153 | + linear = nn.Linear( |
| 154 | + in_features=module.in_features, |
| 155 | + out_features=module.out_features, |
| 156 | + bias=module.bias is not None, |
| 157 | + ) |
| 158 | + linear.weight = module.weight |
| 159 | + linear.bias = module.bias |
| 160 | + |
| 161 | + attribute_name = name.rsplit(".", 1)[-1] |
| 162 | + parent_of_module = get_layer_parent_by_name(model, name) |
| 163 | + setattr(parent_of_module, attribute_name, linear) |
| 164 | + |
| 165 | + return model |
| 166 | + |
| 167 | + |
| 168 | +def make_custom_quantizer(quant_dtype, range_setting_weight=None): |
| 169 | + """ |
| 170 | + A custom quantizer which uses either the MSE or manual observer, depending |
| 171 | + on the weight range setting method provided. |
| 172 | + """ |
| 173 | + quantizer = make_quantizer( |
| 174 | + quant_dtype=quant_dtype, |
| 175 | + per_channel_conv=True, |
| 176 | + per_channel_linear=True, |
| 177 | + act_observer=MinMaxObserver, |
| 178 | + ) |
| 179 | + if range_setting_weight in ("mse", "activation_loss"): |
| 180 | + if range_setting_weight == "mse": |
| 181 | + observer = PerChannelMSEObserver.with_args(**{"steps": 200, "use_mse": True}) |
| 182 | + else: |
| 183 | + observer = PerChannelFixedQParamsObserver.with_args(**{"eps": 2**-12}) |
| 184 | + weight_dtype = ( |
| 185 | + torch.int4 |
| 186 | + if quant_dtype in (QuantDtype.use_16a4w, QuantDtype.use_16a4w_block) |
| 187 | + else torch.int8 |
| 188 | + ) |
| 189 | + per_channel_q_config = quantizer.default_quant_config.quant_config |
| 190 | + weight_qspec = QuantizationSpec( |
| 191 | + dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype, |
| 192 | + quant_min=( |
| 193 | + -7 |
| 194 | + if weight_dtype == torch.int4 |
| 195 | + else torch.iinfo(weight_dtype).min + 1 |
| 196 | + ), |
| 197 | + quant_max=( |
| 198 | + 7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max |
| 199 | + ), |
| 200 | + qscheme=torch.per_channel_symmetric, |
| 201 | + ch_axis=0, |
| 202 | + observer_or_fake_quant_ctr=observer, |
| 203 | + ) |
| 204 | + quantizer.default_quant_config.per_channel_quant_config = ( |
| 205 | + QuantizationConfig( |
| 206 | + input_activation=per_channel_q_config.input_activation, |
| 207 | + output_activation=per_channel_q_config.output_activation, |
| 208 | + weight=weight_qspec, |
| 209 | + bias=_derived_bias_quant_spec, |
| 210 | + ) |
| 211 | + ) |
| 212 | + |
| 213 | + return quantizer |
| 214 | + |
| 215 | + |
| 216 | +def compute_scales(model, data, num_points=100, weight_bits=4, activation_bits=16): |
| 217 | + """ |
| 218 | + Compute scales for weight quantization using activation loss range setting |
| 219 | + Uses function from Quanty |
| 220 | + 1. Peform module swap |
| 221 | + 2. Apply method from Quanty to compute optimal scales |
| 222 | + 3. Save scales in dictionary |
| 223 | + 4. Undo module swap |
| 224 | + """ |
| 225 | + recipe = QuantizationRecipe( |
| 226 | + weight_bits=weight_bits, |
| 227 | + weight_quantization=True, |
| 228 | + dynamic_weights=False, |
| 229 | + weight_group_size="per_channel", |
| 230 | + activation_bits=activation_bits, |
| 231 | + activation_quantization=True, |
| 232 | + activation_group_size="per_tensor", |
| 233 | + input_quantization=True, |
| 234 | + output_quantization=True, |
| 235 | + dynamic_activations=False, |
| 236 | + ) |
| 237 | + |
| 238 | + quantized_model = quantize_module_swap(model, recipe) |
| 239 | + |
| 240 | + set_weight_range_activation_loss(quantized_model, data, 1, num_points) # batch_size = 1 |
| 241 | + scale_dict = dict() |
| 242 | + for name, module in quantized_model.named_modules(): |
| 243 | + if isinstance(module, QuantizedLinear): |
| 244 | + scale_dict[name] = module.weight_scale.clone().detach().to(device=model.device) |
| 245 | + |
| 246 | + reverse_quantize_module_swap(model) |
| 247 | + |
| 248 | + return scale_dict |
| 249 | + |
| 250 | + |
| 251 | +def set_scales(model, scale_dict, num_heads=32, dim=2048): |
| 252 | + """ |
| 253 | + Given a prepared model with manual observers inserted after weights, set scales |
| 254 | + manually. This is specific to Llama architecture, prepared as in the HTP flow |
| 255 | + (For example, we must separate scales because of splitting attention heads) |
| 256 | + """ |
| 257 | + head_dim = dim // num_heads |
| 258 | + for node in model.graph.nodes: |
| 259 | + if node.op == "get_attr": |
| 260 | + l = node.target.split(".") |
| 261 | + if len(l) > 3 and l[-3] in ("wq_sha", "wk_sha", "wv_sha"): |
| 262 | + shorter_name = l[-3][:2] |
| 263 | + key = ".".join(["model"] + l[:-3] + [shorter_name]) |
| 264 | + observer_name = str(list(node.users.keys())[0]) |
| 265 | + observer = getattr(model, observer_name) |
| 266 | + i = int(l[-2]) |
| 267 | + observer.set_scale(scale_dict[key][head_dim*i:head_dim*(i + 1), :], device=model.device) |
| 268 | + elif len(l) > 1 and l[-2] in ("wo_sha", "w1_conv", "w2_conv", "w3_conv"): |
| 269 | + shorter_name = l[-2][:2] |
| 270 | + key = ".".join(["model"] + l[:-2] + [shorter_name]) |
| 271 | + observer_name = str(list(node.users.keys())[0]) |
| 272 | + observer = getattr(model, observer_name) |
| 273 | + observer.set_scale(scale_dict[key], model.device) |
0 commit comments