Skip to content

Commit 0e74a17

Browse files
authored
Qualcomm AI Engine Direct - Suite Operator Test Support Part 2 (#14848)
### Summary Support following OPs - Threshold OP - negative dims permute - sqrt unit test modified to use desired input rather than random values - rsqrt unit test modified to use desired input rather than random values - per channel conv3d support For the sqrt/rsqrt, I believe the sample input for each UT is using `rand` instead of `randn` on purpose to prevent negative numbers input, however, if we don't set `generate_random_test_inputs=False`, then later on it will be using random values consisting of negative numbers, causing `nan` showing up on output. If everything works as expected, we should pass 6 more tests, bringing pass rate from **90.7% -> 91.5%** ### Test plan UT added cc @cccclai @shewu-quic @haowhsu-quic @DannyYuyang-quic @cbilgin
1 parent d8e07bd commit 0e74a17

File tree

14 files changed

+242
-47
lines changed

14 files changed

+242
-47
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .decompose_minmaxdim import DecomposeMinMaxDim
2424
from .decompose_roll import DecomposeRoll
2525
from .decompose_silu import DecomposeSilu
26+
from .decompose_threshold import DecomposeThreshold
2627
from .decompose_wrap_with_autocast import DecomposeWrapWithAutocast
2728
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
2829
from .fixed_linear_keep_dim import FixedLinearKeepDim
@@ -65,6 +66,7 @@
6566
DecomposeMinMaxDim,
6667
DecomposeRoll,
6768
DecomposeSilu,
69+
DecomposeThreshold,
6870
DecomposeWrapWithAutocast,
6971
ExpandBroadcastTensorShape,
7072
FixedLinearKeepDim,
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import torch
7+
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
10+
from .utils import merge_decomposed_graph
11+
12+
13+
class DecomposeModule(torch.nn.Module):
14+
def __init__(self, threshold, value):
15+
super().__init__()
16+
self.threshold = threshold
17+
self.value = value
18+
19+
def forward(self, x):
20+
return torch.where(x <= self.threshold, self.value, x)
21+
22+
23+
class DecomposeThreshold(ExportPass):
24+
"""
25+
Decompose threshold to less_equal and where.
26+
"""
27+
28+
def __init__(self) -> None:
29+
super().__init__()
30+
31+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
32+
graph = graph_module.graph
33+
for node in graph.nodes:
34+
if node.target in {
35+
torch.ops.aten.threshold_.default,
36+
torch.ops.aten.threshold.default,
37+
}:
38+
input_node = node.args[0]
39+
threshold = node.args[1]
40+
value = node.args[2]
41+
42+
model = DecomposeModule(threshold, value)
43+
decomposed_module = torch.export.export(
44+
model, (input_node.meta["val"],), strict=True
45+
).module()
46+
47+
with graph.inserting_before(node):
48+
# remap is used to map original node values to new node values,
49+
# which ensures that reference to nodes are correctly updated in the new graph
50+
remap = {"x": input_node}
51+
merge_decomposed_graph(
52+
remap=remap,
53+
target_node=node,
54+
target_graph=graph,
55+
decomposed_graph_module=decomposed_module,
56+
)
57+
graph.erase_node(node)
58+
59+
graph.eliminate_dead_code()
60+
graph_module.recompile()
61+
return PassResult(graph_module, True)

backends/qualcomm/_passes/lift_constant_scalar_operands.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class TensorOpInfo:
5151
# The scalar number arg[1] is missing when using default. Result in a corner case to deal
5252
aten.leaky_relu.default: TensorOpInfo(aten.prelu.default, True, False),
5353
aten.leaky_relu_.default: TensorOpInfo(aten.prelu.default, True, False),
54+
aten.where.ScalarSelf: TensorOpInfo(aten.where.self, False, True),
5455
aten.where.ScalarOther: TensorOpInfo(aten.where.self, False, True),
5556
aten.where.Scalar: TensorOpInfo(aten.where.self, False, True),
5657
aten.masked_fill.Scalar: TensorOpInfo(aten.masked_fill.Tensor, False, False),

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
DecomposeMinMaxDim,
2929
DecomposeRoll,
3030
DecomposeSilu,
31+
DecomposeThreshold,
3132
DecomposeWrapWithAutocast,
3233
ExpandBroadcastTensorShape,
3334
FixedLinearKeepDim,
@@ -200,6 +201,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
200201
self.add_pass(DecomposeScaledDotProductAttention())
201202
self.add_pass(DecomposeRoll())
202203
self.add_pass(DecomposeSilu())
204+
self.add_pass(DecomposeThreshold())
203205
self.add_pass(DecomposeWrapWithAutocast())
204206
self.add_pass(DecomposeEinsum())
205207
self.add_pass(DecomposeExpM1())
@@ -216,6 +218,7 @@ def transform_for_export_pipeline(
216218
self.add_pass(DecomposeCDist())
217219
self.add_pass(DecomposeScaledDotProductAttention())
218220
self.add_pass(DecomposeRoll())
221+
self.add_pass(DecomposeThreshold())
219222
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
220223
self.add_pass(DecomposeExpM1())
221224
self.add_pass(DecomposeWrapWithAutocast())

backends/qualcomm/builders/node_visitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
176176
user_0 = self.get_first_user(node)
177177
if "convolution" in user_0.target.__name__:
178178
# OIHW (pytorch) -> HWIO (QNN)
179-
quant_config[QCOM_AXIS] = 3
179+
quant_config[QCOM_AXIS] = node.meta["val"].dim() - 1
180180
quant_config[QCOM_AXIS_ORDER] = (2, 3, 1, 0)
181181
elif "linear" in user_0.target.__name__:
182182
# OI (pytorch) -> OI (QNN)
@@ -218,7 +218,7 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
218218
user_0 = self.get_first_user(node)
219219
# Memory layout of QNN conv weight always ends in Output. Like conv2d is HWIO
220220
if "convolution" in user_0.target.__name__:
221-
quant_config[QCOM_AXIS] = 3
221+
quant_config[QCOM_AXIS] = node.meta["val"].dim() - 1
222222
else:
223223
quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS]
224224

backends/qualcomm/builders/op_transpose.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def define_node(
4242

4343
# permutation
4444
permute_order = cast(List[int], node.args[1])
45+
# to prevent negative values
46+
permute_order = [x % len(permute_order) for x in permute_order]
4547
permute_order_shape = [len(permute_order)]
4648

4749
output_tensor = input_tensor.permute(permute_order)

backends/qualcomm/quantizer/annotators.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,7 +1358,7 @@ def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:
13581358
)
13591359

13601360

1361-
@register_annotator([torch.ops.aten.where.self])
1361+
@register_annotator([torch.ops.aten.where.self, torch.ops.aten.where.ScalarSelf])
13621362
def annotate_where(node: Node, quantization_config: QuantizationConfig) -> None:
13631363
if _is_annotated([node]):
13641364
return
@@ -1368,7 +1368,6 @@ def annotate_where(node: Node, quantization_config: QuantizationConfig) -> None:
13681368
assert isinstance(input_node, Node)
13691369
if _is_float_tensor(input_node):
13701370
input_qspec_map[input_node] = quantization_config.input_activation
1371-
13721371
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
13731372
input_qspec_map=input_qspec_map,
13741373
output_qspec=(

backends/qualcomm/quantizer/quantizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def __post_init__(self):
161161
{
162162
torch.ops.aten.conv1d.default,
163163
torch.ops.aten.conv2d.default,
164+
torch.ops.aten.conv3d.default,
164165
torch.ops.aten.conv_transpose2d.input,
165166
}
166167
)

backends/qualcomm/tests/models.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -598,28 +598,6 @@ def forward(self, x):
598598
return self.second(self.first(x))
599599

600600

601-
class Conv3dSequential(torch.nn.Module):
602-
def __init__(self, bias=True):
603-
super().__init__()
604-
self.first = torch.nn.Conv3d(
605-
in_channels=1,
606-
out_channels=3,
607-
kernel_size=(3, 3, 3),
608-
padding=1,
609-
bias=bias,
610-
)
611-
self.second = torch.nn.Conv3d(
612-
in_channels=3,
613-
out_channels=2,
614-
kernel_size=(3, 3, 3),
615-
padding=1,
616-
bias=bias,
617-
)
618-
619-
def forward(self, x):
620-
return self.second(self.first(x))
621-
622-
623601
class Conv2dSingle(torch.nn.Module):
624602
def __init__(
625603
self,
@@ -726,6 +704,28 @@ def forward(self, x):
726704
return topk_values
727705

728706

707+
class Conv3dSequential(torch.nn.Module):
708+
def __init__(self, bias=True):
709+
super().__init__()
710+
self.first = torch.nn.Conv3d(
711+
in_channels=1,
712+
out_channels=3,
713+
kernel_size=(3, 3, 3),
714+
padding=1,
715+
bias=bias,
716+
)
717+
self.second = torch.nn.Conv3d(
718+
in_channels=3,
719+
out_channels=2,
720+
kernel_size=(3, 3, 3),
721+
padding=1,
722+
bias=bias,
723+
)
724+
725+
def forward(self, x):
726+
return self.second(self.first(x))
727+
728+
729729
class ConvTranspose1dSingle(torch.nn.Module):
730730
def __init__(self, bias=True, dilation=1):
731731
super().__init__()
@@ -1507,6 +1507,15 @@ def forward(self, x):
15071507
)
15081508

15091509

1510+
class Permute(torch.nn.Module):
1511+
def __init__(self, dims: List[int]):
1512+
super().__init__()
1513+
self.dims = dims
1514+
1515+
def forward(self, x):
1516+
return x.permute(self.dims)
1517+
1518+
15101519
class PixelShuffle(torch.nn.Module):
15111520
def __init__(self, scale):
15121521
super().__init__()
@@ -1540,11 +1549,12 @@ def forward(self, x):
15401549

15411550

15421551
class PowTensorScalar(torch.nn.Module):
1543-
def __init__(self):
1552+
def __init__(self, exponent=2):
15441553
super().__init__()
1554+
self.exponent = exponent
15451555

15461556
def forward(self, x):
1547-
return torch.pow(x, 2)
1557+
return torch.pow(x, self.exponent)
15481558

15491559

15501560
class PReLUDefault(torch.nn.Module):
@@ -2001,6 +2011,19 @@ def forward(self, x):
20012011
return torch.tanh(x)
20022012

20032013

2014+
class Threshold(torch.nn.Module):
2015+
def __init__(self, threshold=0.0, value=0.0, inplace=False):
2016+
super().__init__()
2017+
self.threshold = threshold
2018+
self.value = value
2019+
self.inplace = inplace
2020+
2021+
def forward(self, x):
2022+
return torch.nn.functional.threshold(
2023+
x, threshold=self.threshold, value=self.value, inplace=self.inplace
2024+
)
2025+
2026+
20042027
class TopKandIndex(torch.nn.Module):
20052028
def __init__(self):
20062029
super().__init__()

0 commit comments

Comments
 (0)