Skip to content

Commit e7560e1

Browse files
authored
Qualcomm AI Engine Direct - enable operator avg_pool3d and adaptive_avg_pool3d (#15460)
Qualcomm AI Engine Direct - enable operator avg_pool3d and adaptive_avg_pool3d ### Summary Enable avg_pool3d and adaptive_pool3d operators ### Test plan python backends/qualcomm/tests/test_qnn_delegate.py TestQNNFloatingPointOperator.test_qnn_backend_avg_pool3d -b build-android -H HOST -s DEVICE -m CHIPID python backends/qualcomm/tests/test_qnn_delegate.py TestQNNFloatingPointOperator.test_qnn_backend_adaptive_avg_pool3d -b build-android -H HOST -s DEVICE -m CHIPID python backends/qualcomm/tests/test_qnn_delegate.py TestQNNQuantizedOperator.test_qnn_backend_avg_pool3d -b build-android -H HOST -s DEVICE -m CHIPID python backends/qualcomm/tests/test_qnn_delegate.py TestQNNQuantizedOperator.test_qnn_backend_adaptive_avg_pool3d -b build-android -H HOST -s DEVICE -m CHIPID
1 parent 67a94af commit e7560e1

File tree

9 files changed

+426
-3
lines changed

9 files changed

+426
-3
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ class LayoutTransform(ExportPass):
4242

4343
layout_sensitive_ops = {
4444
exir_ops.edge.aten.adaptive_avg_pool2d.default,
45+
exir_ops.edge.aten._adaptive_avg_pool3d.default,
4546
exir_ops.edge.aten.avg_pool2d.default,
47+
exir_ops.edge.aten.avg_pool3d.default,
4648
exir_ops.edge.aten.convolution.default,
4749
exir_ops.edge.aten.instance_norm.default,
4850
exir_ops.edge.aten.max_pool2d_with_indices.default,

backends/qualcomm/builders/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ Please help update following table if you are contributing new operators:
448448
| Pack | ✓ |
449449
| Pad | ✓ |
450450
| PoolAvg2d | ✓ |
451-
| PoolAvg3d | ✗ |
451+
| PoolAvg3d | ✓ |
452452
| PoolMax2d | ✓ |
453453
| Prelu | ✓ |
454454
| Quantize | ✓ |

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
op_asin,
1919
op_atan,
2020
op_avg_pool2d,
21+
op_avg_pool3d,
2122
op_batch_norm,
2223
op_binary,
2324
op_bmm,
@@ -123,6 +124,7 @@
123124
op_asin,
124125
op_atan,
125126
op_avg_pool2d,
127+
op_avg_pool3d,
126128
op_batch_norm,
127129
op_binary,
128130
op_bmm,
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import warnings
7+
from typing import cast, Dict, List
8+
9+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
10+
import numpy as np
11+
12+
import torch
13+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
14+
15+
from .node_visitor import NodeVisitor
16+
from .node_visitor_manager import register_node_visitor
17+
from .qnn_constants import OpPoolAvg3d, QNN_OP_PACKAGE_NAME_QTI_AISW
18+
19+
20+
@register_node_visitor
21+
class AvgPool3d(NodeVisitor):
22+
target = ["aten.avg_pool3d.default"]
23+
24+
def __init__(self, *args) -> None:
25+
super().__init__(*args)
26+
27+
def define_node(
28+
self,
29+
node: torch.fx.Node,
30+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
31+
) -> PyQnnWrapper.PyQnnOpWrapper:
32+
33+
input_node = self.get_node(node.args[0])
34+
input_tensor = self.get_tensor(input_node, node)
35+
input_tensor_wrapper = self.define_tensor(
36+
input_node,
37+
node,
38+
input_tensor,
39+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
40+
nodes_to_wrappers,
41+
)
42+
43+
# kernel info
44+
filter_size = cast(List[int], node.args[1])
45+
if len(filter_size) == 1:
46+
filter_size *= 3
47+
filter_size_shape = [len(filter_size)]
48+
49+
# stride info
50+
stride = cast(List[int], node.args[2])
51+
if len(stride) == 1:
52+
stride *= 3
53+
stride_shape = [len(stride)]
54+
55+
# padding info
56+
padding = [0, 0, 0]
57+
if len(node.args) > 3:
58+
padding = cast(List[int], node.args[3])
59+
if len(padding) == 1:
60+
padding *= 3
61+
62+
# if ceil mode is True, use ceil instead of floor to compute the output shape
63+
mode = OpPoolAvg3d.RoundingMode.FLOOR
64+
if len(node.args) > 4:
65+
ceil_mode = cast(bool, node.args[4])
66+
if ceil_mode:
67+
mode = OpPoolAvg3d.RoundingMode.CEIL
68+
69+
count_pad_for_edges = node.args[5] if len(node.args) > 5 else False
70+
71+
# pad left, pad right
72+
depth_pad_l = padding[0]
73+
depth_pad_r = padding[0]
74+
height_pad_l = padding[1]
75+
height_pad_r = padding[1]
76+
width_pad_l = padding[2]
77+
width_pad_r = padding[2]
78+
79+
shape_pad = [
80+
[depth_pad_l, depth_pad_r],
81+
[height_pad_l, height_pad_r],
82+
[width_pad_l, width_pad_r],
83+
]
84+
padding_shape = [len(shape_pad), len(shape_pad[0])]
85+
86+
out_tensor = self.get_tensor(node, node)
87+
output_tensor_wrapper = self.define_tensor(
88+
node,
89+
node,
90+
out_tensor,
91+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
92+
nodes_to_wrappers,
93+
)
94+
95+
avg_pool3d_op = PyQnnWrapper.PyQnnOpWrapper(
96+
node.name,
97+
QNN_OP_PACKAGE_NAME_QTI_AISW,
98+
OpPoolAvg3d.op_name,
99+
)
100+
101+
avg_pool3d_op.AddInputTensors([input_tensor_wrapper])
102+
avg_pool3d_op.AddOutputTensors([output_tensor_wrapper])
103+
104+
avg_pool3d_op.AddTensorParam(
105+
OpPoolAvg3d.param_filter_size,
106+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
107+
len(filter_size_shape),
108+
filter_size_shape,
109+
np.array(
110+
filter_size,
111+
dtype=np.uint32,
112+
),
113+
True,
114+
)
115+
116+
avg_pool3d_op.AddTensorParam(
117+
OpPoolAvg3d.param_stride,
118+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
119+
len(stride_shape),
120+
stride_shape,
121+
np.array(
122+
stride,
123+
dtype=np.uint32,
124+
),
125+
True,
126+
)
127+
128+
avg_pool3d_op.AddTensorParam(
129+
OpPoolAvg3d.param_pad_amount,
130+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
131+
len(padding_shape),
132+
padding_shape,
133+
np.array(
134+
shape_pad,
135+
dtype=np.uint32,
136+
),
137+
True,
138+
)
139+
140+
avg_pool3d_op.AddScalarParam(
141+
OpPoolAvg3d.param_count_pad_for_edges,
142+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
143+
{QCOM_DATA: count_pad_for_edges},
144+
)
145+
146+
avg_pool3d_op.AddScalarParam(
147+
OpPoolAvg3d.param_rounding_mode,
148+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
149+
{QCOM_DATA: np.uint32(mode)},
150+
)
151+
152+
return avg_pool3d_op
153+
154+
155+
@register_node_visitor
156+
class AdaptiveAvgPool3d(NodeVisitor):
157+
target = ["aten._adaptive_avg_pool3d.default"]
158+
159+
def __init__(self, *args) -> None:
160+
super().__init__(*args)
161+
162+
def define_node(
163+
self,
164+
node: torch.fx.Node,
165+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
166+
) -> PyQnnWrapper.PyQnnOpWrapper:
167+
input_node = self.get_node(node.args[0])
168+
input_tensor = self.get_tensor(input_node, node)
169+
input_tensor_wrapper = self.define_tensor(
170+
input_node,
171+
node,
172+
input_tensor,
173+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
174+
nodes_to_wrappers,
175+
)
176+
# NOTE: This operator is layout sensitive, so the input tensor shape is always N,D,H,W,C.
177+
input_depth = input_tensor.shape[1]
178+
input_height = input_tensor.shape[2]
179+
input_width = input_tensor.shape[3]
180+
output_depth = node.args[1][0]
181+
output_height = node.args[1][1]
182+
output_width = node.args[1][2]
183+
if output_depth is None:
184+
output_depth = input_depth
185+
if output_height is None:
186+
output_height = input_height
187+
if output_width is None:
188+
output_width = input_width
189+
190+
# kernel info & stride info
191+
stride_height = input_height // output_height
192+
filter_height = input_height - (output_height - 1) * stride_height
193+
stride_width = input_width // output_width
194+
filter_width = input_width - (output_width - 1) * stride_width
195+
stride_depth = input_depth // output_depth
196+
filter_depth = input_depth - (output_depth - 1) * stride_depth
197+
198+
filter_size = [filter_depth, filter_height, filter_width]
199+
filter_shape = [len(filter_size)]
200+
stride = [stride_depth, stride_height, stride_width]
201+
stride_shape = [len(stride)]
202+
203+
depth = (output_depth - 1) * stride_depth + filter_depth - input_depth
204+
height = (output_height - 1) * stride_height + filter_height - input_height
205+
width = (output_width - 1) * stride_width + filter_width - input_width
206+
207+
if any(x != 0 for x in (depth, height, width)):
208+
warnings.warn(
209+
"[QNN Delegate Op Builder]: Depth or Height or Width is not suitable, fallback op",
210+
stacklevel=1,
211+
)
212+
return
213+
214+
count_pad_for_edges = False
215+
# This operator use the default rounding mode of avg_pool3d, floor.
216+
mode = OpPoolAvg3d.RoundingMode.FLOOR
217+
218+
# pad left, pad right, use default 0
219+
depth_pad_b = 0
220+
depth_pad_a = 0
221+
height_pad_b = 0
222+
height_pad_a = 0
223+
width_pad_b = 0
224+
width_pad_a = 0
225+
226+
shape_pad = [
227+
[depth_pad_b, depth_pad_a],
228+
[height_pad_b, height_pad_a],
229+
[width_pad_b, width_pad_a],
230+
]
231+
padding_shape = [len(shape_pad), len(shape_pad[0])]
232+
233+
out_tensor = self.get_tensor(node, node)
234+
output_tensor_wrapper = self.define_tensor(
235+
node,
236+
node,
237+
out_tensor,
238+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
239+
nodes_to_wrappers,
240+
)
241+
242+
adaptive_avg_pool3d_op = PyQnnWrapper.PyQnnOpWrapper(
243+
node.name,
244+
QNN_OP_PACKAGE_NAME_QTI_AISW,
245+
OpPoolAvg3d.op_name,
246+
)
247+
248+
adaptive_avg_pool3d_op.AddInputTensors([input_tensor_wrapper])
249+
adaptive_avg_pool3d_op.AddOutputTensors([output_tensor_wrapper])
250+
251+
adaptive_avg_pool3d_op.AddTensorParam(
252+
OpPoolAvg3d.param_filter_size,
253+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
254+
len(filter_shape),
255+
filter_shape,
256+
np.array(
257+
filter_size,
258+
dtype=np.uint32,
259+
),
260+
True,
261+
)
262+
263+
adaptive_avg_pool3d_op.AddTensorParam(
264+
OpPoolAvg3d.param_stride,
265+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
266+
len(stride_shape),
267+
stride_shape,
268+
np.array(
269+
stride,
270+
dtype=np.uint32,
271+
),
272+
True,
273+
)
274+
275+
adaptive_avg_pool3d_op.AddTensorParam(
276+
OpPoolAvg3d.param_pad_amount,
277+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
278+
len(padding_shape),
279+
padding_shape,
280+
np.array(
281+
shape_pad,
282+
dtype=np.uint32,
283+
),
284+
True,
285+
)
286+
287+
adaptive_avg_pool3d_op.AddScalarParam(
288+
OpPoolAvg3d.param_count_pad_for_edges,
289+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
290+
{QCOM_DATA: count_pad_for_edges},
291+
)
292+
293+
adaptive_avg_pool3d_op.AddScalarParam(
294+
OpPoolAvg3d.param_rounding_mode,
295+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
296+
{QCOM_DATA: np.uint32(mode)},
297+
)
298+
299+
return adaptive_avg_pool3d_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,21 @@ class RoundingMode(IntEnum):
398398
CEIL = 1
399399

400400

401+
@dataclass(init=False, frozen=True)
402+
class OpPoolAvg3d:
403+
op_name: str = "PoolAvg3d"
404+
param_filter_size: str = "filter_size"
405+
param_stride: str = "stride"
406+
param_pad_amount: str = "pad_amount"
407+
param_count_pad_for_edges: str = "count_pad_for_edges"
408+
param_rounding_mode: str = "rounding_mode"
409+
410+
@unique
411+
class RoundingMode(IntEnum):
412+
FLOOR = 0
413+
CEIL = 1
414+
415+
401416
@dataclass(init=False, frozen=True)
402417
class OpPoolMax2d:
403418
op_name: str = "PoolMax2d"

backends/qualcomm/partition/common_defs.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@
1919
]
2020

2121
to_be_implemented_operator = [
22-
exir_ops.edge.aten._adaptive_avg_pool3d.default,
2322
exir_ops.edge.aten.adaptive_max_pool2d.default,
2423
exir_ops.edge.aten.adaptive_max_pool3d.default,
25-
exir_ops.edge.aten.avg_pool3d.default,
2624
exir_ops.edge.aten.div.Tensor_mode,
2725
exir_ops.edge.aten.log10.default,
2826
exir_ops.edge.aten.log1p.default,

backends/qualcomm/quantizer/annotators.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,18 @@ def annotate_avgpool2d(node: Node, quantization_config: QuantizationConfig) -> N
578578
annotate_single_in_single_out(node, quantization_config)
579579

580580

581+
@register_annotator([torch.ops.aten.avg_pool3d.default])
582+
def annotate_avgpool3d(node: Node, quantization_config: QuantizationConfig) -> None:
583+
annotate_single_in_single_out(node, quantization_config)
584+
585+
586+
@register_annotator([torch.ops.aten.adaptive_avg_pool3d.default])
587+
def annotate_adaptive_avgpool3d(
588+
node: Node, quantization_config: QuantizationConfig
589+
) -> None:
590+
annotate_single_in_single_out(node, quantization_config)
591+
592+
581593
@register_annotator([torch.ops.aten.permute.default])
582594
def annotate_permute(node: Node, quantization_config: QuantizationConfig) -> None:
583595
annotate_in_out_obs_sharing_op(node, quantization_config)

0 commit comments

Comments
 (0)