Skip to content

Commit dbeddf3

Browse files
cccclaifacebook-github-bot
authored andcommitted
Add op_amax support
Summary: As title, add op_amax to support an internal model, add unit test in test_qnn_delegate.py Differential Revision: D72613814
1 parent 2cce2db commit dbeddf3

File tree

6 files changed

+120
-0
lines changed

6 files changed

+120
-0
lines changed

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
op_abs,
1010
op_adaptive_avg_pool2d,
1111
op_add,
12+
op_amax,
1213
op_and,
1314
op_arange,
1415
op_argmin,
@@ -95,6 +96,7 @@
9596
op_abs,
9697
op_adaptive_avg_pool2d,
9798
op_add,
99+
op_amax,
98100
op_and,
99101
op_arange,
100102
op_argmin,
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import cast, Dict, List
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import numpy as np
11+
12+
import torch
13+
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
14+
15+
from .node_visitor import NodeVisitor, register_node_visitor
16+
from .qnn_constants import OpAmax, QNN_OP_PACKAGE_NAME_QTI_AISW
17+
18+
@register_node_visitor
19+
class AMax(NodeVisitor):
20+
target = ["aten.amax.default"]
21+
22+
def __init__(self, *args) -> None:
23+
super().__init__(*args)
24+
25+
def define_node(
26+
self,
27+
node: torch.fx.Node,
28+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
29+
) -> PyQnnWrapper.PyQnnOpWrapper:
30+
input_node = node.args[0]
31+
input_tensor = self.get_tensor(input_node, node)
32+
input_tensor_wrapper = self.define_tensor(
33+
input_node,
34+
node,
35+
input_tensor,
36+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
37+
nodes_to_wrappers,
38+
)
39+
40+
# mean dims and keep dims
41+
mean_dims = cast(List[int], node.args[1])
42+
mean_dims = [
43+
mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims
44+
]
45+
if QCOM_AXIS_ORDER in node.meta:
46+
mean_dims = [
47+
node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in mean_dims
48+
]
49+
mean_dims_shape = [len(mean_dims)]
50+
51+
output_tensor = self.get_tensor(node, node)
52+
output_tensor_wrapper = self.define_tensor(
53+
node,
54+
node,
55+
output_tensor,
56+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
57+
nodes_to_wrappers,
58+
)
59+
60+
reduce_max_op = PyQnnWrapper.PyQnnOpWrapper(
61+
node.name,
62+
QNN_OP_PACKAGE_NAME_QTI_AISW,
63+
OpAmax.op_name,
64+
)
65+
reduce_max_op.AddInputTensors([input_tensor_wrapper])
66+
reduce_max_op.AddOutputTensors([output_tensor_wrapper])
67+
reduce_max_op.AddTensorParam(
68+
OpAmax.param_axes,
69+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
70+
len(mean_dims_shape),
71+
mean_dims_shape,
72+
np.array(mean_dims, dtype=np.uint32),
73+
True,
74+
)
75+
if len(node.args) > 2:
76+
keep_dims = cast(bool, node.args[2])
77+
reduce_max_op.AddScalarParam(
78+
OpAmax.param_keep_dims,
79+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
80+
{QCOM_DATA: keep_dims},
81+
)
82+
83+
return reduce_max_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,12 @@ class OpScatterNd:
418418
class OpSigmoid:
419419
op_name: str = "Sigmoid"
420420

421+
@dataclass(init=False, frozen=True)
422+
class OpAmax:
423+
op_name: str = "ReduceMax"
424+
param_axes: str = "axes"
425+
param_keep_dims: str = "keep_dims"
426+
421427

422428
@dataclass(init=False, frozen=True)
423429
class OpSoftmax:

backends/qualcomm/quantizer/annotators.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None
181181
def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None:
182182
annotate_binary(node, quantization_config)
183183

184+
@register_annotator([torch.ops.aten.amax.default])
185+
def annotate_amax(node: Node, quantization_config: QuantizationConfig) -> None:
186+
annotate_binary(node, quantization_config)
184187

185188
@register_annotator([torch.ops.aten.argmin.default])
186189
def annotate_argmin(node: Node, quantization_config: QuantizationConfig) -> None:

backends/qualcomm/tests/models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,16 @@ def forward(self, x):
7272
return torch.any(x, dim=self.dim, keepdim=self.keepdim)
7373

7474

75+
class AMax(torch.nn.Module):
76+
def __init__(self, dim=None, keepdim=False):
77+
super().__init__()
78+
self.dim = dim
79+
self.keepdim = keepdim
80+
81+
def forward(self, x):
82+
return torch.amax(x, dim=self.dim, keepdim=self.keepdim)
83+
84+
7585
class Arange(torch.nn.Module):
7686
def __init__(self, start, end, step, dtype):
7787
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,13 @@ def test_qnn_backend_adaptive_avg_pool2d(self):
113113
sample_input = (torch.randn(1, 512, 7, 7),)
114114
self.lower_module_and_test_output(module, sample_input)
115115

116+
def test_qnn_backend_amax(self):
117+
modules = [AMax(dim=1, keepdim=False), AMax(dim=1, keepdim=True)]
118+
sample_input = (torch.randn(4, 4), )
119+
for i, module in enumerate(modules):
120+
with self.subTest(i=i):
121+
self.lower_module_and_test_output(module, sample_input)
122+
116123
def test_qnn_backend_any(self):
117124
modules = [Any(), Any(dim=[0, 1]), Any(dim=1, keepdim=True)] # noqa: F405
118125
sample_input = (torch.randn(3, 3, 3) > 0,)
@@ -1111,6 +1118,15 @@ def test_qnn_backend_adaptive_avg_pool2d(self):
11111118
module = self.get_qdq_module(module, sample_input)
11121119
self.lower_module_and_test_output(module, sample_input)
11131120

1121+
1122+
def test_qnn_backend_amax(self):
1123+
modules = [AMax(dim=1, keepdim=False), AMax(dim=1, keepdim=True)]
1124+
sample_input = (torch.randn(4, 4), )
1125+
for i, module in enumerate(modules):
1126+
with self.subTest(i=i):
1127+
module = self.get_qdq_module(module, sample_input)
1128+
self.lower_module_and_test_output(module, sample_input)
1129+
11141130
def test_qnn_backend_any(self):
11151131
modules = [Any(), Any(dim=[0, 1]), Any(dim=1, keepdim=True)] # noqa: F405
11161132
sample_input = (torch.randn(3, 3, 3) > 0,)

0 commit comments

Comments
 (0)