Skip to content

Commit 183fa36

Browse files
billmguofacebook-github-bot
authored andcommitted
Add to.dtype and neg ops (#8041)
Summary: #8041 add neg ops and annonation of to.dtype Reviewed By: Andriyluck Differential Revision: D68815927
1 parent 1d43d91 commit 183fa36

File tree

7 files changed

+85
-0
lines changed

7 files changed

+85
-0
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class LayoutTransform(ExportPass):
7474
exir_ops.edge.aten.mean.dim,
7575
exir_ops.edge.aten.minimum.default,
7676
exir_ops.edge.aten.mul.Tensor,
77+
exir_ops.edge.aten.neg.default,
7778
exir_ops.edge.aten.pow.Tensor_Scalar,
7879
exir_ops.edge.aten.prelu.default,
7980
exir_ops.edge.aten.repeat.default,

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
op_mean_dim,
4747
op_min,
4848
op_mul,
49+
op_neg,
4950
op_pad,
5051
op_pow,
5152
op_prelu,
@@ -118,6 +119,7 @@
118119
op_mean_dim,
119120
op_min,
120121
op_mul,
122+
op_neg,
121123
op_pad,
122124
op_pow,
123125
op_prelu,
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
import torch
10+
11+
from .node_visitor import NodeVisitor, register_node_visitor
12+
from .qnn_constants import OpElementWiseNeg, QNN_OP_PACKAGE_NAME_QTI_AISW
13+
14+
15+
@register_node_visitor
16+
class Neg(NodeVisitor):
17+
target = ["aten.neg.default"]
18+
19+
def __init__(self, *args) -> None:
20+
super().__init__(*args)
21+
22+
def define_node(
23+
self,
24+
node: torch.fx.Node,
25+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
26+
) -> PyQnnWrapper.PyQnnOpWrapper:
27+
input_node = node.args[0]
28+
input_tensor = self.get_tensor(input_node, node)
29+
neg_inp_tensor_wrapper = self.define_tensor(
30+
input_node,
31+
node,
32+
input_tensor,
33+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
34+
nodes_to_wrappers,
35+
)
36+
neg_input_tensors = [neg_inp_tensor_wrapper]
37+
output_tensor = self.get_tensor(node, node)
38+
output_tensor_wrapper = self.define_tensor(
39+
node,
40+
node,
41+
output_tensor,
42+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
43+
nodes_to_wrappers,
44+
)
45+
neg_output_tensors = [output_tensor_wrapper]
46+
neg_op = PyQnnWrapper.PyQnnOpWrapper(
47+
node.name,
48+
QNN_OP_PACKAGE_NAME_QTI_AISW,
49+
OpElementWiseNeg.op_name,
50+
)
51+
neg_op.AddInputTensors(neg_input_tensors)
52+
neg_op.AddOutputTensors(neg_output_tensors)
53+
return neg_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,11 @@ class OpElementWiseMultiply:
145145
op_name: str = "ElementWiseMultiply"
146146

147147

148+
@dataclass(init=False, frozen=True)
149+
class OpElementWiseNeg:
150+
op_name: str = "ElementWiseNeg"
151+
152+
148153
@dataclass(init=False, frozen=True)
149154
class OpElementWiseNeuron:
150155
op_name: str = "ElementWiseNeuron"

backends/qualcomm/quantizer/annotators.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,11 @@ def annotate_max_pool2d_with_indices(
396396
annotate_single_in_single_out(node, quantization_config)
397397

398398

399+
@register_annotator([torch.ops.aten.neg.default])
400+
def annotate_neg(node: Node, quantization_config: QuantizationConfig) -> None:
401+
annotate_single_in_single_out(node, quantization_config)
402+
403+
399404
@register_annotator([torch.ops.aten.adaptive_avg_pool2d.default])
400405
def annotate_adaptive_avgpool2d(
401406
node: Node, quantization_config: QuantizationConfig

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,14 @@ def forward(self, x):
882882
return attn_output
883883

884884

885+
class Neg(torch.nn.Module):
886+
def __init__(self):
887+
super().__init__()
888+
889+
def forward(self, x):
890+
return torch.neg(x)
891+
892+
885893
class Pad(torch.nn.Module):
886894
def __init__(self):
887895
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,11 @@ def test_qnn_backend_minimum(self):
541541
sample_input = (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4))
542542
self.lower_module_and_test_output(module, sample_input)
543543

544+
def test_qnn_backend_neg(self):
545+
module = Neg() # noqa: F405
546+
sample_input = (torch.randn(1, 4, 16, 16),)
547+
self.lower_module_and_test_output(module, sample_input)
548+
544549
def test_qnn_backend_pad(self):
545550
module = Pad() # noqa: F405
546551
sample_input = (torch.randn([1, 8, 128]),)
@@ -1418,6 +1423,12 @@ def test_qnn_backend_minimum(self):
14181423
module = self.get_qdq_module(module, sample_input)
14191424
self.lower_module_and_test_output(module, sample_input)
14201425

1426+
def test_qnn_backend_neg(self):
1427+
module = Neg() # noqa: F405
1428+
sample_input = (torch.randn(1, 4, 16, 16),)
1429+
module = self.get_qdq_module(module, sample_input)
1430+
self.lower_module_and_test_output(module, sample_input)
1431+
14211432
def test_qnn_backend_pad(self):
14221433
module = Pad() # noqa: F405
14231434
sample_input = (torch.randn([1, 8, 128]),)

0 commit comments

Comments
 (0)