Skip to content

Commit dc4be7c

Browse files
winskuo-quicfacebook-github-bot
authored andcommitted
Qualcomm AI Engine Direct - Observer Fix and remove unused passes (#6225)
Summary: - `ConvertToLinear()` is redundant in `qnn_preprocess.py` since this pass is already called in `executorch/backends/qualcomm/utils/utils.py` - Some models are experiencing a significant drop in accuracy, with a few models having 0% accuracy. Adding new conditions to perform requantization and change ptq_per_channel_quant_config's IO from MinMaxObserver to MovingAverageMinMaxObserver to resolve the issue. 1. Why adding new conditions to do requantization? We noticed this change in PyTorch PR (pytorch/pytorch@b8eef50#diff-976c3b0c6f85048d3db01a0c394ce8eb16e2f7541f0983d0f4ef549baa4be822L152). Before this PR, quantization spec only checks whether 2 qspecs were same by comparing `dtype` and `is_dynamic`. After this change, it checks for more attributes such as `scale`, `zero_point`, etc. This causes some nodes having an extra pair of QDQ nodes. As shown in the image below, there are 2 pairs of QDQ nodes after the PyTorch PR, and these 2 pairs of QDQ nodes have different scale and offset. For QNN lowering process, node will only save the quant info right after the node output. For example, `cat` op below will use `quantize_per_tensor_default_18`'s scale and offset as the node's quant attribute, and all other quant and dequant nodes will be ignored. This causes an accuracy drop, but by inserting a requantize node, we can see an improvement in accuracy for most models. Taking inceptionv3 as an example, the average top1 accuracy 0%->~75%. I have checked a couple other models and see accuracy either stays the same or have improvements. I have also provided the option for users to skip this requant optimization if they preferred not to use it. **Before:** ![image](https://github.com/user-attachments/assets/e6048b24-347c-4a5b-8406-c11dc14d33ae) ___ **After** ![image](https://github.com/user-attachments/assets/200cca57-f4f7-48bc-83fb-fc1595935569) 2. Why change ptq_per_channel_quant_config's IO from MinMaxObserver to MovingAverageMinMaxObserver? After the above change, it seems like there is an inference speed drop due to requantization. By switching to MovingAverageMinMaxObserver, I observed an improvement in inference speed for some models such as inceptionv3. Pull Request resolved: #6225 Reviewed By: kirklandsign Differential Revision: D64413835 Pulled By: cccclai fbshipit-source-id: a8be66b034c69ff403f9f2985f2b584695f3798b
1 parent 423f65d commit dc4be7c

File tree

5 files changed

+30
-8
lines changed

5 files changed

+30
-8
lines changed

backends/qualcomm/_passes/annotate_quant_attrs.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,12 @@ class AnnotateQuantAttrs(ExportPass):
2727
generated after quatization process.
2828
"""
2929

30-
def __init__(self, edge_program: torch.export.ExportedProgram):
30+
def __init__(
31+
self, edge_program: torch.export.ExportedProgram, skip_advanced_requat: bool
32+
):
3133
super(AnnotateQuantAttrs, self).__init__()
3234
self.edge_program = edge_program
35+
self.skip_advanced_requant = skip_advanced_requat
3336

3437
def _annotate_source_nodes(
3538
self, quant_node: torch.fx.Node, quant_attrs: Dict[str, Any]
@@ -68,9 +71,26 @@ def _annotate_requant(self, n):
6871

6972
# TODO: Store multiple pairs of requantize attributes when we have an op builder
7073
# that has multiple outputs that requires quant attributes.
71-
if q_attrs["dtype"] != dq_attrs["dtype"]:
72-
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
73-
n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs
74+
if self.skip_advanced_requant:
75+
if q_attrs["dtype"] != dq_attrs["dtype"]:
76+
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
77+
n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs
78+
else:
79+
# When dtype is the same but other specs such as scale and offset are different,
80+
# insert requant to improve accuracy.
81+
# Users can turn this feature off if any inference speed drop is observed.
82+
if any(
83+
q_attrs[attr] != dq_attrs[attr]
84+
for attr in [
85+
"scale",
86+
"zero_point",
87+
"quant_min",
88+
"quant_max",
89+
"dtype",
90+
]
91+
):
92+
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
93+
n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs
7494

7595
# Dequant all the fold_quant parameters back to fp32.
7696
# If an operation is not supported by QNN and got fallback, it will expect a fp32 param.

backends/qualcomm/qnn_preprocess.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager
1212

1313
import torch # noqa: F401
14-
from executorch.backends.qualcomm._passes.convert_to_linear import ConvertToLinear
1514
from executorch.backends.qualcomm._passes.fuse_consecutive_transpose import (
1615
FuseConsecutiveTranspose,
1716
)
@@ -49,7 +48,6 @@ def preprocess(
4948
# QNN Delegate Specific Passes
5049
qnn_compiler_passes = PassManager(
5150
passes=[
52-
ConvertToLinear(),
5351
InsertRequantize(edge_program),
5452
InsertIOQDQ(edge_program),
5553
LayoutTransform(edge_program, insert_permute=True),

backends/qualcomm/quantizer/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def get_ptq_per_channel_quant_config(
364364
quant_min=torch.iinfo(act_dtype).min,
365365
quant_max=torch.iinfo(act_dtype).max,
366366
qscheme=torch.per_tensor_affine,
367-
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
367+
observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args),
368368
)
369369

370370
weight_quantization_spec = QuantizationSpec(

backends/qualcomm/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
QCOM_ZERO_POINT = "zero_point"
2727
QCOM_ZERO_POINTS = "zero_points"
2828
QCOM_PASS_EXPAND_BROADCAST_SHAPE = "expand_broadcast_shape"
29+
QCOM_PASS_SKIP_ADVANCED_REQUANT = "skip_advanced_requant"
2930

3031
# constants in backends/qualcomm/tests
3132
QCOM_ANNOTATION = "annotation"

backends/qualcomm/utils/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
)
7070
from executorch.backends.qualcomm.utils.constants import (
7171
QCOM_PASS_EXPAND_BROADCAST_SHAPE,
72+
QCOM_PASS_SKIP_ADVANCED_REQUANT,
7273
QCOM_QNN_COMPILE_SPEC,
7374
)
7475

@@ -305,7 +306,9 @@ def _transform(
305306
ConvertBmmToMatmul()(graph_module)
306307
ConvertInterpolateWithUpsample2D()(graph_module)
307308
I64toI32(edge_program)(graph_module)
308-
AnnotateQuantAttrs(edge_program)(graph_module)
309+
AnnotateQuantAttrs(
310+
edge_program, QCOM_PASS_SKIP_ADVANCED_REQUANT in custom_pass_config
311+
)(graph_module)
309312
AnnotateAndQuantScalar(edge_program)(graph_module)
310313
AnnotateDecomposed(edge_program)(graph_module)
311314
FoldQDQ()(graph_module)

0 commit comments

Comments
 (0)