Skip to content

Commit b56e33b

Browse files
billmguofacebook-github-bot
authored andcommitted
add argmin, add_ ops (#8040)
Summary: Pull Request resolved: #8040 #8040 add argmin, add_ ops annontation Differential Revision: D68811662
1 parent fccb352 commit b56e33b

File tree

8 files changed

+122
-4
lines changed

8 files changed

+122
-4
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class LayoutTransform(ExportPass):
4545
layout_agnostic_ops = {
4646
exir_ops.edge.aten.abs.default,
4747
exir_ops.edge.aten.add.Tensor,
48+
exir_ops.edge.aten.argmin.default,
4849
exir_ops.edge.aten.bmm.default,
4950
exir_ops.edge.aten.cat.default,
5051
exir_ops.edge.aten.ceil.default,

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
op_adaptive_avg_pool2d,
1111
op_add,
1212
op_arange,
13+
op_argmin,
1314
op_avg_pool2d,
1415
op_batch_norm,
1516
op_bmm,
@@ -82,6 +83,7 @@
8283
op_adaptive_avg_pool2d,
8384
op_add,
8485
op_arange,
86+
op_argmin,
8587
op_avg_pool2d,
8688
op_batch_norm,
8789
op_bmm,

backends/qualcomm/builders/op_add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@register_node_visitor
1717
class Add(NodeVisitor):
18-
target = ["aten.add.Tensor"]
18+
target = ["aten.add.Tensor", "aten.add_.Tensor"]
1919

2020
def __init__(self, *args) -> None:
2121
super().__init__(*args)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import cast, Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
import numpy as np
10+
import torch
11+
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
12+
13+
from .node_visitor import NodeVisitor, register_node_visitor
14+
from .qnn_constants import OpArgmin, QNN_OP_PACKAGE_NAME_QTI_AISW
15+
16+
17+
@register_node_visitor
18+
class Argmin(NodeVisitor):
19+
target = ["aten.argmin.default"]
20+
21+
def __init__(self, *args) -> None:
22+
super().__init__(*args)
23+
24+
def define_node(
25+
self,
26+
node: torch.fx.Node,
27+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
28+
) -> PyQnnWrapper.PyQnnOpWrapper:
29+
input_node = node.args[0]
30+
input_tensor = self.get_tensor(input_node, node)
31+
argmin_inp_tensor_wrapper = self.define_tensor(
32+
input_node,
33+
node,
34+
input_tensor,
35+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
36+
nodes_to_wrappers,
37+
)
38+
argmin_input_tensors = [argmin_inp_tensor_wrapper]
39+
40+
output_tensor = self.get_tensor(node, node)
41+
output_tensor_wrapper = self.define_tensor(
42+
node,
43+
node,
44+
output_tensor,
45+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
46+
nodes_to_wrappers,
47+
)
48+
argmin_output_tensors = [output_tensor_wrapper]
49+
50+
dim = cast(int, node.args[1])
51+
if dim < 0:
52+
dim = dim % len(input_tensor.shape)
53+
if QCOM_AXIS_ORDER in node.meta:
54+
dim = node.meta[QCOM_AXIS_ORDER].index(dim)
55+
56+
argmin_op = PyQnnWrapper.PyQnnOpWrapper(
57+
node.name,
58+
QNN_OP_PACKAGE_NAME_QTI_AISW,
59+
OpArgmin.op_name,
60+
)
61+
argmin_op.AddInputTensors(argmin_input_tensors)
62+
argmin_op.AddOutputTensors(argmin_output_tensors)
63+
64+
argmin_op.AddScalarParam(
65+
OpArgmin.param_axis,
66+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
67+
{QCOM_DATA: np.uint32(dim)},
68+
)
69+
70+
if len(node.args) > 2:
71+
keep_dims = cast(bool, node.args[2])
72+
argmin_op.AddScalarParam(
73+
OpArgmin.param_keep_dims,
74+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
75+
{QCOM_DATA: keep_dims},
76+
)
77+
78+
return argmin_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,13 @@ class OpReduceMean:
301301
param_keep_dims: str = "keep_dims"
302302

303303

304+
@dataclass(init=False, frozen=True)
305+
class OpArgmin:
306+
op_name: str = "Argmin"
307+
param_axis: str = "axis"
308+
param_keep_dims: str = "keep_dims"
309+
310+
304311
@dataclass(init=False, frozen=True)
305312
class OpReduceSum:
306313
op_name: str = "ReduceSum"

backends/qualcomm/quantizer/annotators.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None
162162
)
163163

164164

165-
@register_annotator([torch.ops.aten.add, torch.ops.aten.add.Tensor])
165+
@register_annotator(
166+
[torch.ops.aten.add, torch.ops.aten.add_.Tensor, torch.ops.aten.add.Tensor]
167+
)
166168
def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None:
167169
annotate_binary(node, quantization_config)
168170

@@ -476,6 +478,10 @@ def annotate_softmax(node: Node, quantization_config: QuantizationConfig) -> Non
476478
def annotate_log_softmax(node: Node, quantization_config: QuantizationConfig) -> None:
477479
annotate_single_in_single_out(node, quantization_config)
478480

481+
@register_annotator([torch.ops.aten.argmin.default])
482+
def annotate_argmin(node: Node, quantization_config: QuantizationConfig) -> None:
483+
annotate_single_in_single_out(node, quantization_config)
484+
479485

480486
@register_annotator([torch.ops.aten.log.default])
481487
def annotate_log(node: Node, quantization_config: QuantizationConfig) -> None:

backends/qualcomm/tests/models.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,6 @@ def forward(self, x):
274274
x = self.logsoftmax(x)
275275
return x
276276

277-
278277
class Conv2dAvgPool2d(torch.nn.Module):
279278
def __init__(self):
280279
super().__init__()
@@ -287,6 +286,18 @@ def forward(self, x):
287286
return self.pool(self.conv(x))
288287

289288

289+
class Conv2dArgmin(torch.nn.Module):
290+
def __init__(self):
291+
super().__init__()
292+
self.conv = torch.nn.Conv2d(
293+
3, 16, 7, bias=True, stride=2, padding=3, dilation=1
294+
)
295+
296+
def forward(self, x):
297+
x = self.conv(x)
298+
return torch.argmin(x, dim=0, keepdim=True)
299+
300+
290301
class Conv2dBnHardtanhMean(torch.nn.Module):
291302
def __init__(self):
292303
super(Conv2dBnHardtanhMean, self).__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,10 @@ def test_qnn_backend_view(self):
681681
sample_input = (torch.randn([1, 8, 512]), torch.randn([1, 2, 8, 256]))
682682
self.lower_module_and_test_output(module, sample_input)
683683

684+
def test_qnn_backend_argmin(self):
685+
module = Conv2dArgmin() # noqa: F405
686+
sample_input = (torch.randn(16, 3, 16, 16),)
687+
self.lower_module_and_test_output(module, sample_input)
684688

685689
class TestQNNFloatingPointModel(TestQNN):
686690
# TODO: refactor to support different backends
@@ -704,7 +708,8 @@ def test_qnn_backend_chunk_add(self):
704708
torch.manual_seed(8)
705709
sample_input = (torch.randn(1, 2, 4, 2),)
706710
self.lower_module_and_test_output(module, sample_input)
707-
711+
712+
708713
def test_qnn_backend_conv1d_relu_log_softmax(self):
709714
module = Conv1dReluLogSoftmax() # noqa: F405
710715
sample_input = (torch.rand(1, 2, 28),)
@@ -1585,6 +1590,12 @@ def test_qnn_backend_view(self):
15851590
sample_input = (torch.randn([1, 8, 512]), torch.randn([1, 2, 8, 256]))
15861591
module = self.get_qdq_module(module, sample_input)
15871592
self.lower_module_and_test_output(module, sample_input)
1593+
1594+
def test_qnn_backend_argmin(self):
1595+
module = Conv2dArgmin() # noqa: F405
1596+
sample_input = (torch.randn(16, 3, 16, 16),)
1597+
module = self.get_qdq_module(module, sample_input)
1598+
self.lower_module_and_test_output(module, sample_input)
15881599

15891600

15901601
class TestQNNQuantizedModel(TestQNN):
@@ -1610,6 +1621,8 @@ def test_qnn_backend_chunk_add(self):
16101621
sample_input = (torch.randn(1, 1, 4, 2),)
16111622
module = self.get_qdq_module(module, sample_input)
16121623
self.lower_module_and_test_output(module, sample_input)
1624+
1625+
16131626

16141627
def test_qnn_backend_conv1d_relu_log_softmax(self):
16151628
module = Conv1dReluLogSoftmax() # noqa: F405

0 commit comments

Comments
 (0)