From 72da5c15c83a9eda5920a445e3d7d4bf70842581 Mon Sep 17 00:00:00 2001 From: Rohan Joshi Date: Tue, 22 Jul 2025 14:29:37 -0700 Subject: [PATCH] Implemented range setting in QNN llama flow (#12377) Summary: `llama.py` now has the `--range_setting` flag, for which there are the options `mse_weight_only` and `mse_with_act_loss`. There is also an eval script for computing perplexity called `eval_llama_qnn.py` (for faster eval, try seq length 1024). This script also has a flag --quant_linear_only to only quantize linear/conv nodes, to run faster experiments. Update: I've also added SpinQuant as a feature to further improve accuracy. Both `llama.py` and `eval_llama_qnn.py` also have a flag `--spinquant` which can be used in combination with range setting or by itself. Based on my experiments on Llama 1B, I get the best results using both `--spinquant` and `--range_setting mse_with_act_loss`. Reviewed By: cccclai Differential Revision: D78127727 --- examples/qualcomm/oss_scripts/llama/TARGETS | 12 + .../oss_scripts/llama/eval_llama_qnn.py | 173 ++++++---- examples/qualcomm/oss_scripts/llama/llama.py | 98 +++++- .../oss_scripts/llama/model/static_llama.py | 4 +- .../oss_scripts/llama/range_setting_pt2e.py | 309 ++++++++++++++++++ 5 files changed, 518 insertions(+), 78 deletions(-) create mode 100644 examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py diff --git a/examples/qualcomm/oss_scripts/llama/TARGETS b/examples/qualcomm/oss_scripts/llama/TARGETS index 9c5dd1ceaf9..264854d9bfc 100644 --- a/examples/qualcomm/oss_scripts/llama/TARGETS +++ b/examples/qualcomm/oss_scripts/llama/TARGETS @@ -34,6 +34,16 @@ python_library( ], ) +python_library( + name = "range_setting_pt2e", + srcs = [ + "range_setting_pt2e.py", + ], + deps = [ + "//caffe2:torch", + ], +) + python_binary( name = "llama", main_function = "executorch.examples.qualcomm.oss_scripts.llama.llama.main", @@ -42,6 +52,7 @@ python_binary( ], deps = [ ":llama_lib", + "//executorch/examples/qualcomm/oss_scripts/llama:range_setting_pt2e", ], ) @@ -55,6 +66,7 @@ python_binary( deps = [ ":llama_lib", "//executorch/examples/models/llama:eval_library", + "//executorch/examples/qualcomm/oss_scripts/llama:range_setting_pt2e", "fbsource//third-party/pypi/lm-eval:lm-eval", ], ) diff --git a/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py b/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py index 632c0051b12..b26c033eae7 100644 --- a/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py +++ b/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py @@ -5,16 +5,14 @@ # LICENSE file in the root directory of this source tree. import argparse -import copy import json import logging import sys - -from typing import List, Tuple +import types import torch -import torch.nn as nn + from executorch.backends.qualcomm.quantizer.custom_annotation import ( annotate_linear_16a8w_in_affine_layer, annotate_matmul_16a8w, @@ -46,14 +44,19 @@ LlamaModel, ModelArgs, ) - -from executorch.examples.qualcomm.utils import make_quantizer +from executorch.examples.qualcomm.oss_scripts.llama.range_setting_pt2e import ( + compute_scales, + make_custom_quantizer, + reverse_quantize_module_swap, + set_scales, + WrappedLlamaModel, +) from lm_eval.evaluator import simple_evaluate from pytorch_tokenizers import get_tokenizer +from torchao.prototype.spinquant import apply_spinquant -from torchao.quantization.pt2e import MinMaxObserver from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from torchao.quantization.pt2e.quantizer import QuantizationSpec @@ -64,30 +67,6 @@ logging.getLogger().setLevel(logging.INFO) -class WrappedLlamaModel(nn.Module): - def __init__( - self, model, atten_mask, use_kv_cache=False, max_seq_len=512, device="cuda" - ): - super(WrappedLlamaModel, self).__init__() - self.model = model - self.max_seq_len = max_seq_len - self.use_kv_cache = use_kv_cache - self.device = device - self.atten_mask = atten_mask - - def forward( - self, - tokens: torch.Tensor, - *args, - ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: - # Pad input if necessary, since LlamaModel requires static shape - if tokens.shape[1] != self.max_seq_len: - tokens = torch.nn.functional.pad( - tokens, (0, self.max_seq_len - tokens.shape[1]) - ) - return self.model.forward(tokens, self.atten_mask) - - def add_mse_weight_observer(quant_dtype, quantizer): weight_dtype = ( torch.int4 @@ -115,24 +94,16 @@ def add_mse_weight_observer(quant_dtype, quantizer): ) -def gen_eval_wrapper(model_name, args): - tokenizer = get_tokenizer(args.tokenizer_path) +def prepare_model(model_name, args): with open(args.params) as f: - kv_config = ModelArgs(**json.load(f)) + prefill_config = ModelArgs(**json.load(f)) # TODO: support batch inputs if necessary - kv_config.max_batch_size = 1 - kv_config.max_seq_len = args.max_seq_length - kv_config.use_kv_cache = True - - prefill_config = copy.copy(kv_config) + prefill_config.max_batch_size = 1 prefill_config.max_seq_len = args.max_seq_length - prefill_config.use_kv_cache = ( - False if args.max_seq_length == args.prefill_ar_len else True - ) - config = prefill_config + prefill_config.use_kv_cache = False use_i64_token = args.embedding_quantize is not None model = LlamaModel( - config, + prefill_config, ar_len=args.prefill_ar_len, output_new_cache_only=True, output_cache=False, @@ -173,29 +144,67 @@ def permute(w, heads): if "model" in state_dict: state_dict = state_dict["model"] + # TODO: use dtype of model checkpoint + model = model.to(device=args.device, dtype=torch.float) + inputs = model.get_example_inputs(use_kv_cache=False) + tokens, atten_mask = inputs + + scales_state_dict = {} + if args.spinquant: + config = types.SimpleNamespace( + dim=prefill_config.dim, + head_dim=prefill_config.dim // prefill_config.n_heads, + n_local_heads=prefill_config.n_heads, + intermediate_size=4 * prefill_config.dim, + ) + model.config = config + apply_spinquant( + model, + use_r1=True, + use_r2=True, + use_r4=False, + pretrained_rotation_path=None, + qkv_split=True, + ) + logging.info("Applied SpinQuant to the model") + + if args.range_setting == "mse_with_act_loss": + wrapped_model = WrappedLlamaModel( + model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device + ) + act_bits, weight_bits = { + "8a8w": (8, 8), + "16a4w": (16, 4), + "16a4w_block": (16, 4), + }[args.ptq] + scales_state_dict = compute_scales( + wrapped_model, tokens, weight_bits, act_bits, 1600 + ) + torch.save(scales_state_dict, "scales_state_dict.pth") + logging.info("Saved scales to scales_state_dict.pth!") + reverse_quantize_module_swap(wrapped_model) + for layer in model.layers: if getattr(layer.attention, "prepare_sha", None): layer.attention.prepare_sha() if getattr(layer.feed_forward, "prepare_feedfoward_conv", None): layer.feed_forward.prepare_feedfoward_conv() - - model.to(dtype=torch.float) - model.to(device=args.device) - - tokens, atten_mask = model.get_example_inputs(use_kv_cache=False) - tokens = tokens.to(device=args.device) - atten_mask = atten_mask.to(device=args.device) - atten_mask = atten_mask.to(dtype=torch.float) - inputs = (tokens, atten_mask) - if args.embedding_quantize: model = get_quant_embedding_transform( embedding_quantize=args.embedding_quantize )(model) model = convert_linear_to_conv2d(model) + return model, prefill_config, inputs, scales_state_dict + + +def gen_eval_wrapper(model_name, args): + tokenizer = get_tokenizer(args.tokenizer_path) + model, config, inputs, scales_state_dict = prepare_model(model_name, args) + tokens, atten_mask = inputs + use_i64_token = args.embedding_quantize is not None - if args.ptq: + if args.ptq is not None: quant_dtype = getattr(QuantDtype, f"use_{args.ptq}") custom_annotations = (annotate_matmul_16a8w,) @@ -203,27 +212,22 @@ def permute(w, heads): custom_annotations = custom_annotations + ( annotate_linear_16a8w_in_affine_layer, ) - quantizer = make_quantizer( - quant_dtype=quant_dtype, - per_channel_conv=True, - per_channel_linear=True, - act_observer=MinMaxObserver, - ) - quantizer.add_custom_quant_annotations(custom_annotations) - if args.range_setting == "mse_weight": - add_mse_weight_observer(quant_dtype, quantizer) + quantizer = make_custom_quantizer( + quant_dtype, args.range_setting, custom_annotations, args.quant_linear_only + ) with torch.no_grad(): + logging.info("Starting export...") model = torch.export.export(model, inputs, strict=True).module() if quant_dtype == QuantDtype.use_16a4w_block: conv_nodes = [n for n in model.graph.nodes if "conv" in n.name] block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes} quantizer.set_block_size_map(block_size_map) - + logging.info("Finished export, adding observers (prepare_pt2e)...") model = prepare_pt2e(model, quantizer) - logging.info("Quantizing the model...") + logging.info("Observers added, starting calibration...") calibrate( inputs, @@ -236,7 +240,24 @@ def permute(w, heads): use_i64_token=use_i64_token, ) + if args.range_setting == "mse_with_act_loss": + # scales_state_dict = torch.load("scales_state_dict.pth") + set_scales(model, scales_state_dict, config.head_dim) + + logging.info("Quantizing the model...") model = convert_pt2e(model) + logging.info("Quantization complete! Here is some sample generated text:") + + calibrate( + inputs, + "Could you tell me about Facebook?", + model, + tokenizer=tokenizer, + ar_len=args.prefill_ar_len, + max_seq_len=args.max_seq_len, + kv_updater=None, + use_i64_token=use_i64_token, + ) model = WrappedLlamaModel( model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device @@ -248,7 +269,7 @@ def permute(w, heads): max_seq_length=args.calibration_seq_length, use_kv_cache=args.use_kv_cache, generate_full_logits=args.generate_full_logits, - enable_dynamic_shape=args.enable_dynamic_shape, + enable_dynamic_shape=False, ) @@ -271,6 +292,7 @@ def eval_llama( model=eval_wrapper, tasks=args.tasks, num_fewshot=args.num_fewshot, + limit=args.fraction, ) for task, res in eval_results["results"].items(): @@ -290,9 +312,24 @@ def main() -> None: ) parser.add_argument( "--range_setting", - help="Choose which range setting method (e.g. mse_weight). If not specified, will do minmax for weights and activations", + help="Choose which range setting method for weight quantization (e.g. mse_weight_only or mse_with_act_loss). If not specified, defaults to minmax", type=str, ) + parser.add_argument( + "--spinquant", + help="Apply SpinQuant (R1+R2) to the model. Uses random Hadamard matrices for rotations", + action="store_true", + ) + parser.add_argument( + "--fraction", + help="the fraction of examples per task (only use this for testing)", + type=float, + ) + parser.add_argument( + "--quant_linear_only", + help="if you select this option we quantize linear layers only", + action="store_true", + ) args = parser.parse_args() args.llama_model = "llama3_2" diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index db533986119..cbd9f711bae 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -16,6 +16,7 @@ import subprocess import sys import time +import types from functools import partial from multiprocessing.connection import Client @@ -63,6 +64,14 @@ LlamaModel, ModelArgs, ) +from executorch.examples.qualcomm.oss_scripts.llama.range_setting_pt2e import ( + compute_scales, + make_custom_quantizer, + reverse_quantize_module_swap, + set_scales, + WrappedLlamaModel, +) + from executorch.examples.qualcomm.utils import ( make_output_dir, make_quantizer, @@ -82,6 +91,8 @@ from pytorch_tokenizers import get_tokenizer, TiktokenTokenizer from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer +from torchao.prototype.spinquant import apply_spinquant + from torchao.quantization.pt2e import MinMaxObserver from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -380,15 +391,18 @@ def _tag_ios(self, node, fixed_point_type): return quant_io_type - def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): + def quantize( + self, + quant_dtype, + args, + tokenizer, + custom_annotations=(), + scales_state_dict=None, + ): self.quant_dtype = quant_dtype - quantizer = make_quantizer( - quant_dtype=quant_dtype, - per_channel_conv=True, - per_channel_linear=True, - act_observer=MinMaxObserver, + quantizer = make_custom_quantizer( + quant_dtype, args.range_setting, custom_annotations ) - quantizer.add_custom_quant_annotations(custom_annotations) self.has_quant_io = True fx_graph_module = None @@ -408,6 +422,7 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): fx_graph_module = prepare_pt2e(fx_graph_module, quantizer) logging.info("Quantizing the model...") + calibrate( self.get_example_inputs(self.llama_meta["get_use_kv_cache"]), args.prompt[0], @@ -419,6 +434,11 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): use_i64_token=args.embedding_quantize is not None, ) + if scales_state_dict: + set_scales( + fx_graph_module, scales_state_dict, self.llama_graph_module.head_dim + ) + self.llama_graph_module = convert_pt2e(fx_graph_module) def lowering_modules( @@ -597,6 +617,55 @@ def permute(w, heads): end_load_ts = time.time() logging.info(f"Time for loading checkpoint: {end_load_ts - start_ts}") + if args.spinquant: + config = types.SimpleNamespace( + dim=prefill_config.dim, + head_dim=prefill_config.dim // prefill_config.n_heads, + n_local_heads=prefill_config.n_heads, + intermediate_size=4 * prefill_config.dim, + ) + for llama_instance in llama_instance_list: + model = llama_instance + model.config = config + # Currently this script is on CPU: run with CUDA_VISIBLE_DEVICES=-1 + apply_spinquant( + model, + use_r1=True, + use_r2=True, + use_r4=False, + pretrained_rotation_path=None, + qkv_split=True, + ) + logging.info("Applied SpinQuant to the model") + + scales_state_dict = dict() + if args.range_setting == "mse_with_act_loss": + try: + scales_state_dict = torch.load( + "scales_state_dict.pth", map_location=torch.device("cpu") + ) + logging.info("Loaded scales_state_dict from file") + except: + logging.info("Computing scales using activation loss range setting") + model = llama_instance_list[1] + model.to(torch.float) + ar_len, model.ar_len = model.ar_len, model.max_seq_len + tokens, atten_mask = model.get_example_inputs(use_kv_cache=False) + atten_mask.to(torch.float) + wrapped_model = WrappedLlamaModel( + model, atten_mask, model.use_kv_cache, args.max_seq_len, args.device + ) + act_bits, weight_bits = { + "8a8w": (8, 8), + "16a4w": (16, 4), + "16a4w_block": (16, 4), + }[args.ptq] + scales_state_dict = compute_scales( + wrapped_model, tokens, weight_bits, act_bits, 1600 + ) + reverse_quantize_module_swap(wrapped_model) + model.ar_len = ar_len + for llama_instance in llama_instance_list: for layer in llama_instance.layers: if getattr(layer.attention, "prepare_sha", None): @@ -658,6 +727,7 @@ def permute(w, heads): args=args, tokenizer=tokenizer, custom_annotations=custom_annotations, + scales_state_dict=scales_state_dict, ) # If hybrid and lookahead mode, we store kv output quant_attrs and apply to prefill output quant_attrs later if i == 0 and args.model_mode in ["hybrid", "lookahead"]: @@ -673,7 +743,7 @@ def permute(w, heads): annotate_prefill_kv_output, kv_quant_attrs=kv_quant_attrs, ), - ) + ) # temporarily remove annotate_prefill_kv_output llama_instance.passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True llama_instance.passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ "get_quant_io_dtype_fn" @@ -1061,6 +1131,18 @@ def _build_parser(): default=8, type=int, ) + # TODO: remove mse_weight_only (doesn't help much), only keep mse_with_act_loss (=SeqMSE) + parser.add_argument( + "--range_setting", + help="Choose which range setting method for weight quantization (e.g. mse_weight_only or mse_with_act_loss). If not specified, defaults to minmax", + type=str, + ) + + parser.add_argument( + "--spinquant", + help="Apply SpinQuant (R1+R2) to the model. Uses random Hadamard matrices for rotations", + action="store_true", + ) parser.add_argument("-v", "--verbose", action="store_true") diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index f7893792e00..e710443f07a 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -25,8 +25,8 @@ def apply_rotary_emb_single( x_r, x_i = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] # broadcast for batch_prefill mode input x if x.dim() == 4: - freqs_cos = freqs_cos[None, None, :, :] - freqs_sin = freqs_sin[None, None, :, :] + freqs_cos = freqs_cos[None, :, None, :] + freqs_sin = freqs_sin[None, :, None, :] x_out_r = x_r * freqs_cos - x_i * freqs_sin x_out_i = x_r * freqs_sin + x_i * freqs_cos diff --git a/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py b/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py new file mode 100644 index 00000000000..4ef3e8cfe94 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py @@ -0,0 +1,309 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +""" +The goal of this is to allow range setting methods from TorchAO (formerly Quanty) +to be incorporated into the PT2E flow. + +We implement the two main range setting methods: +1) MSE weight range setting +2) Activation loss weight range setting + +""" + +import torch +import torch.nn as nn +from executorch.backends.qualcomm.quantizer.annotators import OP_ANNOTATOR +from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import ( + PerChannelParamObserver, +) + +from executorch.backends.qualcomm.quantizer.qconfig import ( + _derived_bias_quant_spec, + QuantizationConfig, +) +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype + +from executorch.examples.qualcomm.utils import make_quantizer + +from torchao.prototype.quantization.module_swap import ( + QuantizationRecipe, + quantize_module_swap, + QuantizedLinear, +) +from torchao.prototype.quantization.module_swap.module_swap import ( + get_layer_parent_by_name, +) +from torchao.prototype.quantization.module_swap.quantized_modules import ( + QuantizedEmbedding, +) +from torchao.prototype.quantization.module_swap.range_setting_methods import ( + set_weight_range_activation_loss, +) + +from torchao.quantization.pt2e import MinMaxObserver, PerChannelMinMaxObserver +from torchao.quantization.pt2e.quantizer import QuantizationSpec + + +class WrappedLlamaModel(nn.Module): + def __init__( + self, model, atten_mask, use_kv_cache=False, max_seq_len=512, device="cuda" + ): + super(WrappedLlamaModel, self).__init__() + self.model = model + self.max_seq_len = max_seq_len + self.use_kv_cache = use_kv_cache + self.device = device + self.atten_mask = atten_mask + + def forward( + self, + tokens: torch.Tensor, + *args, + ): + # Pad input if necessary, since LlamaModel requires static shape + if tokens.shape[1] != self.max_seq_len: + tokens = torch.nn.functional.pad( + tokens, (0, self.max_seq_len - tokens.shape[1]) + ) + return self.model.forward(tokens, self.atten_mask) + + +class PerChannelMSEObserver(PerChannelParamObserver): + + def forward(self, x_orig): + # since params are static, one calibration is enough + if not self.calibrated: + x = x_orig.detach().to(self.min_val.dtype) + self.min_val, self.max_val = self.line_search(x) + self.calibrated = True + + return x_orig + + +class PerChannelFixedQParamsObserver(PerChannelMinMaxObserver): + r""" + Fixed scale that you set manually (for per channel quantization) + Symmetric quantization, so zero point is always zero + If scale not set, defaults to minmax + """ + + def __init__( + self, + ch_axis=0, + dtype=torch.quint8, + qscheme=torch.per_channel_symmetric, + quant_min=0, + quant_max=255, + is_dynamic=False, + **kwargs, + ): + super().__init__( + ch_axis=ch_axis, + dtype=dtype, + qscheme=qscheme, + is_dynamic=is_dynamic, + **kwargs, + ) + self.quant_min = quant_min + self.quant_max = quant_max + + def set_scale(self, scale): + self.register_buffer("scale", scale.clone().detach()) + self.register_buffer("zero_point", torch.zeros_like(scale)) + + def calculate_qparams(self): + if hasattr(self, "scale") and hasattr(self, "zero_point"): + print("Using precomputed scale") + return self.scale, self.zero_point + print("Using minmax scale") + return self._calculate_qparams(self.min_val, self.max_val) + + +def reverse_quantize_module_swap(model: nn.Module) -> nn.Module: + model = reverse_replace_all_linear_with_quantized(model) + model = reverse_replace_all_embedding_with_quantized( + model + ) # if embedding_quantize was false, does nothing + return model + + +def reverse_replace_all_embedding_with_quantized(model: nn.Module) -> nn.Module: + for name, module in model.named_modules(): + if isinstance(module, QuantizedEmbedding): + embedding = nn.Embedding( + num_embeddings=module.num_embeddings, + embedding_dim=module.embedding_dim, + padding_idx=module.padding_idx, + max_norm=module.max_norm, + norm_type=module.norm_type, + scale_grad_by_freq=module.scale_grad_by_freq, + sparse=module.sparse, + _weight=module.weight, + ) + attribute_name = name.rsplit(".", 1)[-1] + parent_of_module = get_layer_parent_by_name(model, name) + setattr(parent_of_module, attribute_name, embedding) + + # logger.info(f"replaced {name} with original embedding") + return model + + +def reverse_replace_all_linear_with_quantized( + model: nn.Module, +) -> nn.Module: + for name, module in model.named_modules(): + if isinstance(module, QuantizedLinear): + linear = nn.Linear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + ) + linear.weight = module.weight + linear.bias = module.bias + + attribute_name = name.rsplit(".", 1)[-1] + parent_of_module = get_layer_parent_by_name(model, name) + setattr(parent_of_module, attribute_name, linear) + + # logger.info(f"replaced {name} with originallinear") + return model + + +def compute_scales(model, data, weight_bits, act_bits, num_points=1600): + recipe = QuantizationRecipe( + weight_bits=weight_bits, # TODO: should be based on dtype! + weight_quantization=True, + dynamic_weights=False, + weight_group_size="per_channel", + activation_bits=act_bits, # same as above + activation_quantization=True, + activation_group_size="per_tensor", + input_quantization=True, + output_quantization=True, + dynamic_activations=False, + ) + + quantized_model = quantize_module_swap(model, recipe) + + set_weight_range_activation_loss( + quantized_model, data, 1, num_points + ) # batch_size = 1 for us + scales_state_dict = {} + for name, module in quantized_model.named_modules(): + if isinstance(module, QuantizedLinear): + scales_state_dict[name] = module.weight_scale.clone().detach() + + return scales_state_dict + + +def make_custom_quantizer( + quant_dtype, range_setting=None, custom_annotations=(), linear_only=False +): + quantizer = make_quantizer( + quant_dtype=quant_dtype, + per_channel_conv=True, + per_channel_linear=True, + act_observer=MinMaxObserver, + ) + if range_setting in ("mse_weight_only", "mse_with_act_loss", "na"): + if range_setting == "na": + observer = PerChannelMinMaxObserver + elif range_setting == "mse_weight_only": + observer = PerChannelMSEObserver.with_args( + **{"steps": 200, "use_mse": True} + ) + else: + observer = PerChannelFixedQParamsObserver.with_args(**{"eps": 2**-12}) + weight_dtype = ( + torch.int4 + if quant_dtype in (QuantDtype.use_16a4w, QuantDtype.use_16a4w_block) + else torch.int8 + ) + per_channel_q_config = quantizer.default_quant_config.quant_config + weight_qspec = QuantizationSpec( + dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype, + quant_min=( + -7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1 + ), + quant_max=( + 7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max + ), + qscheme=torch.per_channel_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=observer, + ) + quantizer.default_quant_config.per_channel_quant_config = QuantizationConfig( + input_activation=per_channel_q_config.input_activation, + output_activation=per_channel_q_config.output_activation, + weight=weight_qspec, + bias=_derived_bias_quant_spec, + ) + if linear_only: + all_keys = set(OP_ANNOTATOR.keys()) + conv_keys = { + op + for op in all_keys + if op.__name__ + in ( + "conv1d.default", + "conv2d.default", + "conv_transpose2d.input", + "linear.default", + ) + } + quantizer.add_discard_ops(all_keys.difference(conv_keys)) + else: + quantizer.add_custom_quant_annotations(custom_annotations) + return quantizer + + +def set_scales(prepared_model, scales_state_dict, head_dim=64): + for node in prepared_model.graph.nodes: + if node.op == "get_attr": + split_target = node.target.split(".") + if len(split_target) > 3 and split_target[-3] in ( + "wq_sha", + "wk_sha", + "wv_sha", + ): + shorter = split_target[-3][:2] + key = ".".join(["model"] + split_target[:-3] + [shorter]) + observer_name = str(list(node.users.keys())[0]) + observer = getattr(prepared_model, observer_name) + i = int(split_target[-2]) + try: + observer.set_scale( + scales_state_dict[key][head_dim * i : head_dim * (i + 1), :] + ) + print("Set scale for", key) + except Exception: + print("Failed to set scale for ", key, node.target) + elif len(split_target) > 1 and split_target[-2] in ( + "wo_sha", + "w1_conv", + "w2_conv", + "w3_conv", + ): + shorter = split_target[-2][:2] + key = ".".join(["model"] + split_target[:-2] + [shorter]) + observer_name = str(list(node.users.keys())[0]) + observer = getattr(prepared_model, observer_name) + try: + observer.set_scale(scales_state_dict[key]) + print("Set scale for", key) + except Exception: + print("Failed to set scale for ", key, node.target) + elif len(split_target) > 2 and split_target[-3] == "output": + key = ".".join(["model"] + split_target[:-2]) + observer_name = str(list(node.users.keys())[0]) + observer = getattr(prepared_model, observer_name) + try: + observer.set_scale(scales_state_dict[key]) + print("Set scale for", key) + except Exception: + print("Failed to set scale for ", key, node.target)