Skip to content

Commit a7ed425

Browse files
authored
Qaulcomm AI Engine Direct - Fix quantization annotation for per channel quant (#7026)
summary: - Fix the 8a8w config in custom annotation - Enable to set act observer and symmetic argument for per channel quant - Remove unuse custom annotation in llama.py
1 parent abd739e commit a7ed425

File tree

6 files changed

+58
-20
lines changed

6 files changed

+58
-20
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class LayoutTransform(ExportPass):
6464
exir_ops.edge.aten.prelu.default,
6565
exir_ops.edge.aten.relu.default,
6666
exir_ops.edge.aten._softmax.default, # TODO: Need to find a new solution to do "axis_order" to transform axis.
67+
exir_ops.edge.aten.sigmoid.default,
6768
exir_ops.edge.aten.sqrt.default,
6869
exir_ops.edge.aten.sub.Tensor,
6970
exir_ops.edge.aten.sum.dim_IntList,

backends/qualcomm/partition/common_defs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
exir_ops.edge.aten.full.default,
1515
exir_ops.edge.aten.slice_scatter.default,
1616
exir_ops.edge.aten.copy.default,
17+
exir_ops.edge.quantized_decomposed.embedding_4bit.dtype,
1718
]
1819

1920
to_be_implemented_operator = [

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torch.fx import Node
2323

2424

25-
def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None:
25+
def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
2626
"""
2727
This function is specific for matmul op 16a8w.
2828
"""

backends/qualcomm/quantizer/qconfig.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def get_ptq_per_channel_quant_config(
221221
act_dtype=torch.uint8,
222222
weight_dtype=torch.int8,
223223
act_observer=MovingAverageMinMaxObserver,
224+
act_symmetric: bool = False,
224225
) -> QuantizationConfig:
225226
extra_args: Dict[str, Any] = {"eps": 2**-12}
226227

@@ -241,13 +242,27 @@ def get_ptq_per_channel_quant_config(
241242
), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}"
242243

243244
# torch do not support uint16 quantization, use int32 to bypass
244-
act_quantization_spec = QuantizationSpec(
245-
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
246-
quant_min=torch.iinfo(act_dtype).min,
247-
quant_max=torch.iinfo(act_dtype).max,
248-
qscheme=torch.per_tensor_affine,
249-
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
250-
)
245+
if act_symmetric:
246+
# If zero_point is 128, htp can do optimizations.
247+
# If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
248+
# If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
249+
act_quantization_spec = QuantizationSpec(
250+
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
251+
qscheme=torch.per_tensor_symmetric,
252+
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
253+
)
254+
else:
255+
# PyTorch will remove redundant observers based on attributes such as:
256+
# dtype, quant_min, quant_max, ch_axis, etc.
257+
# Providing values like quant_min and quant_max can help observers compare
258+
# and further reduce the number of observers.
259+
act_quantization_spec = QuantizationSpec(
260+
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
261+
quant_min=torch.iinfo(act_dtype).min,
262+
quant_max=torch.iinfo(act_dtype).max,
263+
qscheme=torch.per_tensor_affine,
264+
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
265+
)
251266

252267
weight_quantization_spec = QuantizationSpec(
253268
dtype=torch.int8 if weight_dtype == "int4" else weight_dtype,

backends/qualcomm/quantizer/quantizer.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
from enum import IntEnum, unique
7+
from functools import partial
78
from typing import Callable, Optional, Sequence, Set
89

910
import torch
@@ -67,28 +68,44 @@ class QuantDtype(IntEnum):
6768
# PTQ
6869
(QuantDtype.use_16a16w, False): (
6970
get_16a16w_qnn_ptq_config,
70-
get_ptq_per_channel_quant_config(torch.uint16, torch.int16),
71+
partial(
72+
get_ptq_per_channel_quant_config,
73+
act_dtype=torch.uint16,
74+
weight_dtype=torch.int16,
75+
),
7176
),
7277
(QuantDtype.use_16a8w, False): (
7378
get_16a8w_qnn_ptq_config,
74-
get_ptq_per_channel_quant_config(torch.uint16, torch.int8),
79+
partial(
80+
get_ptq_per_channel_quant_config,
81+
act_dtype=torch.uint16,
82+
weight_dtype=torch.int8,
83+
),
7584
),
7685
(QuantDtype.use_16a4w, False): (
7786
get_16a4w_qnn_ptq_config,
78-
get_ptq_per_channel_quant_config(torch.uint16, "int4"),
87+
partial(
88+
get_ptq_per_channel_quant_config,
89+
act_dtype=torch.uint16,
90+
weight_dtype="int4",
91+
),
7992
),
8093
(QuantDtype.use_8a8w, False): (
8194
get_8a8w_qnn_ptq_config,
82-
get_ptq_per_channel_quant_config(),
95+
partial(get_ptq_per_channel_quant_config),
8396
),
8497
# QAT,
8598
(QuantDtype.use_16a4w, True): (
8699
get_16a4w_qnn_qat_config,
87-
get_qat_per_channel_quant_config(torch.uint16, "int4"),
100+
partial(
101+
get_qat_per_channel_quant_config,
102+
act_dtype=torch.uint16,
103+
weight_dtype="int4",
104+
),
88105
),
89106
(QuantDtype.use_8a8w, True): (
90107
get_8a8w_qnn_qat_config,
91-
get_qat_per_channel_quant_config(),
108+
partial(get_qat_per_channel_quant_config),
92109
),
93110
}
94111

@@ -176,11 +193,18 @@ def set_quant_config(
176193
f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support"
177194
)
178195

179-
quant_config_fuc, self.per_channel_quant_config = quant_config_dict[
196+
quant_config_fuc, per_channel_quant_config_fuc = quant_config_dict[
180197
(quant_dtype, is_qat)
181198
]
182199
self.quant_config = (
183-
quant_config_fuc(act_observer) if act_observer else quant_config_fuc()
200+
quant_config_fuc(act_observer=act_observer)
201+
if act_observer
202+
else quant_config_fuc()
203+
)
204+
self.per_channel_quant_config = (
205+
per_channel_quant_config_fuc(act_observer=act_observer)
206+
if act_observer
207+
else per_channel_quant_config_fuc()
184208
)
185209

186210
def set_per_channel_conv_quant(self, enable: bool) -> None:

examples/qualcomm/oss_scripts/llama3_2/llama.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,10 +293,7 @@ def compile(args):
293293
start_quantize_ts = time.time()
294294
single_llama.quantize(
295295
quant_dtype,
296-
custom_annotations=(
297-
custom_annotate_llama_last_conv_16a8w,
298-
annotate_matmul_16a8w,
299-
),
296+
custom_annotations=(annotate_matmul_16a8w,),
300297
)
301298
end_quantize_ts = time.time()
302299
logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}")

0 commit comments

Comments
 (0)