Skip to content

Commit 4b24a82

Browse files
committed
Qualcomm AI Engine Direct - xr model enablement (mld_f)
Summary - add gather op support - make cast / slice op more general
1 parent df75088 commit 4b24a82

File tree

14 files changed

+305
-30
lines changed

14 files changed

+305
-30
lines changed

backends/cadence/hifi/operators/op_bmm.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ using exec_aten::ScalarType;
1616
using executorch::runtime::KernelRuntimeContext;
1717
using executorch::runtime::kTensorDimensionLimit;
1818
using executorch::runtime::resize_tensor;
19-
using executorch::runtime::tensors_have_same_dim_order;
2019
using executorch::runtime::tensor_is_default_dim_order;
20+
using executorch::runtime::tensors_have_same_dim_order;
2121
using torch::executor::check_bmm_args;
2222
using torch::executor::Error;
2323
using torch::executor::get_bmm_out_target_size;
@@ -78,16 +78,16 @@ Tensor& bmm_out(
7878
WORD32 out_stride = p;
7979

8080
WORD32* __restrict__ tmp =
81-
(WORD32* __restrict__)kernels::allocate_temp_memory(
82-
ctx, (batch_size * m * p) * sizeof(float));
81+
(WORD32* __restrict__)kernels::allocate_temp_memory(
82+
ctx, (batch_size * m * p) * sizeof(float));
8383

8484
ET_KERNEL_CHECK(ctx, tmp != nullptr, MemoryAllocationFailed, out);
8585

8686
tmp[batch_size * m * p] = {0};
8787

8888
WORD32* __restrict__ p_o =
89-
(WORD32* __restrict__)kernels::allocate_temp_memory(
90-
ctx, (batch_size * m * p) * sizeof(WORD32));
89+
(WORD32* __restrict__)kernels::allocate_temp_memory(
90+
ctx, (batch_size * m * p) * sizeof(WORD32));
9191

9292
ET_KERNEL_CHECK(ctx, p_o != nullptr, MemoryAllocationFailed, out);
9393

backends/cadence/hifi/operators/op_mm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ Tensor& mm_out(
7676
WORD32 out_stride = p;
7777

7878
WORD32* __restrict__ p_o =
79-
(WORD32* __restrict__)kernels::allocate_temp_memory(
80-
ctx, (n * p) * sizeof(WORD32));
79+
(WORD32* __restrict__)kernels::allocate_temp_memory(
80+
ctx, (n * p) * sizeof(WORD32));
8181

8282
WORD32 p_inp_shape[2];
8383
p_inp_shape[0] = n;

backends/qualcomm/_passes/i64_to_i32.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ class I64toI32(ExportPass):
3131
exir_ops.edge.aten.full.default,
3232
exir_ops.edge.aten.scalar_tensor.default,
3333
}
34+
# This dict is to ensure that the input of the OPs are int64 due to Pytorch restrictions.
35+
# For example, scatter op can only accept args[2], the index, as int64.
36+
# Key: Ops to cast input to i64
37+
# Value: The args' indices to add casting op
38+
I64_IN_OPS = {
39+
exir_ops.edge.aten.gather.default: [2],
40+
exir_ops.edge.aten.scatter.src: [2],
41+
}
3442
copy_op = exir_ops.edge.aten._to_copy.default
3543

3644
def __init__(
@@ -141,11 +149,32 @@ def _cast_constant_to_int32(self, graph_module: torch.fx.GraphModule):
141149
n.replace_all_uses_with(to_dst_node)
142150
to_dst_node.args = (n,)
143151

152+
def _cast_op_args_to_i64(self, graph_module: torch.fx.GraphModule):
153+
# input will be cast to i32 during call_operator dtype propogation
154+
# insert i64 cast node to prevent operator validation failure
155+
for node in graph_module.graph.nodes:
156+
if node.target in self.I64_IN_OPS:
157+
with graph_module.graph.inserting_before(node):
158+
arg_indices = self.I64_IN_OPS[node.target]
159+
for arg_index in arg_indices:
160+
input_node = node.args[arg_index]
161+
cast_i64_node = graph_module.graph.create_node(
162+
"call_function",
163+
self.copy_op,
164+
(input_node,),
165+
{"dtype": torch.int64},
166+
)
167+
cast_i64_node.meta["val"] = node.meta["val"].to(torch.int64)
168+
args_list = list(node.args)
169+
args_list[arg_index] = cast_i64_node
170+
node.args = tuple(args_list)
171+
144172
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
145173
# Record original output dtype to ensure that if user expects int64 as output,
146174
# convert the output back to int64 if it is casted from int64->int32.
147175
self._record_original_output_dtype(graph_module)
148176
self._cast_constant_to_int32(graph_module)
177+
self._cast_op_args_to_i64(graph_module)
149178
graph_module = super().call(graph_module).graph_module
150179
self._preserve_output_dtype(graph_module)
151180
graph_module.recompile()

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def transform_for_to_edge_pipeline(
182182

183183
# Before quantizer
184184
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
185+
self.add_pass(RemoveRedundancy(quantization_capture=True))
185186
self.add_pass(ReduceDynamicRange())
186187
self.add_pass(RecomposePixelUnshuffle(quantization_capture=True))
187188
self.add_pass(ReplaceArangeArgs())

backends/qualcomm/_passes/remove_redundancy.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ class RemoveRedundancy(ExportPass):
1414
Trim certain operators to reduce unnecessary overhead.
1515
"""
1616

17-
def __init__(self):
17+
def __init__(self, quantization_capture=False):
1818
super(RemoveRedundancy, self).__init__()
19-
self.redundant_ops = {
19+
self.redundant_ops_general = {
2020
torch.clone: self._default_condition,
2121
torch.ops.aten.clone.default: self._default_condition,
2222
exir_ops.edge.aten.clone.default: self._default_condition,
@@ -27,7 +27,16 @@ def __init__(self):
2727
exir_ops.edge.dim_order_ops._to_dim_order_copy.default: self._dim_order_op_condition,
2828
# remove channel_last / contiguous _to_copy if '_skip_dim_order' is set to True
2929
exir_ops.edge.aten._to_copy.default: self._to_copy_op_condition,
30+
torch.ops.aten._assert_tensor_metadata.default: self._default_condition,
3031
}
32+
self.redundant_ops_annotation = {
33+
torch.ops.aten._assert_tensor_metadata.default: self._default_condition,
34+
}
35+
self.redundant_ops = (
36+
self.redundant_ops_annotation
37+
if quantization_capture
38+
else self.redundant_ops_general
39+
)
3140

3241
def _dim_order_op_condition(self, node):
3342
dim_order = node.kwargs.get("dim_order")
@@ -49,6 +58,10 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
4958
continue
5059

5160
to_be_remove = n
61+
# assert_tensor_metadata op has no user
62+
if len(n.users.keys()) == 0:
63+
n.args = ()
64+
# normal case
5265
for user_n in list(n.users.keys()):
5366
user_n.replace_input_with(n, n.args[0])
5467
graph_module.graph.erase_node(to_be_remove)

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
op_expand,
3333
op_full,
3434
op_full_like,
35+
op_gather,
3536
op_ge,
3637
op_gelu,
3738
op_group_norm,
@@ -120,6 +121,7 @@
120121
op_expand,
121122
op_full,
122123
op_full_like,
124+
op_gather,
123125
op_ge,
124126
op_gelu,
125127
op_group_norm,
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import cast, Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import numpy as np
11+
import torch
12+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
15+
from .node_visitor import NodeVisitor, register_node_visitor
16+
from .qnn_constants import OpCast, OpGatherElements, QNN_OP_PACKAGE_NAME_QTI_AISW
17+
18+
19+
@register_node_visitor
20+
class Gather(NodeVisitor):
21+
target = ["aten.gather.default"]
22+
23+
def __init__(self, *args) -> None:
24+
super().__init__(*args)
25+
26+
def define_node(
27+
self,
28+
node: torch.fx.Node,
29+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
30+
) -> PyQnnWrapper.PyQnnOpWrapper:
31+
input_node = node.args[0]
32+
input_tensor = self.get_tensor(input_node, node)
33+
input_tensor_wrapper = self.define_tensor(
34+
input_node,
35+
node,
36+
input_tensor,
37+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
38+
nodes_to_wrappers,
39+
)
40+
41+
dim = cast(int, node.args[1])
42+
43+
indices_node = node.args[2]
44+
indices_tensor = self.get_tensor(indices_node, node)
45+
indices_tensor_wrapper = self.define_tensor(
46+
indices_node,
47+
node,
48+
indices_tensor,
49+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
50+
nodes_to_wrappers,
51+
)
52+
53+
cast_node = self.edge_program.graph.create_node(
54+
"call_function",
55+
exir_ops.edge.aten._to_copy.default,
56+
(indices_node,),
57+
{"dtype": torch.int32},
58+
)
59+
cast_node.meta["val"] = indices_node.meta["val"].to(torch.int32)
60+
cast_tensor = self.get_tensor(cast_node, node)
61+
cast_tensor_wrapper = self.define_tensor(
62+
cast_node,
63+
node,
64+
cast_tensor,
65+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
66+
nodes_to_wrappers,
67+
)
68+
# graph is not allowed to be modified in partition stage
69+
# erase it here to prevent lowering failure
70+
self.edge_program.graph.erase_node(cast_node)
71+
cast_op = PyQnnWrapper.PyQnnOpWrapper(
72+
f"{node.name}_cast_i64_to_i32", QNN_OP_PACKAGE_NAME_QTI_AISW, OpCast.op_name
73+
)
74+
cast_op.AddInputTensors([indices_tensor_wrapper])
75+
cast_op.AddOutputTensors([cast_tensor_wrapper])
76+
77+
gather_input_tensors = [input_tensor_wrapper, cast_tensor_wrapper]
78+
output_tensor = self.get_tensor(node, node)
79+
output_tensor_wrapper = self.define_tensor(
80+
node,
81+
node,
82+
output_tensor,
83+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
84+
nodes_to_wrappers,
85+
)
86+
gather_output_tensors = [output_tensor_wrapper]
87+
88+
gather_op = PyQnnWrapper.PyQnnOpWrapper(
89+
node.name,
90+
QNN_OP_PACKAGE_NAME_QTI_AISW,
91+
OpGatherElements.op_name,
92+
)
93+
gather_op.AddInputTensors(gather_input_tensors)
94+
gather_op.AddOutputTensors(gather_output_tensors)
95+
gather_op.AddScalarParam(
96+
OpGatherElements.param_axis,
97+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
98+
{QCOM_DATA: np.uint32(dim)},
99+
)
100+
101+
return [cast_op, gather_op]

backends/qualcomm/builders/op_slice_copy.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,17 @@ def define_node(
5050
dim = cast(int, node.args[1])
5151
if dim < 0:
5252
dim = dim % len(input_tensor.shape)
53-
start = cast(int, node.args[2])
53+
54+
start = 0 if node.args[2] is None else cast(int, node.args[2])
5455
if start < 0:
5556
start = start % input_tensor.shape[dim]
56-
end = min(cast(int, node.args[3]), input_tensor.shape[dim])
57-
if end < 0:
58-
end = end % input_tensor.shape[dim]
57+
58+
if len(node.args) > 3:
59+
end = min(cast(int, node.args[3]), input_tensor.shape[dim])
60+
if end < 0:
61+
end = end % input_tensor.shape[dim]
62+
else:
63+
end = input_tensor.shape[dim]
5964

6065
input_tensor_rank = len(input_tensor.shape)
6166
ranges = []

backends/qualcomm/builders/op_to.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import torch
1111
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
12+
from executorch.exir.dialects._ops import ops as exir_ops
1213

1314
from .node_visitor import NodeVisitor, register_node_visitor
1415
from .qnn_constants import OpCast, OpConvert, QNN_OP_PACKAGE_NAME_QTI_AISW
@@ -90,9 +91,48 @@ def define_node(
9091
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
9192
nodes_to_wrappers,
9293
)
94+
node_input_tensors = [input_tensor_wrapper]
95+
96+
# if the output / input dtype is int64, we should cast it to int32 first
97+
# since int32 is the only source that can be caste into int64
98+
ops = []
99+
if (
100+
(
101+
node.meta["val"].dtype == torch.int64
102+
or input_node.meta["val"].dtype == torch.int64
103+
)
104+
# no need to add another cast node if the dtype is already integer type
105+
and input_node.meta["val"].dtype not in (torch.int32, torch.int64)
106+
):
107+
cast_node = self.edge_program.graph.create_node(
108+
"call_function",
109+
exir_ops.edge.aten._to_copy.default,
110+
(input_node,),
111+
{"dtype": torch.int32},
112+
)
113+
cast_node.meta["val"] = input_node.meta["val"].to(torch.int32)
114+
cast_tensor = self.get_tensor(cast_node, node)
115+
cast_tensor_wrapper = self.define_tensor(
116+
cast_node,
117+
node,
118+
cast_tensor,
119+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
120+
nodes_to_wrappers,
121+
)
122+
# graph is not allowed to be modified in partition stage
123+
# erase it here to prevent lowering failure
124+
self.edge_program.graph.erase_node(cast_node)
125+
cast_op = PyQnnWrapper.PyQnnOpWrapper(
126+
f"{node.name}_cast_i64_to_i32",
127+
QNN_OP_PACKAGE_NAME_QTI_AISW,
128+
OpCast.op_name,
129+
)
130+
node_input_tensors = [cast_tensor_wrapper]
131+
cast_op.AddInputTensors([input_tensor_wrapper])
132+
cast_op.AddOutputTensors([cast_tensor_wrapper])
133+
ops.append(cast_op)
93134

94135
output_tensor = self.get_tensor(node, node)
95-
96136
output_tensor_wrapper = self.define_tensor(
97137
node,
98138
node,
@@ -105,7 +145,8 @@ def define_node(
105145
op = PyQnnWrapper.PyQnnOpWrapper(
106146
node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op.op_name
107147
)
108-
op.AddInputTensors([input_tensor_wrapper])
148+
op.AddInputTensors(node_input_tensors)
109149
op.AddOutputTensors([output_tensor_wrapper])
150+
ops.append(op)
110151

111-
return op
152+
return ops

backends/qualcomm/builders/qnn_constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,12 @@ class OpGather:
252252
param_axis: str = "axis"
253253

254254

255+
@dataclass(init=False, frozen=True)
256+
class OpGatherElements:
257+
op_name: str = "GatherElements"
258+
param_axis: str = "axis"
259+
260+
255261
@dataclass(init=False, frozen=True)
256262
class OpGatherND:
257263
op_name: str = "GatherNd"

0 commit comments

Comments
 (0)