Skip to content

Commit 30df6a3

Browse files
billmguofacebook-github-bot
authored andcommitted
Add to.dtype and neg ops (#8041)
Summary: Pull Request resolved: #8041 #8041 add neg ops and annonation of to.dtype Differential Revision: D68815927
1 parent 69fef28 commit 30df6a3

File tree

8 files changed

+92
-2
lines changed

8 files changed

+92
-2
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class LayoutTransform(ExportPass):
7575
exir_ops.edge.aten.mean.dim,
7676
exir_ops.edge.aten.minimum.default,
7777
exir_ops.edge.aten.mul.Tensor,
78+
exir_ops.edge.aten.neg.default,
7879
exir_ops.edge.aten.pow.Tensor_Scalar,
7980
exir_ops.edge.aten.prelu.default,
8081
exir_ops.edge.aten.repeat.default,
@@ -87,6 +88,7 @@ class LayoutTransform(ExportPass):
8788
exir_ops.edge.aten.sum.dim_IntList,
8889
exir_ops.edge.aten.topk.default,
8990
exir_ops.edge.aten._to_copy.default,
91+
exir_ops.edge.aten.to.dtype,
9092
*q_ops,
9193
*dq_ops,
9294
_operator.getitem,

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
op_mean_dim,
4848
op_min,
4949
op_mul,
50+
op_neg,
5051
op_pad,
5152
op_pow,
5253
op_prelu,
@@ -120,6 +121,7 @@
120121
op_mean_dim,
121122
op_min,
122123
op_mul,
124+
op_neg,
123125
op_pad,
124126
op_pow,
125127
op_prelu,
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
import torch
10+
11+
from .node_visitor import NodeVisitor, register_node_visitor
12+
from .qnn_constants import OpElementWiseNeg, QNN_OP_PACKAGE_NAME_QTI_AISW
13+
14+
15+
@register_node_visitor
16+
class Neg(NodeVisitor):
17+
target = ["aten.neg.default"]
18+
19+
def __init__(self, *args) -> None:
20+
super().__init__(*args)
21+
22+
def define_node(
23+
self,
24+
node: torch.fx.Node,
25+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
26+
) -> PyQnnWrapper.PyQnnOpWrapper:
27+
input_node = node.args[0]
28+
input_tensor = self.get_tensor(input_node, node)
29+
neg_inp_tensor_wrapper = self.define_tensor(
30+
input_node,
31+
node,
32+
input_tensor,
33+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
34+
nodes_to_wrappers,
35+
)
36+
neg_input_tensors = [neg_inp_tensor_wrapper]
37+
output_tensor = self.get_tensor(node, node)
38+
output_tensor_wrapper = self.define_tensor(
39+
node,
40+
node,
41+
output_tensor,
42+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
43+
nodes_to_wrappers,
44+
)
45+
neg_output_tensors = [output_tensor_wrapper]
46+
neg_op = PyQnnWrapper.PyQnnOpWrapper(
47+
node.name,
48+
QNN_OP_PACKAGE_NAME_QTI_AISW,
49+
OpElementWiseNeg.op_name,
50+
)
51+
neg_op.AddInputTensors(neg_input_tensors)
52+
neg_op.AddOutputTensors(neg_output_tensors)
53+
return neg_op

backends/qualcomm/builders/op_to.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616

1717
@register_node_visitor
1818
class To(NodeVisitor):
19-
target = ["aten._to_copy.default", "dim_order_ops._to_dim_order_copy.default"]
19+
target = [
20+
"aten._to_copy.default",
21+
"dim_order_ops._to_dim_order_copy.default",
22+
"aten.to.dtype",
23+
]
2024
sufixed_8_offset_diff = 128
2125
sufixed_16_offset_diff = 32768
2226
epsilon = 1e-6

backends/qualcomm/builders/qnn_constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,11 @@ class OpElementWiseMultiply:
145145
op_name: str = "ElementWiseMultiply"
146146

147147

148+
@dataclass(init=False, frozen=True)
149+
class OpElementWiseNeg:
150+
op_name: str = "ElementWiseNeg"
151+
152+
148153
@dataclass(init=False, frozen=True)
149154
class OpElementWiseNeuron:
150155
op_name: str = "ElementWiseNeuron"

backends/qualcomm/quantizer/annotators.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,11 @@ def annotate_max_pool2d_with_indices(
398398
annotate_single_in_single_out(node, quantization_config)
399399

400400

401+
@register_annotator([torch.ops.aten.neg.default])
402+
def annotate_neg(node: Node, quantization_config: QuantizationConfig) -> None:
403+
annotate_single_in_single_out(node, quantization_config)
404+
405+
401406
@register_annotator([torch.ops.aten.adaptive_avg_pool2d.default])
402407
def annotate_adaptive_avgpool2d(
403408
node: Node, quantization_config: QuantizationConfig

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,14 @@ def forward(self, x):
893893
return attn_output
894894

895895

896+
class Neg(torch.nn.Module):
897+
def __init__(self):
898+
super().__init__()
899+
900+
def forward(self, x):
901+
return torch.neg(x)
902+
903+
896904
class Pad(torch.nn.Module):
897905
def __init__(self):
898906
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,11 @@ def test_qnn_backend_argmin(self):
685685
module = Conv2dArgmin() # noqa: F405
686686
sample_input = (torch.randn(16, 3, 16, 16),)
687687
self.lower_module_and_test_output(module, sample_input)
688+
689+
def test_qnn_backend_neg(self):
690+
module = Neg() # noqa: F405
691+
sample_input = (torch.randn(1, 4, 16, 16),)
692+
self.lower_module_and_test_output(module, sample_input)
688693

689694
class TestQNNFloatingPointModel(TestQNN):
690695
# TODO: refactor to support different backends
@@ -1421,7 +1426,7 @@ def test_qnn_backend_minimum(self):
14211426
sample_input = (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4))
14221427
module = self.get_qdq_module(module, sample_input)
14231428
self.lower_module_and_test_output(module, sample_input)
1424-
1429+
14251430
def test_qnn_backend_pad(self):
14261431
module = Pad() # noqa: F405
14271432
sample_input = (torch.randn([1, 8, 128]),)
@@ -1596,6 +1601,12 @@ def test_qnn_backend_argmin(self):
15961601
sample_input = (torch.randn(16, 3, 16, 16),)
15971602
module = self.get_qdq_module(module, sample_input)
15981603
self.lower_module_and_test_output(module, sample_input)
1604+
1605+
def test_qnn_backend_neg(self):
1606+
module = Neg() # noqa: F405
1607+
sample_input = (torch.randn(1, 4, 16, 16),)
1608+
module = self.get_qdq_module(module, sample_input)
1609+
self.lower_module_and_test_output(module, sample_input)
15991610

16001611

16011612
class TestQNNQuantizedModel(TestQNN):

0 commit comments

Comments
 (0)