Skip to content

Commit c1c5a71

Browse files
authored
Qualcomm AI Engine Direct - fix suite op (pytorch#15162)
### Summary - fix annotation logic for non arithmetic op - partition out unsupported embedding_bag op - use calibration input in suite test when verifying quantized op ### Test plan e.g. pytest --disable-warnings -c /dev/nul backends/test/suite/ -k test_split_size_dimensions[qnn_16a16w] cc @cccclai @cbilgin
1 parent 72b1fa1 commit c1c5a71

File tree

6 files changed

+116
-27
lines changed

6 files changed

+116
-27
lines changed

backends/qualcomm/partition/common_defs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
from executorch.exir.dialects._ops import ops as exir_ops
1111

1212
not_supported_operator = [
13+
# output size is data dependent on the slice index
14+
exir_ops.edge.aten._embedding_bag.default,
15+
# for graph sharding purpose, different from the op used in decoder models
1316
exir_ops.edge.dim_order_ops._clone_dim_order.default,
17+
# QNN does not support 4-bit embedding
1418
exir_ops.edge.quantized_decomposed.embedding_4bit.dtype,
1519
]
1620

backends/qualcomm/quantizer/annotators.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
)
2626
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
2727

28+
from .observers.concat_observer import ConcatObserver
29+
2830
from .qconfig import (
2931
get_16a16w_qnn_ptq_config,
3032
get_16a4w_qnn_qat_config,
@@ -691,7 +693,7 @@ def annotate_sign(node: Node, quantization_config: QuantizationConfig) -> None:
691693

692694
@register_annotator([torch.ops.aten.slice.Tensor])
693695
def annotate_slice(node: Node, quantization_config: QuantizationConfig) -> None:
694-
annotate_single_in_single_out(node, quantization_config)
696+
annotate_single_in_share_out(node, quantization_config)
695697

696698

697699
@register_annotator([torch.ops.aten.slice_scatter.default])
@@ -1277,31 +1279,40 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) ->
12771279

12781280
@register_annotator([torch.ops.aten.cat.default, torch.ops.aten.concat.default])
12791281
def annotate_cat(node: Node, quantization_config: QuantizationConfig) -> None:
1280-
input_nodes = node.args[0]
12811282
if _is_annotated([node]) or not _is_float_tensor(node):
12821283
return
12831284

1284-
assert isinstance(input_nodes, Sequence)
1285-
1286-
first_input_node = input_nodes[0]
1287-
input_qspec_map = {}
1288-
assert isinstance(first_input_node, Node)
1289-
assert isinstance(node, Node)
1290-
if _is_float_tensor(first_input_node):
1291-
input_qspec_map[first_input_node] = quantization_config.input_activation
1292-
share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
1293-
(first_input_node, node)
1294-
)
1295-
1296-
for input_node in input_nodes[1:]:
1297-
if input_node not in input_qspec_map:
1298-
assert isinstance(input_node, Node)
1299-
if _is_float_tensor(input_node):
1300-
input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
1301-
1285+
input_qspec_map, input_nodes = {}, node.args[0]
1286+
for input in input_nodes:
1287+
input_qspec = input.meta.get(Q_ANNOTATION_KEY, None)
1288+
if (
1289+
# placeholder
1290+
input_qspec is None
1291+
or
1292+
# keep shared qspec here for propagation the data range
1293+
# without introducing extra requantizations
1294+
not isinstance(input_qspec.output_qspec, SharedQuantizationSpec)
1295+
):
1296+
input_qspec_map[input] = quantization_config.input_activation
1297+
1298+
output_qspec = QuantizationSpec(
1299+
dtype=quantization_config.output_activation.dtype,
1300+
qscheme=quantization_config.output_activation.qscheme,
1301+
quant_max=quantization_config.output_activation.quant_max,
1302+
quant_min=quantization_config.output_activation.quant_min,
1303+
observer_or_fake_quant_ctr=ConcatObserver.with_args(
1304+
# we need to know the concat node in order to hack all the input observers' data range
1305+
# since deep copy of fake tensor (node.meta["val"]) is inhibited
1306+
# we could only ship grap & node name and perform postprocess inside observer currently
1307+
**{
1308+
"node_name": node.name,
1309+
"graph": node.graph,
1310+
}
1311+
),
1312+
)
13021313
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
13031314
input_qspec_map=input_qspec_map,
1304-
output_qspec=share_qparams_with_input_act0_qspec,
1315+
output_qspec=output_qspec,
13051316
_annotated=True,
13061317
)
13071318

@@ -1345,6 +1356,7 @@ def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:
13451356
input_act = node.args[0]
13461357
assert isinstance(input_act, Node)
13471358
input_qspec_map[input_act] = quantization_config.input_activation
1359+
share_qparams_with_input_node_qspec = SharedQuantizationSpec((input_act, node))
13481360

13491361
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
13501362
input_qspec_map=input_qspec_map,
@@ -1353,7 +1365,7 @@ def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:
13531365

13541366
for user in node.users:
13551367
user.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
1356-
output_qspec=quantization_config.output_activation,
1368+
output_qspec=share_qparams_with_input_node_qspec,
13571369
_annotated=True,
13581370
)
13591371

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from torchao.quantization.pt2e import UniformQuantizationObserverBase
9+
10+
11+
class ConcatObserver(UniformQuantizationObserverBase):
12+
"""
13+
Fetch maximum data range of all tensors to be concatenated
14+
"""
15+
16+
def __init__(
17+
self,
18+
node_name,
19+
graph,
20+
dtype=torch.uint8,
21+
qscheme=torch.per_tensor_affine,
22+
reduce_range=False,
23+
quant_min=None,
24+
quant_max=None,
25+
factory_kwargs=None,
26+
eps=torch.finfo(torch.float32).eps, # noqa: B008
27+
is_dynamic=False,
28+
**kwargs,
29+
) -> None:
30+
super().__init__(
31+
dtype=dtype,
32+
qscheme=qscheme,
33+
reduce_range=reduce_range,
34+
quant_min=quant_min,
35+
quant_max=quant_max,
36+
factory_kwargs=factory_kwargs,
37+
eps=eps,
38+
is_dynamic=is_dynamic,
39+
**kwargs,
40+
)
41+
42+
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
43+
self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
44+
self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
45+
# get concat node and its inputs
46+
self.concat_node = [node for node in graph.nodes if node.name == node_name][0]
47+
self.input_nodes = self.concat_node.args[0]
48+
self.input_observers = []
49+
50+
def forward(self, x_orig):
51+
# calculate the min / max first
52+
self.min_val = min(self.min_val, x_orig.min())
53+
self.max_val = max(self.max_val, x_orig.max())
54+
55+
if len(self.input_observers) == 0:
56+
# collect observers first if they are not cached
57+
# we cannot do this in constructor since observers have not appeared
58+
for node in self.input_nodes:
59+
obs_node = list(
60+
filter(lambda user: user != self.concat_node, node.users.keys())
61+
)[0]
62+
self.input_observers.append(
63+
getattr(obs_node.graph.owning_module, obs_node.name)
64+
)
65+
66+
# update min / max for all observers of input nodes
67+
for observers in self.input_observers:
68+
observers.min_val = self.min_val
69+
observers.max_val = self.max_val
70+
71+
return x_orig
72+
73+
def calculate_qparams(self):
74+
return self._calculate_qparams(self.min_val, self.max_val)

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6043,7 +6043,7 @@ def test_llama_stories_110m(self):
60436043
# x86 does not allow weight sharing, so we don't check pte size
60446044
if not self.enable_x86_64:
60456045
pte_size = msg["pte_size"]
6046-
self.assertLessEqual(pte_size, 130_000_000) # 130MB
6046+
self.assertLessEqual(pte_size, 135_000_000) # 135MB
60476047
if not self.compile_only and not self.enable_x86_64:
60486048
self.assertGreaterEqual(msg["inference_speed"], 220) # Lanai
60496049

backends/test/suite/operators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class TestCaseShim:
6969
def __init__(self, test_runner):
7070
self._test_runner = test_runner
7171

72-
def _test_op(self, model, args, flow, generate_random_test_inputs=True):
72+
def _test_op(self, model, args, flow, generate_random_test_inputs=False):
7373
self._test_runner.lower_and_run_model(
7474
model, args, generate_random_test_inputs=generate_random_test_inputs
7575
)

examples/qualcomm/oss_scripts/fastvit.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def main(args):
7575
dtype=torch.uint8,
7676
qscheme=torch.per_tensor_affine,
7777
observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(
78-
**{"averaging_constant": 0.02}
78+
**{"averaging_constant": 0.01}
7979
),
8080
)
8181
weight_qspec = QuantizationSpec(
@@ -85,7 +85,7 @@ def main(args):
8585
qscheme=torch.per_channel_symmetric,
8686
ch_axis=0,
8787
observer_or_fake_quant_ctr=PerChannelParamObserver.with_args(
88-
**{"steps": 200, "use_mse": True}
88+
**{"steps": 100, "use_mse": True}
8989
),
9090
)
9191
# rewrite default per-channel ptq config
@@ -114,7 +114,6 @@ def main(args):
114114
dataset=inputs,
115115
skip_node_id_set=skip_node_id_set,
116116
skip_node_op_set=skip_node_op_set,
117-
quant_dtype=QuantDtype.use_8a8w,
118117
custom_quantizer=quantizer,
119118
shared_buffer=args.shared_buffer,
120119
)

0 commit comments

Comments
 (0)