Skip to content

Commit 19d7f42

Browse files
billmguofacebook-github-bot
authored andcommitted
Add to.dtype and neg ops (#8041)
Summary: Pull Request resolved: #8041 #8041 add neg ops and annonation of to.dtype Reviewed By: Andriyluck Differential Revision: D68815927
1 parent 75800ee commit 19d7f42

File tree

7 files changed

+87
-8
lines changed

7 files changed

+87
-8
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class LayoutTransform(ExportPass):
7575
exir_ops.edge.aten.mean.dim,
7676
exir_ops.edge.aten.minimum.default,
7777
exir_ops.edge.aten.mul.Tensor,
78+
exir_ops.edge.aten.neg.default,
7879
exir_ops.edge.aten.pow.Tensor_Scalar,
7980
exir_ops.edge.aten.prelu.default,
8081
exir_ops.edge.aten.repeat.default,

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
op_mean_dim,
4848
op_min,
4949
op_mul,
50+
op_neg,
5051
op_pad,
5152
op_pow,
5253
op_prelu,
@@ -120,6 +121,7 @@
120121
op_mean_dim,
121122
op_min,
122123
op_mul,
124+
op_neg,
123125
op_pad,
124126
op_pow,
125127
op_prelu,
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
import torch
10+
11+
from .node_visitor import NodeVisitor, register_node_visitor
12+
from .qnn_constants import OpElementWiseNeg, QNN_OP_PACKAGE_NAME_QTI_AISW
13+
14+
15+
@register_node_visitor
16+
class Neg(NodeVisitor):
17+
target = ["aten.neg.default"]
18+
19+
def __init__(self, *args) -> None:
20+
super().__init__(*args)
21+
22+
def define_node(
23+
self,
24+
node: torch.fx.Node,
25+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
26+
) -> PyQnnWrapper.PyQnnOpWrapper:
27+
input_node = node.args[0]
28+
input_tensor = self.get_tensor(input_node, node)
29+
neg_inp_tensor_wrapper = self.define_tensor(
30+
input_node,
31+
node,
32+
input_tensor,
33+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
34+
nodes_to_wrappers,
35+
)
36+
neg_input_tensors = [neg_inp_tensor_wrapper]
37+
output_tensor = self.get_tensor(node, node)
38+
output_tensor_wrapper = self.define_tensor(
39+
node,
40+
node,
41+
output_tensor,
42+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
43+
nodes_to_wrappers,
44+
)
45+
neg_output_tensors = [output_tensor_wrapper]
46+
neg_op = PyQnnWrapper.PyQnnOpWrapper(
47+
node.name,
48+
QNN_OP_PACKAGE_NAME_QTI_AISW,
49+
OpElementWiseNeg.op_name,
50+
)
51+
neg_op.AddInputTensors(neg_input_tensors)
52+
neg_op.AddOutputTensors(neg_output_tensors)
53+
return neg_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,11 @@ class OpElementWiseMultiply:
145145
op_name: str = "ElementWiseMultiply"
146146

147147

148+
@dataclass(init=False, frozen=True)
149+
class OpElementWiseNeg:
150+
op_name: str = "ElementWiseNeg"
151+
152+
148153
@dataclass(init=False, frozen=True)
149154
class OpElementWiseNeuron:
150155
op_name: str = "ElementWiseNeuron"

backends/qualcomm/quantizer/annotators.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,7 @@ def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None
162162
)
163163

164164

165-
@register_annotator(
166-
[torch.ops.aten.add, torch.ops.aten.add.Tensor]
167-
)
165+
@register_annotator([torch.ops.aten.add, torch.ops.aten.add.Tensor])
168166
def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None:
169167
annotate_binary(node, quantization_config)
170168

@@ -398,6 +396,11 @@ def annotate_max_pool2d_with_indices(
398396
annotate_single_in_single_out(node, quantization_config)
399397

400398

399+
@register_annotator([torch.ops.aten.neg.default])
400+
def annotate_neg(node: Node, quantization_config: QuantizationConfig) -> None:
401+
annotate_single_in_single_out(node, quantization_config)
402+
403+
401404
@register_annotator([torch.ops.aten.adaptive_avg_pool2d.default])
402405
def annotate_adaptive_avgpool2d(
403406
node: Node, quantization_config: QuantizationConfig
@@ -479,8 +482,6 @@ def annotate_log_softmax(node: Node, quantization_config: QuantizationConfig) ->
479482
annotate_single_in_single_out(node, quantization_config)
480483

481484

482-
483-
484485
@register_annotator([torch.ops.aten.log.default])
485486
def annotate_log(node: Node, quantization_config: QuantizationConfig) -> None:
486487
annotate_single_in_single_out(node, quantization_config)

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,14 @@ def forward(self, x):
894894
return attn_output
895895

896896

897+
class Neg(torch.nn.Module):
898+
def __init__(self):
899+
super().__init__()
900+
901+
def forward(self, x):
902+
return torch.neg(x)
903+
904+
897905
class Pad(torch.nn.Module):
898906
def __init__(self):
899907
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def test_qnn_backend_arange(self):
111111
for i, module in enumerate(modules):
112112
with self.subTest(i=i):
113113
self.lower_module_and_test_output(module, sample_input)
114-
114+
115115
def test_qnn_backend_argmin(self):
116116
module = Conv2dArgmin() # noqa: F405
117117
sample_input = (torch.randn(16, 3, 16, 16),)
@@ -546,6 +546,11 @@ def test_qnn_backend_minimum(self):
546546
sample_input = (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4))
547547
self.lower_module_and_test_output(module, sample_input)
548548

549+
def test_qnn_backend_neg(self):
550+
module = Neg() # noqa: F405
551+
sample_input = (torch.randn(1, 4, 16, 16),)
552+
self.lower_module_and_test_output(module, sample_input)
553+
549554
def test_qnn_backend_pad(self):
550555
module = Pad() # noqa: F405
551556
sample_input = (torch.randn([1, 8, 128]),)
@@ -1434,6 +1439,12 @@ def test_qnn_backend_minimum(self):
14341439
module = self.get_qdq_module(module, sample_input)
14351440
self.lower_module_and_test_output(module, sample_input)
14361441

1442+
def test_qnn_backend_neg(self):
1443+
module = Neg() # noqa: F405
1444+
sample_input = (torch.randn(1, 4, 16, 16),)
1445+
module = self.get_qdq_module(module, sample_input)
1446+
self.lower_module_and_test_output(module, sample_input)
1447+
14371448
def test_qnn_backend_pad(self):
14381449
module = Pad() # noqa: F405
14391450
sample_input = (torch.randn([1, 8, 128]),)
@@ -1603,8 +1614,6 @@ def test_qnn_backend_view(self):
16031614
module = self.get_qdq_module(module, sample_input)
16041615
self.lower_module_and_test_output(module, sample_input)
16051616

1606-
1607-
16081617

16091618
class TestQNNQuantizedModel(TestQNN):
16101619
# TODO: refactor to support different backends

0 commit comments

Comments
 (0)