Skip to content

Commit f4540f1

Browse files
committed
Qualcomm AI Engine Direct - Support Flip & Index_Select
1 parent 5d42a39 commit f4540f1

File tree

7 files changed

+214
-17
lines changed

7 files changed

+214
-17
lines changed

backends/qualcomm/builders/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
op_eq,
3737
op_exp,
3838
op_expand,
39+
op_flip,
3940
op_floor,
4041
op_full,
4142
op_full_like,
@@ -49,6 +50,7 @@
4950
op_hardtanh,
5051
op_index,
5152
op_index_put,
53+
op_index_select,
5254
op_instance_norm,
5355
op_layer_norm,
5456
op_le,
@@ -139,6 +141,7 @@
139141
op_eq,
140142
op_exp,
141143
op_expand,
144+
op_flip,
142145
op_floor,
143146
op_full,
144147
op_full_like,
@@ -152,6 +155,7 @@
152155
op_hardsigmoid,
153156
op_index,
154157
op_index_put,
158+
op_index_select,
155159
op_instance_norm,
156160
op_layer_norm,
157161
op_le,
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import numpy as np
11+
import torch
12+
13+
from .node_visitor import NodeVisitor
14+
from .node_visitor_manager import register_node_visitor
15+
from .qnn_constants import OpStridedSlice, QNN_OP_PACKAGE_NAME_QTI_AISW
16+
17+
18+
@register_node_visitor
19+
class Flip(NodeVisitor):
20+
target = ["aten.flip.default"]
21+
22+
def __init__(self, *args) -> None:
23+
super().__init__(*args)
24+
25+
def define_node(
26+
self,
27+
node: torch.fx.Node,
28+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
29+
) -> PyQnnWrapper.PyQnnOpWrapper:
30+
input_node = self.get_node(node.args[0])
31+
input_tensor = self.get_tensor(input_node, node)
32+
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
33+
34+
input_tensor_wrapper = self.define_tensor(
35+
input_node,
36+
node,
37+
input_tensor,
38+
tensor_type,
39+
nodes_to_wrappers,
40+
)
41+
42+
output_tensor = self.get_tensor(node, node)
43+
output_tensor_wrapper = self.define_tensor(
44+
node,
45+
node,
46+
output_tensor,
47+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
48+
nodes_to_wrappers,
49+
)
50+
51+
ranges = []
52+
53+
for dim, size in enumerate(output_tensor.shape):
54+
if dim in node.args[1]:
55+
ranges.extend([size - 1, -1, -1])
56+
else:
57+
ranges.extend([0, size, 1])
58+
59+
range_shape = [input_tensor.dim(), 3]
60+
stride_slice_op = PyQnnWrapper.PyQnnOpWrapper(
61+
node.name,
62+
QNN_OP_PACKAGE_NAME_QTI_AISW,
63+
OpStridedSlice.op_name,
64+
)
65+
stride_slice_op.AddInputTensors([input_tensor_wrapper])
66+
stride_slice_op.AddOutputTensors([output_tensor_wrapper])
67+
stride_slice_op.AddTensorParam(
68+
OpStridedSlice.param_ranges,
69+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32,
70+
len(range_shape),
71+
range_shape,
72+
np.array(ranges, dtype=np.int32),
73+
True,
74+
)
75+
76+
return stride_slice_op
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import numpy as np
11+
import torch
12+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
13+
14+
from .node_visitor import NodeVisitor
15+
from .node_visitor_manager import register_node_visitor
16+
from .qnn_constants import OpGather, QNN_OP_PACKAGE_NAME_QTI_AISW
17+
18+
19+
@register_node_visitor
20+
class IndexSelect(NodeVisitor):
21+
target = ["aten.index_select.default"]
22+
23+
def __init__(self, *args) -> None:
24+
super().__init__(*args)
25+
26+
def define_node(
27+
self,
28+
node: torch.fx.Node,
29+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
30+
) -> PyQnnWrapper.PyQnnOpWrapper:
31+
input_node = self.get_node(node.args[0])
32+
input_tensor = self.get_tensor(input_node, node)
33+
input_tensor_wrapper = self.define_tensor(
34+
input_node,
35+
node,
36+
input_tensor,
37+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
38+
nodes_to_wrappers,
39+
)
40+
41+
axis = node.args[1]
42+
indices_node = node.args[2]
43+
indices_tensor = self.get_tensor(indices_node, node).to(torch.int32)
44+
assert indices_tensor.size(0) != 0, "Not support empty indices list"
45+
46+
indices_tensor_wrapper = self.define_tensor(
47+
indices_node,
48+
node,
49+
indices_tensor,
50+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
51+
nodes_to_wrappers,
52+
)
53+
54+
gather_input_tensors = [input_tensor_wrapper, indices_tensor_wrapper]
55+
56+
output_tensor = self.get_tensor(node, node)
57+
output_tensor_wrapper = self.define_tensor(
58+
node,
59+
node,
60+
output_tensor,
61+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
62+
nodes_to_wrappers,
63+
)
64+
gather_output_tensors = [output_tensor_wrapper]
65+
66+
gather_op = PyQnnWrapper.PyQnnOpWrapper(
67+
node.name,
68+
QNN_OP_PACKAGE_NAME_QTI_AISW,
69+
OpGather.op_name,
70+
)
71+
gather_op.AddInputTensors(gather_input_tensors)
72+
gather_op.AddOutputTensors(gather_output_tensors)
73+
74+
# If support tuple of tensor, need to refine it based on len
75+
gather_op.AddScalarParam(
76+
OpGather.param_axis,
77+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32,
78+
{QCOM_DATA: np.int32(axis)},
79+
)
80+
81+
return gather_op

backends/qualcomm/partition/common_defs.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@
1919
exir_ops.edge.aten.adaptive_max_pool2d.default,
2020
exir_ops.edge.aten.avg_pool3d.default,
2121
exir_ops.edge.aten.div.Tensor_mode,
22-
exir_ops.edge.aten.index_select.default,
2322
exir_ops.edge.aten.log10.default,
2423
exir_ops.edge.aten.log1p.default,
2524
exir_ops.edge.aten.log2.default,
26-
exir_ops.edge.aten.flip.default,
2725
exir_ops.edge.aten.max_pool3d_with_indices.default,
2826
exir_ops.edge.aten.median.default,
2927
exir_ops.edge.aten.median.dim,

backends/qualcomm/quantizer/annotators.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,11 +432,18 @@ def annotate_ceil(node: Node, quantization_config: QuantizationConfig) -> None:
432432
def annotate_clamp(node: Node, quantization_config: QuantizationConfig) -> None:
433433
annotate_single_in_single_out(node, quantization_config)
434434

435+
435436
@register_annotator([torch.ops.aten.index_select.default])
436437
def annotate_index_select(node: Node, quantization_config: QuantizationConfig) -> None:
437-
import pdb; pdb.set_trace()
438+
# args[2] = indices, which should be int
438439
annotate_single_in_single_out(node, quantization_config)
439440

441+
442+
@register_annotator([torch.ops.aten.flip.default])
443+
def annotate_flip(node: Node, quantization_config: QuantizationConfig) -> None:
444+
annotate_single_in_single_out(node, quantization_config)
445+
446+
440447
@register_annotator([torch.ops.aten.floor.default])
441448
def annotate_floor(node: Node, quantization_config: QuantizationConfig) -> None:
442449
annotate_single_in_single_out(node, quantization_config)

backends/qualcomm/tests/models.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -813,23 +813,29 @@ def __init__(self):
813813
def forward(self, x):
814814
return torch.special.expm1(x)
815815

816+
816817
class Flip(torch.nn.Module):
817818
def __init__(self):
818819
super().__init__()
819-
self.dims = [0,2]
820+
self.dims = [0, 2]
820821

821822
def forward(self, x):
822823
return torch.flip(x, self.dims)
823824

825+
824826
class FlipDecomp(torch.nn.Module):
825827
def __init__(self):
826828
super().__init__()
827-
self.dims = [0,2]
829+
self.dims = [0, 2]
830+
828831
def forward(self, x):
829832
for dim in self.dims:
830-
idx = torch.arange(x.size(dim) - 1, -1, -1, device=x.device)
833+
idx = torch.arange(start=x.size(dim) - 1, end=-1, step=-1)
834+
# Select using reverse index, equivalent to flipping.
831835
x = torch.index_select(x, dim, idx)
832836
return x
837+
838+
833839
class Floor(torch.nn.Module):
834840
def __init__(self):
835841
super().__init__()
@@ -1055,6 +1061,15 @@ def forward(self, input_pos, k_val):
10551061
return k_out + 0
10561062

10571063

1064+
class IndexSelect(torch.nn.Module):
1065+
def __init__(self, dim):
1066+
super().__init__()
1067+
self.dim = dim
1068+
1069+
def forward(self, x, indices):
1070+
return torch.index_select(x, self.dim, indices)
1071+
1072+
10581073
class InstanceNorm2d(torch.nn.Module):
10591074
def __init__(self, n_features, affine=True):
10601075
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,11 @@ def test_qnn_backend_expm1(self):
555555
module = ExpM1() # noqa: F405
556556
self.lower_module_and_test_output(module, sample_input)
557557

558+
def test_qnn_backend_flip(self):
559+
sample_input = (torch.randn(3, 4, 5, 6),)
560+
module = Flip() # noqa: F405
561+
self.lower_module_and_test_output(module, sample_input)
562+
558563
def test_qnn_backend_floor(self):
559564
sample_input = (torch.randn(3, 4),)
560565
module = Floor() # noqa: F405
@@ -778,6 +783,14 @@ def test_qnn_backend_index_put(self):
778783
skip_mutable_buffer=test[QCOM_MODULE].skip_mutable_buffer,
779784
)
780785

786+
def test_qnn_backend_index_select(self):
787+
module = IndexSelect(dim=1) # noqa: F405
788+
sample_input = (
789+
torch.randn(2, 3, 4, 5),
790+
torch.tensor([0, 2]),
791+
)
792+
self.lower_module_and_test_output(module, sample_input)
793+
781794
def test_qnn_backend_instance_norm_2d(self):
782795
modules = [InstanceNorm2d(32), InstanceNorm2d(32, affine=False)] # noqa: F405
783796
sample_input = (torch.randn([4, 32, 16, 16]),)
@@ -2031,17 +2044,11 @@ def test_qnn_backend_expm1(self):
20312044
self.lower_module_and_test_output(module, sample_input)
20322045

20332046
def test_qnn_backend_flip(self):
2034-
sample_input = (torch.randn(3, 4, 5,6),)
2035-
# golden_module = Flip()
2036-
decomp_module = FlipDecomp()
2037-
decomp_module = self.get_qdq_module(decomp_module, sample_input)
2038-
self.lower_module_and_test_output(decomp_module, sample_input)
2039-
# golden_out = golden_module(sample_input)
2040-
# decomp_out = decomp_module(sample_input)
2041-
# torch.testing.assert_close(golden_out, decomp_out)
2042-
2043-
2044-
2047+
sample_input = (torch.randn(3, 4, 5, 6),)
2048+
module = Flip() # noqa: F405
2049+
module = self.get_qdq_module(module, sample_input)
2050+
self.lower_module_and_test_output(module, sample_input)
2051+
20452052
def test_qnn_backend_floor(self):
20462053
sample_input = (torch.randn(3, 4),)
20472054
module = Floor() # noqa: F405
@@ -2285,6 +2292,15 @@ def test_qnn_backend_index_put(self):
22852292
skip_mutable_buffer=test[QCOM_MODULE].skip_mutable_buffer,
22862293
)
22872294

2295+
def test_qnn_backend_index_select(self):
2296+
module = IndexSelect(dim=1) # noqa: F405
2297+
sample_input = (
2298+
torch.randn(2, 3, 4, 5),
2299+
torch.tensor([0, 2]),
2300+
)
2301+
module = self.get_qdq_module(module, sample_input)
2302+
self.lower_module_and_test_output(module, sample_input)
2303+
22882304
def test_qnn_backend_instance_norm_2d(self):
22892305
modules = [InstanceNorm2d(32), InstanceNorm2d(32, affine=False)] # noqa: F405
22902306
sample_input = (torch.randn([4, 32, 16, 16]),)

0 commit comments

Comments
 (0)