diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index a73ce9acbde..851b547eb6c 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -64,6 +64,7 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.prelu.default, exir_ops.edge.aten.relu.default, exir_ops.edge.aten._softmax.default, # TODO: Need to find a new solution to do "axis_order" to transform axis. + exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.sqrt.default, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.sum.dim_IntList, diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index d68441c2f79..1c24d00390d 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -14,6 +14,7 @@ exir_ops.edge.aten.full.default, exir_ops.edge.aten.slice_scatter.default, exir_ops.edge.aten.copy.default, + exir_ops.edge.quantized_decomposed.embedding_4bit.dtype, ] to_be_implemented_operator = [ diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 8a3ff405711..0e021c02e68 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -22,7 +22,7 @@ from torch.fx import Node -def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: +def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901 """ This function is specific for matmul op 16a8w. """ diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index e07ca24d90f..abe51066ba0 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -221,6 +221,7 @@ def get_ptq_per_channel_quant_config( act_dtype=torch.uint8, weight_dtype=torch.int8, act_observer=MovingAverageMinMaxObserver, + act_symmetric: bool = False, ) -> QuantizationConfig: extra_args: Dict[str, Any] = {"eps": 2**-12} @@ -241,13 +242,27 @@ def get_ptq_per_channel_quant_config( ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" # torch do not support uint16 quantization, use int32 to bypass - act_quantization_spec = QuantizationSpec( - dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, - quant_min=torch.iinfo(act_dtype).min, - quant_max=torch.iinfo(act_dtype).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) + if act_symmetric: + # If zero_point is 128, htp can do optimizations. + # If we keep quant_min and quant_max none, observer will default use 128 as zero_point. + # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired. + act_quantization_spec = QuantizationSpec( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + else: + # PyTorch will remove redundant observers based on attributes such as: + # dtype, quant_min, quant_max, ch_axis, etc. + # Providing values like quant_min and quant_max can help observers compare + # and further reduce the number of observers. + act_quantization_spec = QuantizationSpec( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + quant_min=torch.iinfo(act_dtype).min, + quant_max=torch.iinfo(act_dtype).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) weight_quantization_spec = QuantizationSpec( dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index da7b0174c02..7a41fb1ae21 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from enum import IntEnum, unique +from functools import partial from typing import Callable, Optional, Sequence, Set import torch @@ -67,28 +68,44 @@ class QuantDtype(IntEnum): # PTQ (QuantDtype.use_16a16w, False): ( get_16a16w_qnn_ptq_config, - get_ptq_per_channel_quant_config(torch.uint16, torch.int16), + partial( + get_ptq_per_channel_quant_config, + act_dtype=torch.uint16, + weight_dtype=torch.int16, + ), ), (QuantDtype.use_16a8w, False): ( get_16a8w_qnn_ptq_config, - get_ptq_per_channel_quant_config(torch.uint16, torch.int8), + partial( + get_ptq_per_channel_quant_config, + act_dtype=torch.uint16, + weight_dtype=torch.int8, + ), ), (QuantDtype.use_16a4w, False): ( get_16a4w_qnn_ptq_config, - get_ptq_per_channel_quant_config(torch.uint16, "int4"), + partial( + get_ptq_per_channel_quant_config, + act_dtype=torch.uint16, + weight_dtype="int4", + ), ), (QuantDtype.use_8a8w, False): ( get_8a8w_qnn_ptq_config, - get_ptq_per_channel_quant_config(), + partial(get_ptq_per_channel_quant_config), ), # QAT, (QuantDtype.use_16a4w, True): ( get_16a4w_qnn_qat_config, - get_qat_per_channel_quant_config(torch.uint16, "int4"), + partial( + get_qat_per_channel_quant_config, + act_dtype=torch.uint16, + weight_dtype="int4", + ), ), (QuantDtype.use_8a8w, True): ( get_8a8w_qnn_qat_config, - get_qat_per_channel_quant_config(), + partial(get_qat_per_channel_quant_config), ), } @@ -176,11 +193,18 @@ def set_quant_config( f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support" ) - quant_config_fuc, self.per_channel_quant_config = quant_config_dict[ + quant_config_fuc, per_channel_quant_config_fuc = quant_config_dict[ (quant_dtype, is_qat) ] self.quant_config = ( - quant_config_fuc(act_observer) if act_observer else quant_config_fuc() + quant_config_fuc(act_observer=act_observer) + if act_observer + else quant_config_fuc() + ) + self.per_channel_quant_config = ( + per_channel_quant_config_fuc(act_observer=act_observer) + if act_observer + else per_channel_quant_config_fuc() ) def set_per_channel_conv_quant(self, enable: bool) -> None: diff --git a/examples/qualcomm/oss_scripts/llama3_2/llama.py b/examples/qualcomm/oss_scripts/llama3_2/llama.py index 706c04fd0d9..532eb68319d 100755 --- a/examples/qualcomm/oss_scripts/llama3_2/llama.py +++ b/examples/qualcomm/oss_scripts/llama3_2/llama.py @@ -293,10 +293,7 @@ def compile(args): start_quantize_ts = time.time() single_llama.quantize( quant_dtype, - custom_annotations=( - custom_annotate_llama_last_conv_16a8w, - annotate_matmul_16a8w, - ), + custom_annotations=(annotate_matmul_16a8w,), ) end_quantize_ts = time.time() logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}")