diff --git a/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py b/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py index 1105ac0ef82..4e092b71892 100644 --- a/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py +++ b/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py @@ -20,6 +20,14 @@ annotate_matmul_16a8w, ) +from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import ( + PerChannelParamObserver, +) +from executorch.backends.qualcomm.quantizer.qconfig import ( + _derived_bias_quant_spec, + QuantizationConfig, +) + from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d @@ -47,6 +55,8 @@ from torchao.quantization.pt2e import MinMaxObserver from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import QuantizationSpec + sys.setrecursionlimit(4096) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -78,6 +88,33 @@ def forward( return self.model.forward(tokens, self.atten_mask) +def add_mse_weight_observer(quant_dtype, quantizer): + weight_dtype = ( + torch.int4 + if quant_dtype in (QuantDtype.use_16a4w, QuantDtype.use_16a4w_block) + else torch.int8 + ) + per_channel_q_config = quantizer.default_quant_config.quant_config + weight_qspec = QuantizationSpec( + dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype, + quant_min=( + -7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1 + ), + quant_max=(7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max), + qscheme=torch.per_channel_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=PerChannelParamObserver.with_args( + **{"steps": 200, "use_mse": True} + ), + ) + quantizer.default_quant_config.per_channel_quant_config = QuantizationConfig( + input_activation=per_channel_q_config.input_activation, + output_activation=per_channel_q_config.output_activation, + weight=weight_qspec, + bias=_derived_bias_quant_spec, + ) + + def gen_eval_wrapper(model_name, args): tokenizer = get_tokenizer(args.tokenizer_path) with open(args.params) as f: @@ -142,13 +179,13 @@ def permute(w, heads): if getattr(layer.feed_forward, "prepare_feedfoward_conv", None): layer.feed_forward.prepare_feedfoward_conv() - model.to(dtype=torch.bfloat16) + model.to(dtype=torch.float) model.to(device=args.device) tokens, atten_mask = model.get_example_inputs(use_kv_cache=False) tokens = tokens.to(device=args.device) atten_mask = atten_mask.to(device=args.device) - atten_mask = atten_mask.to(dtype=torch.bfloat16) + atten_mask = atten_mask.to(dtype=torch.float) inputs = (tokens, atten_mask) if args.embedding_quantize: @@ -174,7 +211,8 @@ def permute(w, heads): ) quantizer.add_custom_quant_annotations(custom_annotations) - model.has_quant_io = True + if args.range_setting == "mse_weight": + add_mse_weight_observer(quant_dtype, quantizer) with torch.no_grad(): model = torch.export.export(model, inputs, strict=True).module() @@ -245,6 +283,23 @@ def main() -> None: torch.manual_seed(seed) modelname = "llama2" parser = build_args_parser() + parser.add_argument( + "-P", + "--ptq", + help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w, 16a4w and 16a4w_block.", + type=str, + ) + parser.add_argument( + "--range_setting", + help="Choose which range setting method (e.g. mse_weight). If not specified, will do minmax for weights and activations", + type=str, + ) + parser.add_argument( + "--limit", + help="the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples", + type=str, + ) + args = parser.parse_args() args.llama_model = "llama3_2" # Overrides this arg, because evaluation requires full logits. @@ -257,15 +312,9 @@ def main() -> None: args.use_kv_cache = False args.prefill_ar_len = args.max_seq_length - # To do fewer samples for faster evaluation - args.limit = 0.1 - # args.samples = {'wikitext': list(range(1))} - args.device = "cuda" if torch.cuda.is_available() else "cpu" torch.set_default_device(args.device) - args.ptq = "8a8w" - eval_llama(modelname, args)