Skip to content

Commit a288a59

Browse files
committed
Qualcomm AI Engine Direct - XR model enablement pipe_clean
summary - support linalg_vector_norm, instance_norm - expand coverage of quantization annotator - test cases - small refactor for _pass importing
1 parent b1d76c9 commit a288a59

File tree

16 files changed

+480
-102
lines changed

16 files changed

+480
-102
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,52 @@
11
from .annotate_and_quant_scalar import AnnotateAndQuantScalar
22
from .annotate_decomposed import AnnotateDecomposed
33
from .annotate_quant_attrs import AnnotateQuantAttrs
4+
from .convert_binary_op_with_scalar import ConvertBinaryOpsWithScalar
45
from .convert_bmm_to_matmul import ConvertBmmToMatmul
56
from .convert_interpolate_with_upsample2d import ConvertInterpolateWithUpsample2D
67
from .convert_prelu import ConvertPReLU
78
from .convert_to_linear import ConvertToLinear
9+
from .decompose_einsum import DecomposeEinsum
10+
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
11+
from .decompose_silu import DecomposeSilu
812
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
913
from .fold_qdq import FoldQDQ
14+
from .fuse_consecutive_transpose import FuseConsecutiveTranspose
1015
from .i64_to_i32 import I64toI32
16+
from .insert_io_qdq import InsertIOQDQ
17+
from .insert_requantize import InsertRequantize
1118
from .layout_transform import LayoutTransform
1219
from .recompose_pixel_unshuffle import RecomposePixelUnshuffle
1320
from .recompose_rms_norm import RecomposeRmsNorm
21+
from .reduce_dynamic_range import ReduceDynamicRange
1422
from .remove_redundancy import RemoveRedundancy
1523
from .replace_index_put_input import ReplaceIndexPutInput
24+
from .replace_inf_buffer import ReplaceInfBuffer
1625

1726

1827
__all__ = [
1928
AnnotateAndQuantScalar,
2029
AnnotateDecomposed,
2130
AnnotateQuantAttrs,
2231
ConvertBmmToMatmul,
32+
ConvertBinaryOpsWithScalar,
2333
ConvertInterpolateWithUpsample2D,
2434
ConvertPReLU,
2535
ConvertToLinear,
36+
DecomposeEinsum,
37+
DecomposeLinalgVectorNorm,
38+
DecomposeSilu,
2639
ExpandBroadcastTensorShape,
2740
FoldQDQ,
41+
FuseConsecutiveTranspose,
2842
I64toI32,
43+
InsertIOQDQ,
44+
InsertRequantize,
2945
LayoutTransform,
3046
RecomposePixelUnshuffle,
3147
RecomposeRmsNorm,
48+
ReduceDynamicRange,
3249
RemoveRedundancy,
3350
ReplaceIndexPutInput,
51+
ReplaceInfBuffer,
3452
]

backends/qualcomm/_passes/convert_to_linear.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class ConvertToLinear(ExportPass):
3939
mm = exir_ops.edge.aten.mm.default
4040

4141
addmm_patterns = [
42+
{view_copy: 1, permute_copy: 1, addmm: 1},
4243
{view_copy: 2, permute_copy: 1, addmm: 1},
4344
{permute_copy: 1, addmm: 1},
4445
]
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir import to_edge
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
11+
12+
class LinalgVectorNorm(torch.nn.Module):
13+
def __init__(self, exp, dim, keepdim):
14+
super().__init__()
15+
self.exp = exp
16+
self.dim = tuple(dim) if dim is not None else None
17+
self.keepdim = keepdim
18+
19+
def forward(self, x):
20+
if self.dim is None:
21+
x = torch.flatten(x)
22+
self.dim = 0
23+
24+
x = torch.abs(x)
25+
x = torch.pow(x, self.exp)
26+
x = torch.sum(x, dim=self.dim, keepdim=self.keepdim)
27+
return torch.pow(x, 1.0 / self.exp)
28+
29+
30+
class DecomposeLinalgVectorNorm(ExportPass):
31+
"""
32+
Decompose for math equivalent op.
33+
"""
34+
35+
def __init__(self, quantization_capture=False) -> None:
36+
super().__init__()
37+
self.quantization_capture = quantization_capture
38+
39+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
40+
graph = graph_module.graph
41+
for node in graph.nodes:
42+
if "linalg_vector_norm" in str(node.target):
43+
ord = node.args[1] if len(node.args) > 1 else 2.0
44+
dim = node.args[2] if len(node.args) > 2 else None
45+
keepdim = node.args[3] if len(node.args) > 3 else False
46+
model = LinalgVectorNorm(ord, dim, keepdim)
47+
if self.quantization_capture:
48+
decomposed_module = torch.export.export(
49+
model, (node.args[0].meta["val"],)
50+
).module()
51+
else:
52+
edge_mgr = to_edge(
53+
torch.export.export(model, (node.args[0].meta["val"],))
54+
)
55+
decomposed_module = edge_mgr.exported_program()
56+
57+
with graph.inserting_before(node):
58+
# remap is used to map original node values to new node values,
59+
# which ensures that reference to nodes are correclty updated in the new graph
60+
remap = {"x": node.args[0]}
61+
62+
for decomposed_node in decomposed_module.graph.nodes:
63+
# no need to copy existent 'output'
64+
if decomposed_node.op == "output":
65+
for user in node.users.copy():
66+
# remap
67+
user.replace_input_with(
68+
node,
69+
remap[decomposed_node.args[0][0]],
70+
)
71+
# no need to copy existent placeholders
72+
elif decomposed_node.op == "placeholder":
73+
# replace node map from string to graph node
74+
remap[decomposed_node] = remap.pop(decomposed_node.name)
75+
else:
76+
remap[decomposed_node] = graph.node_copy(
77+
decomposed_node,
78+
arg_transform=lambda x, remap=remap: remap[x],
79+
)
80+
81+
graph.erase_node(node)
82+
83+
graph.eliminate_dead_code()
84+
graph_module.recompile()
85+
return PassResult(graph_module, True)

backends/qualcomm/_passes/layout_transform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ class LayoutTransform(ExportPass):
3333
exir_ops.edge.aten.adaptive_avg_pool2d.default,
3434
exir_ops.edge.aten.avg_pool2d.default,
3535
exir_ops.edge.aten.convolution.default,
36+
exir_ops.edge.aten.instance_norm.default,
3637
exir_ops.edge.aten.max_pool2d_with_indices.default,
3738
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
39+
exir_ops.edge.aten._native_batch_norm_legit.no_stats,
3840
exir_ops.edge.aten.native_group_norm.default,
3941
exir_ops.edge.aten.pixel_shuffle.default,
4042
exir_ops.edge.aten.pixel_unshuffle.default,

backends/qualcomm/_passes/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def get_passes_dependency_for_capture_program():
6464
ConvertInterpolateWithUpsample2D,
6565
ConvertPReLU,
6666
ConvertToLinear,
67+
DecomposeLinalgVectorNorm,
6768
ExpandBroadcastTensorShape,
6869
FoldQDQ,
6970
I64toI32,
@@ -81,6 +82,7 @@ def get_passes_dependency_for_capture_program():
8182
ConvertPReLU: [RemoveRedundancy],
8283
ConvertBmmToMatmul: [ConvertToLinear],
8384
ConvertInterpolateWithUpsample2D: [RemoveRedundancy],
85+
DecomposeLinalgVectorNorm: [RemoveRedundancy],
8486
I64toI32: [RemoveRedundancy],
8587
AnnotateQuantAttrs: [
8688
RecomposePixelUnshuffle,

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
op_hardtanh,
3636
op_index,
3737
op_index_put,
38+
op_instance_norm,
3839
op_layer_norm,
3940
op_le,
4041
op_linear,
@@ -109,6 +110,7 @@
109110
op_hardsigmoid,
110111
op_index,
111112
op_index_put,
113+
op_instance_norm,
112114
op_layer_norm,
113115
op_le,
114116
op_linear,

backends/qualcomm/builders/op_batch_norm.py

Lines changed: 62 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@
99

1010
import torch
1111
from executorch.backends.qualcomm.utils.constants import (
12+
QCOM_AXIS_ORDER,
1213
QCOM_QUANT_ATTRS,
1314
QCOM_QUANT_MAX,
15+
QCOM_QUANT_MIN,
1416
QCOM_SCALE,
17+
QCOM_ZERO_POINT,
1518
)
19+
from executorch.exir.dialects._ops import ops as exir_ops
1620

1721
from .node_visitor import NodeVisitor, register_node_visitor
1822
from .qnn_constants import OpBatchnorm, QNN_OP_PACKAGE_NAME_QTI_AISW
@@ -21,7 +25,10 @@
2125

2226
@register_node_visitor
2327
class BatchNorm(NodeVisitor):
24-
target = ["aten._native_batch_norm_legit_no_training.default"]
28+
target = [
29+
"aten._native_batch_norm_legit_no_training.default",
30+
"aten._native_batch_norm_legit.no_stats",
31+
]
2532

2633
def __init__(self, *args) -> None:
2734
super().__init__(*args)
@@ -43,9 +50,13 @@ def define_node(
4350
input_node = node.args[0]
4451
input_tensor = self.get_tensor(input_node, node)
4552

46-
mean_node, var_node, eps = node.args[3], node.args[4], 1e-9
47-
mean_tensor = get_parameter(mean_node, self.edge_program)
48-
var_tensor = get_parameter(var_node, self.edge_program)
53+
eps = 1e-9
54+
if "no_stats" in str(node.target):
55+
mean_tensor = torch.Tensor([node.args[4]])
56+
var_tensor = torch.Tensor([node.args[5]])
57+
else:
58+
mean_tensor = get_parameter(node.args[3], self.edge_program)
59+
var_tensor = get_parameter(node.args[4], self.edge_program)
4960

5061
input_tensor_wrapper = self.define_tensor(
5162
input_node,
@@ -54,22 +65,43 @@ def define_node(
5465
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
5566
nodes_to_wrappers,
5667
)
68+
batch_norm_input_tensors = [input_tensor_wrapper]
5769

58-
bias_node = node.args[2]
59-
bias_tensor = get_parameter(bias_node, self.edge_program)
60-
filter_node = node.args[1]
61-
filter_tensor = get_parameter(filter_node, self.edge_program)
62-
63-
amount = (filter_tensor * mean_tensor) / torch.sqrt(var_tensor + eps)
64-
bias_tensor = bias_tensor - amount
65-
self.update_encoding(bias_node, bias_tensor, eps)
66-
bias_tensor_wrapper = self.define_tensor(
67-
bias_node,
70+
output_tensor = self.get_tensor(node, node, 0)
71+
output_tensor_wrapper = self.define_tensor(
6872
node,
69-
bias_tensor,
70-
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
73+
node,
74+
output_tensor,
75+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
7176
nodes_to_wrappers,
7277
)
78+
batch_norm_output_tensors = [output_tensor_wrapper]
79+
80+
n_feature = output_tensor.shape[-1 if QCOM_AXIS_ORDER in node.meta else 1]
81+
bias_node = node.args[2]
82+
if bias_node is not None:
83+
bias_tensor = get_parameter(bias_node, self.edge_program)
84+
85+
filter_node = node.args[1]
86+
if filter_node is not None:
87+
filter_tensor = get_parameter(filter_node, self.edge_program)
88+
else:
89+
# 'graph', 'name', 'op', 'target', 'args', and 'kwargs'
90+
filter_node = torch.fx.Node(
91+
node.graph,
92+
node.name + "_filter",
93+
"call_function",
94+
exir_ops.edge.aten.scalar_tensor.default,
95+
(), # args
96+
{}, # kwargs
97+
)
98+
filter_tensor = torch.ones(n_feature)
99+
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
100+
quant_attrs = quant_attrs.copy()
101+
quant_range = quant_attrs[QCOM_QUANT_MAX] - quant_attrs[QCOM_QUANT_MIN]
102+
quant_attrs[QCOM_ZERO_POINT] = 0
103+
quant_attrs[QCOM_SCALE] = 1.0 / quant_range
104+
filter_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
73105

74106
filter_tensor = filter_tensor / torch.sqrt(var_tensor + eps)
75107
self.update_encoding(filter_node, filter_tensor, eps)
@@ -80,22 +112,20 @@ def define_node(
80112
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
81113
nodes_to_wrappers,
82114
)
83-
84-
batch_norm_input_tensors = [
85-
input_tensor_wrapper,
86-
filter_tensor_wrapper,
87-
bias_tensor_wrapper,
88-
]
89-
90-
output_tensor = self.get_tensor(node, node, 0)
91-
output_tensor_wrapper = self.define_tensor(
92-
node,
93-
node,
94-
output_tensor,
95-
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
96-
nodes_to_wrappers,
97-
)
98-
batch_norm_output_tensors = [output_tensor_wrapper]
115+
batch_norm_input_tensors.append(filter_tensor_wrapper)
116+
117+
if bias_node is not None:
118+
amount = (filter_tensor * mean_tensor) / torch.sqrt(var_tensor + eps)
119+
bias_tensor = bias_tensor - amount
120+
self.update_encoding(bias_node, bias_tensor, eps)
121+
bias_tensor_wrapper = self.define_tensor(
122+
bias_node,
123+
node,
124+
bias_tensor,
125+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
126+
nodes_to_wrappers,
127+
)
128+
batch_norm_input_tensors.append(bias_tensor_wrapper)
99129

100130
batch_norm_op = PyQnnWrapper.PyQnnOpWrapper(
101131
node.name,

0 commit comments

Comments
 (0)