Skip to content

Commit 7c6b85d

Browse files
committed
Qualcomm AI Engine Direct - oss model enablement (retinanet_fpn)
Summary: - e2e script for retinanet_fpn in torchvision module - refine layout transform for graph with residual connection - support group norm operator
1 parent ca47839 commit 7c6b85d

File tree

9 files changed

+563
-4
lines changed

9 files changed

+563
-4
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from executorch.backends.qualcomm.utils.constants import (
1313
QCOM_AXIS_ORDER,
1414
QCOM_INSERTED_PERMUTE,
15+
QCOM_LAYOUT_CHANGE,
1516
QCOM_QUANT_ATTRS,
1617
QCOM_REQUANTIZE,
1718
)
@@ -34,6 +35,7 @@ class LayoutTransform(ExportPass):
3435
exir_ops.edge.aten.convolution.default,
3536
exir_ops.edge.aten.max_pool2d_with_indices.default,
3637
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
38+
exir_ops.edge.aten.native_group_norm.default,
3739
exir_ops.edge.aten.pixel_shuffle.default,
3840
exir_ops.edge.aten.pixel_unshuffle.default,
3941
exir_ops.edge.aten.upsample_bilinear2d.default,
@@ -95,6 +97,7 @@ def __init__(
9597
self.edge_program = edge_program
9698
self.insert_permute = insert_permute
9799
self.qdq_opset = {*q_ops, *dq_ops}
100+
self.transformed_tag = QCOM_AXIS_ORDER
98101

99102
def mark_as_transformed(self, node: torch.fx.Node) -> None:
100103
if isinstance(node.meta["val"], (tuple, list)):
@@ -105,18 +108,18 @@ def mark_as_transformed(self, node: torch.fx.Node) -> None:
105108
f"got {getitem_node.target.__name__}"
106109
)
107110
index = getitem_node.args[1]
108-
node.meta[QCOM_AXIS_ORDER] = self.get_axis_order(
111+
node.meta[self.transformed_tag] = self.get_axis_order(
109112
eval_shape(node.meta["val"][index].shape)
110113
)
111114
else:
112-
node.meta[QCOM_AXIS_ORDER] = self.get_axis_order(
115+
node.meta[self.transformed_tag] = self.get_axis_order(
113116
eval_shape(node.meta["val"].shape)
114117
)
115118

116119
def is_transformed_node(self, node: torch.fx.Node) -> bool:
117120
if not hasattr(node, "meta"):
118121
return False
119-
return QCOM_AXIS_ORDER in node.meta
122+
return self.transformed_tag in node.meta
120123

121124
def is_layout_sensitive(self, node: torch.fx.Node) -> bool:
122125
return node.target in self.layout_sensitive_ops
@@ -186,8 +189,23 @@ def insert_node(self, graph_module, node, revert_layout: bool) -> None:
186189
# we need this to check the annotation boundary
187190
permute.meta[QCOM_INSERTED_PERMUTE] = True
188191

192+
# this is the case when residual connection happened:
193+
# e.g. consider following graph
194+
# x --> permute --> layer_norm --> permute --> conv2d --> add
195+
# └-------------------------------------┙
196+
# we should have premute node to be correctly inserted as:
197+
# x --> permute --> layer_norm --> permute --> qnn_permute --> conv2d --> add
198+
# └--------------------------------------> qnn_premute -┙
199+
# i.e. insert permute by condition between user and current node
200+
# if there are multiple users included
201+
is_node_transformed = self.is_transformed_node(node)
189202
for user in users:
190-
user.replace_input_with(node, permute)
203+
is_user_transformed = (
204+
self.is_transformed_node(user) or QCOM_LAYOUT_CHANGE in user.meta
205+
)
206+
# insert permute only in exclusive condition
207+
if is_node_transformed != is_user_transformed:
208+
user.replace_input_with(node, permute)
191209

192210
def create_call_function_node(
193211
self,
@@ -243,6 +261,15 @@ def call(self, graph_module: torch.fx.GraphModule):
243261
sensitive_nodes = [
244262
node for node in graph.nodes if self.is_layout_sensitive(node)
245263
]
264+
# perform first run traversal for identifying nodes subjected to layout changes
265+
if self.insert_permute:
266+
self.insert_permute, self.transformed_tag = False, QCOM_LAYOUT_CHANGE
267+
for node in sensitive_nodes:
268+
if not self.is_transformed_node(node):
269+
self.mark_as_transformed(node)
270+
self.traverse(node, graph_module)
271+
self.insert_permute, self.transformed_tag = True, QCOM_AXIS_ORDER
272+
246273
for node in sensitive_nodes:
247274
if not self.is_transformed_node(node):
248275
self.mark_as_transformed(node)

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
op_embedding,
2121
op_expand,
2222
op_gelu,
23+
op_group_norm,
2324
op_hardsigmoid,
2425
op_hardswish,
2526
op_hardtanh,
@@ -76,6 +77,7 @@
7677
op_embedding,
7778
op_expand,
7879
op_gelu,
80+
op_group_norm,
7981
op_hardswish,
8082
op_hardtanh,
8183
op_hardsigmoid,
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
10+
11+
import numpy as np
12+
import torch
13+
14+
from .node_visitor import NodeVisitor, register_node_visitor
15+
from .qnn_constants import OpGroupNorm, QNN_OP_PACKAGE_NAME_QTI_AISW
16+
from .utils import get_parameter
17+
18+
19+
@register_node_visitor
20+
class GroupNormVisitor(NodeVisitor):
21+
target = ["aten.native_group_norm.default"]
22+
23+
def __init__(self, *args) -> None:
24+
super().__init__(*args)
25+
26+
def define_node(
27+
self,
28+
node: torch.fx.Node,
29+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
30+
) -> PyQnnWrapper.PyQnnOpWrapper:
31+
input_node = node.args[0]
32+
input_tensor = self.get_tensor(input_node, node)
33+
input_tensor_wrapper = self.define_tensor(
34+
input_node,
35+
input_tensor,
36+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
37+
nodes_to_wrappers,
38+
is_input_tensor=True,
39+
)
40+
41+
weight_node = node.args[1]
42+
weight_tensor = get_parameter(weight_node, self.edge_program)
43+
weight_tensor_wrapper = self.define_tensor(
44+
weight_node,
45+
weight_tensor,
46+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
47+
nodes_to_wrappers,
48+
is_input_tensor=False,
49+
)
50+
51+
bias_node = node.args[2]
52+
bias_tensor = get_parameter(bias_node, self.edge_program)
53+
bias_tensor_wrapper = self.define_tensor(
54+
bias_node,
55+
bias_tensor,
56+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
57+
nodes_to_wrappers,
58+
is_input_tensor=False,
59+
)
60+
group = node.args[6]
61+
epsilon = node.args[7]
62+
63+
output_tensor = self.get_tensor(node, node, 0)
64+
output_tensor_wrapper = self.define_tensor(
65+
node,
66+
output_tensor,
67+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
68+
nodes_to_wrappers,
69+
is_input_tensor=False,
70+
)
71+
72+
group_norm_op = PyQnnWrapper.PyQnnOpWrapper(
73+
node.name,
74+
QNN_OP_PACKAGE_NAME_QTI_AISW,
75+
OpGroupNorm.op_name,
76+
)
77+
group_norm_op.AddInputTensors(
78+
[input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper]
79+
)
80+
group_norm_op.AddOutputTensors([output_tensor_wrapper])
81+
group_norm_op.AddScalarParam(
82+
OpGroupNorm.param_epsilon,
83+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
84+
{"data": np.float32(epsilon)},
85+
)
86+
group_norm_op.AddScalarParam(
87+
OpGroupNorm.param_group,
88+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
89+
{"data": np.uint32(group)},
90+
)
91+
92+
return group_norm_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,12 @@ class OpGelu:
147147
op_name: str = "Gelu"
148148

149149

150+
class OpGroupNorm:
151+
op_name: str = "GroupNorm"
152+
param_epsilon = "epsilon"
153+
param_group = "group"
154+
155+
150156
@dataclass(init=False, frozen=True)
151157
class OpHardSwish:
152158
op_name: str = "HardSwish"

backends/qualcomm/quantizer/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,39 @@ def annotate_expand(node: Node, quantization_config: QuantizationConfig) -> None
10101010
annotate_single_in_single_out(node, quantization_config)
10111011

10121012

1013+
@register_annotator([torch.ops.aten.group_norm.default])
1014+
def annotate_group_norm(node: Node, quantization_config: QuantizationConfig) -> None:
1015+
act_node = node.args[0]
1016+
weight_node = node.args[2]
1017+
bias_node = None
1018+
if len(node.args) > 2:
1019+
bias_node = node.args[3]
1020+
1021+
if _is_annotated([node]):
1022+
return
1023+
1024+
_annotate_input_qspec_map(
1025+
node,
1026+
act_node,
1027+
quantization_config.input_activation,
1028+
)
1029+
_annotate_input_qspec_map(
1030+
node,
1031+
weight_node,
1032+
quantization_config.weight,
1033+
)
1034+
nodes_to_mark_annotated = [node, weight_node]
1035+
if bias_node:
1036+
_annotate_input_qspec_map(
1037+
node,
1038+
bias_node,
1039+
quantization_config.bias,
1040+
)
1041+
nodes_to_mark_annotated.append(bias_node)
1042+
_annotate_output_qspec(node, quantization_config.output_activation)
1043+
_mark_nodes_as_annotated(nodes_to_mark_annotated)
1044+
1045+
10131046
@register_annotator([torch.ops.aten.flatten.using_ints])
10141047
def annotate_flatten(node: Node, quantization_config: QuantizationConfig) -> None:
10151048
annotate_in_out_obs_sharing_op(node, quantization_config)

backends/qualcomm/tests/models.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,24 @@ def forward(self, x):
501501
return self.gelu(x)
502502

503503

504+
class GroupNorm(torch.nn.Module):
505+
def __init__(self, bias=True):
506+
super().__init__()
507+
self.conv = torch.nn.Conv2d(
508+
32,
509+
256,
510+
kernel_size=3,
511+
stride=1,
512+
padding=1,
513+
bias=bias,
514+
)
515+
self.norm = torch.nn.GroupNorm(32, 256)
516+
517+
def forward(self, x):
518+
y = self.conv(x)
519+
return y, self.norm(y)
520+
521+
504522
class HardSigmoid(torch.nn.Module):
505523
def __init__(self):
506524
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,13 @@ def test_qnn_backend_gelu(self):
289289
sample_input = (torch.randn(2, 5, 1, 3),)
290290
self.lower_module_and_test_output(module, sample_input)
291291

292+
def test_qnn_backend_group_norm(self):
293+
modules = [GroupNorm(), GroupNorm(bias=False)] # noqa: F405
294+
sample_input = (torch.randn(3, 32, 56, 56),)
295+
for i, module in enumerate(modules):
296+
with self.subTest(i=i):
297+
self.lower_module_and_test_output(module, sample_input)
298+
292299
def test_qnn_backend_hardsigmoid(self):
293300
module = HardSigmoid() # noqa: F405
294301
sample_input = (torch.randn(2, 5, 1, 3),)
@@ -964,6 +971,14 @@ def test_qnn_backend_gelu(self):
964971
module = self.get_qdq_module(module, sample_input)
965972
self.lower_module_and_test_output(module, sample_input)
966973

974+
def test_qnn_backend_group_norm(self):
975+
modules = [GroupNorm(), GroupNorm(bias=False)] # noqa: F405
976+
sample_input = (torch.randn(3, 32, 56, 56),)
977+
for i, module in enumerate(modules):
978+
with self.subTest(i=i):
979+
module = self.get_qdq_module(module, sample_input)
980+
self.lower_module_and_test_output(module, sample_input)
981+
967982
def test_qnn_backend_hardsigmoid(self):
968983
module = HardSigmoid() # noqa: F405
969984
sample_input = (torch.randn(2, 5, 1, 3),)
@@ -2147,6 +2162,41 @@ def test_regnet(self):
21472162
self.assertGreaterEqual(msg["top_1"], 60)
21482163
self.assertGreaterEqual(msg["top_5"], 85)
21492164

2165+
def test_retinanet(self):
2166+
if not self.required_envs([self.image_dataset]):
2167+
self.skipTest("missing required envs")
2168+
2169+
cmds = [
2170+
"python",
2171+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/retinanet.py",
2172+
"--artifact",
2173+
self.artifact_dir,
2174+
"--build_folder",
2175+
self.build_folder,
2176+
"--device",
2177+
self.device,
2178+
"--model",
2179+
self.model,
2180+
"--dataset",
2181+
self.image_dataset,
2182+
"--ip",
2183+
self.ip,
2184+
"--port",
2185+
str(self.port),
2186+
]
2187+
if self.host:
2188+
cmds.extend(["--host", self.host])
2189+
2190+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2191+
with Listener((self.ip, self.port)) as listener:
2192+
conn = listener.accept()
2193+
p.communicate()
2194+
msg = json.loads(conn.recv())
2195+
if "Error" in msg:
2196+
self.fail(msg["Error"])
2197+
else:
2198+
self.assertGreaterEqual(msg["mAP"], 0.6)
2199+
21502200
def test_squeezenet(self):
21512201
if not self.required_envs([self.image_dataset]):
21522202
self.skipTest("missing required envs")

backends/qualcomm/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
QCOM_DTYPE = "dtype"
1515
QCOM_ENCODING = "encoding"
1616
QCOM_INSERTED_PERMUTE = "qnn_permute"
17+
QCOM_LAYOUT_CHANGE = "layout_change"
1718
QCOM_OFFSET = "offset"
1819
QCOM_QUANTIZED_IO = "q_tensor_io"
1920
QCOM_QUANT_ATTRS = "quant_attrs"

0 commit comments

Comments
 (0)