Skip to content

Commit aaa41d2

Browse files
haowhsu-quicDannyYuyang-quic
authored andcommitted
Qualcomm AI Engine Direct - XR model mld_a enablement (pytorch#9129)
### Summary - make index op builder more general - small refactor on layout_transform - support new pattern of upsample2d ### Test plan ```bash python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNQuantizedOperator -s $SERIAL_NO -m SM8650 -b build-android ```
1 parent 3857588 commit aaa41d2

File tree

4 files changed

+23
-18
lines changed

4 files changed

+23
-18
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
from executorch.exir.pass_base import ExportPass, PassResult
2020
from executorch.exir.sym_util import eval_shape
2121

22-
from .utils import dq_ops, q_ops
23-
2422

2523
class LayoutTransform(ExportPass):
2624
"""
@@ -92,8 +90,6 @@ class LayoutTransform(ExportPass):
9290
exir_ops.edge.aten.topk.default,
9391
exir_ops.edge.aten._to_copy.default,
9492
exir_ops.edge.aten.where.self,
95-
*q_ops,
96-
*dq_ops,
9793
_operator.getitem,
9894
}
9995

@@ -118,7 +114,6 @@ def __init__(
118114
super(LayoutTransform, self).__init__()
119115
self.edge_program = edge_program
120116
self.insert_permute = insert_permute
121-
self.qdq_opset = {*q_ops, *dq_ops}
122117
self.transformed_tag = QCOM_AXIS_ORDER
123118

124119
def mark_as_transformed(self, node: torch.fx.Node) -> None:

backends/qualcomm/builders/op_index.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ def define_node(
3838
nodes_to_wrappers,
3939
)
4040

41-
if len(node.args[1]) > 1:
42-
# TODO consider to implement it in a recursive way.
43-
raise NotImplementedError("Not support tuple of tensor.")
44-
45-
indices_node = node.args[1][0]
41+
# e.g. x[:, index]:
42+
# > node.args[1] = [None, indices]
43+
# > axis = 1
44+
axis = len(node.args[1]) - 1
45+
indices_node = node.args[1][axis]
4646
indices_tensor = self.get_tensor(indices_node, node).to(torch.int32)
4747
assert indices_tensor.size(0) != 0, "Not support empty indices list"
4848

@@ -78,7 +78,7 @@ def define_node(
7878
gather_op.AddScalarParam(
7979
OpGather.param_axis,
8080
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32,
81-
{QCOM_DATA: np.int32(0)},
81+
{QCOM_DATA: np.int32(axis)},
8282
)
8383

8484
return gather_op

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -754,13 +754,19 @@ def forward(self, x):
754754

755755

756756
class Index(torch.nn.Module):
757-
def __init__(self):
757+
def __init__(self, axis):
758758
super().__init__()
759759
self.idx0 = torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.int32)
760760
self.idx1 = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.int32)
761+
self.axis = axis
762+
self.dispatcher = {
763+
0: lambda x: x[self.idx0] + x[self.idx1],
764+
1: lambda x: x[:, self.idx0] + x[:, self.idx1],
765+
2: lambda x: x[:, :, self.idx0] + x[:, :, self.idx1],
766+
}
761767

762768
def forward(self, x):
763-
return x[self.idx0] + x[self.idx1]
769+
return self.dispatcher[self.axis](x)
764770

765771

766772
class IndexPut(torch.nn.Module):

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -474,9 +474,11 @@ def test_qnn_backend_hardtanh(self):
474474
self.lower_module_and_test_output(module, sample_input)
475475

476476
def test_qnn_backend_index(self):
477-
module = Index() # noqa: F405
477+
modules = [Index(0), Index(1), Index(2)] # noqa: F405
478478
sample_input = (torch.randn([8, 172, 64]),)
479-
self.lower_module_and_test_output(module, sample_input)
479+
for i, module in enumerate(modules):
480+
with self.subTest(i=i):
481+
self.lower_module_and_test_output(module, sample_input)
480482

481483
def test_qnn_backend_index_put(self):
482484
module = IndexPut() # noqa: F405
@@ -1470,10 +1472,12 @@ def test_qnn_backend_hardtanh(self):
14701472
self.lower_module_and_test_output(module, sample_input)
14711473

14721474
def test_qnn_backend_index(self):
1473-
module = Index() # noqa: F405
1475+
modules = [Index(0), Index(1), Index(2)] # noqa: F405
14741476
sample_input = (torch.randn([8, 172, 64]),)
1475-
module = self.get_qdq_module(module, sample_input)
1476-
self.lower_module_and_test_output(module, sample_input)
1477+
for i, module in enumerate(modules):
1478+
with self.subTest(i=i):
1479+
module = self.get_qdq_module(module, sample_input)
1480+
self.lower_module_and_test_output(module, sample_input)
14771481

14781482
def test_qnn_backend_index_put(self):
14791483
module = IndexPut() # noqa: F405

0 commit comments

Comments
 (0)