Skip to content

Commit 34e55b7

Browse files
billmguofacebook-github-bot
authored andcommitted
add argmin, add_ ops (#8040)
Summary: #8040 add argmin, add_ ops annontation Differential Revision: D68811662
1 parent 4796da7 commit 34e55b7

File tree

8 files changed

+108
-2
lines changed

8 files changed

+108
-2
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class LayoutTransform(ExportPass):
4444
layout_agnostic_ops = {
4545
exir_ops.edge.aten.abs.default,
4646
exir_ops.edge.aten.add.Tensor,
47+
exir_ops.edge.aten.argmin.default,
4748
exir_ops.edge.aten.bmm.default,
4849
exir_ops.edge.aten.cat.default,
4950
exir_ops.edge.aten.ceil.default,

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
op_abs,
1010
op_add,
1111
op_arange,
12+
op_argmin,
1213
op_avg_pool2d,
1314
op_batch_norm,
1415
op_bmm,
@@ -80,6 +81,7 @@
8081
op_abs,
8182
op_add,
8283
op_arange,
84+
op_argmin,
8385
op_avg_pool2d,
8486
op_batch_norm,
8587
op_bmm,

backends/qualcomm/builders/op_add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@register_node_visitor
1717
class Add(NodeVisitor):
18-
target = ["aten.add.Tensor"]
18+
target = ["aten.add.Tensor", "aten.add_.Tensor"]
1919

2020
def __init__(self, *args) -> None:
2121
super().__init__(*args)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import cast, Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
import numpy as np
10+
import torch
11+
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
12+
13+
from .node_visitor import NodeVisitor, register_node_visitor
14+
from .qnn_constants import OpArgmin, QNN_OP_PACKAGE_NAME_QTI_AISW
15+
16+
17+
@register_node_visitor
18+
class Argmin(NodeVisitor):
19+
target = ["aten.argmin.default"]
20+
21+
def __init__(self, *args) -> None:
22+
super().__init__(*args)
23+
24+
def define_node(
25+
self,
26+
node: torch.fx.Node,
27+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
28+
) -> PyQnnWrapper.PyQnnOpWrapper:
29+
input_node = node.args[0]
30+
input_tensor = self.get_tensor(input_node, node)
31+
argmin_inp_tensor_wrapper = self.define_tensor(
32+
input_node,
33+
node,
34+
input_tensor,
35+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
36+
nodes_to_wrappers,
37+
)
38+
argmin_input_tensors = [argmin_inp_tensor_wrapper]
39+
40+
output_tensor = self.get_tensor(node, node)
41+
output_tensor_wrapper = self.define_tensor(
42+
node,
43+
node,
44+
output_tensor,
45+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
46+
nodes_to_wrappers,
47+
)
48+
argmin_output_tensors = [output_tensor_wrapper]
49+
50+
dim = cast(int, node.args[1])
51+
if dim < 0:
52+
dim = dim % len(input_tensor.shape)
53+
if QCOM_AXIS_ORDER in node.meta:
54+
dim = node.meta[QCOM_AXIS_ORDER].index(dim)
55+
56+
argmin_op = PyQnnWrapper.PyQnnOpWrapper(
57+
node.name,
58+
QNN_OP_PACKAGE_NAME_QTI_AISW,
59+
OpArgmin.op_name,
60+
)
61+
argmin_op.AddInputTensors(argmin_input_tensors)
62+
argmin_op.AddOutputTensors(argmin_output_tensors)
63+
64+
argmin_op.AddScalarParam(
65+
OpArgmin.param_axis,
66+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
67+
{QCOM_DATA: np.uint32(dim)},
68+
)
69+
if len(node.args) > 2:
70+
keep_dims = cast(bool, node.args[2])
71+
argmin_op.AddScalarParam(
72+
OpArgmin.param_keep_dims,
73+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
74+
{QCOM_DATA: keep_dims},
75+
)

backends/qualcomm/builders/qnn_constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,13 @@ class OpReduceMean:
301301
param_keep_dims: str = "keep_dims"
302302

303303

304+
@dataclass(init=False, frozen=True)
305+
class OpArgmin:
306+
op_name: str = "Argmin"
307+
param_axis: str = "axis"
308+
param_keep_dims: str = "keep_dims"
309+
310+
304311
@dataclass(init=False, frozen=True)
305312
class OpReduceSum:
306313
op_name: str = "ReduceSum"

backends/qualcomm/quantizer/annotators.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,18 @@ def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None
162162
)
163163

164164

165-
@register_annotator([torch.ops.aten.add, torch.ops.aten.add.Tensor])
165+
@register_annotator(
166+
[torch.ops.aten.add, torch.ops.aten.add_.Tensor, torch.ops.aten.add.Tensor]
167+
)
166168
def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None:
167169
annotate_binary(node, quantization_config)
168170

169171

172+
@register_annotator([torch.ops.aten.argmin.default])
173+
def annotate_argmin(node: Node, quantization_config: QuantizationConfig) -> None:
174+
annotate_binary(node, quantization_config)
175+
176+
170177
@register_annotator([torch.ops.aten.sub, torch.ops.aten.sub.Tensor])
171178
def annotate_sub(node: Node, quantization_config: QuantizationConfig) -> None:
172179
annotate_binary(node, quantization_config)

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ def forward(self, y):
5757
)
5858

5959

60+
class Argmin(torch.nn.Module):
61+
def __init__(self):
62+
super().__init__()
63+
64+
def forward(self, x):
65+
return torch.argmin(x, dim=0, keepdim=True)
66+
67+
6068
class AvgPoolModule(torch.nn.Module):
6169
def __init__(self):
6270
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1596,6 +1596,12 @@ def test_qnn_backend_chunk_add(self):
15961596
module = self.get_qdq_module(module, sample_input)
15971597
self.lower_module_and_test_output(module, sample_input)
15981598

1599+
def test_qnn_backend_argmin(self):
1600+
module = Argmin() # noqa: F405
1601+
sample_input = (torch.randn(1, 3, 224, 224),)
1602+
module = self.get_qdq_module(module, sample_input)
1603+
self.lower_module_and_test_output(module, sample_input)
1604+
15991605
def test_qnn_backend_conv1d_relu_log_softmax(self):
16001606
module = Conv1dReluLogSoftmax() # noqa: F405
16011607
sample_input = (torch.rand(1, 2, 28),)

0 commit comments

Comments
 (0)