Skip to content

Commit a5874d9

Browse files
billmguofacebook-github-bot
authored andcommitted
Add to.dtype and neg ops (#8041)
Summary: #8041 add neg ops and annonation of to.dtype Differential Revision: D68815927
1 parent 4796da7 commit a5874d9

File tree

8 files changed

+84
-1
lines changed

8 files changed

+84
-1
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class LayoutTransform(ExportPass):
7373
exir_ops.edge.aten.mean.dim,
7474
exir_ops.edge.aten.minimum.default,
7575
exir_ops.edge.aten.mul.Tensor,
76+
exir_ops.edge.aten.neg.default,
7677
exir_ops.edge.aten.pow.Tensor_Scalar,
7778
exir_ops.edge.aten.prelu.default,
7879
exir_ops.edge.aten.repeat.default,
@@ -85,6 +86,7 @@ class LayoutTransform(ExportPass):
8586
exir_ops.edge.aten.sum.dim_IntList,
8687
exir_ops.edge.aten.topk.default,
8788
exir_ops.edge.aten._to_copy.default,
89+
exir_ops.edge.aten.to.dtype,
8890
*q_ops,
8991
*dq_ops,
9092
_operator.getitem,

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
op_mean_dim,
4646
op_min,
4747
op_mul,
48+
op_neg,
4849
op_pad,
4950
op_pow,
5051
op_prelu,
@@ -116,6 +117,7 @@
116117
op_mean_dim,
117118
op_min,
118119
op_mul,
120+
op_neg,
119121
op_pad,
120122
op_pow,
121123
op_prelu,
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
import torch
10+
11+
from .node_visitor import NodeVisitor, register_node_visitor
12+
from .qnn_constants import OpElementWiseNeg, QNN_OP_PACKAGE_NAME_QTI_AISW
13+
14+
15+
@register_node_visitor
16+
class Neg(NodeVisitor):
17+
target = ["aten.neg.default"]
18+
19+
def __init__(self, *args) -> None:
20+
super().__init__(*args)
21+
22+
def define_node(
23+
self,
24+
node: torch.fx.Node,
25+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
26+
) -> PyQnnWrapper.PyQnnOpWrapper:
27+
input_node = node.args[0]
28+
input_tensor = self.get_tensor(input_node, node)
29+
neg_inp_tensor_wrapper = self.define_tensor(
30+
input_node,
31+
node,
32+
input_tensor,
33+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
34+
nodes_to_wrappers,
35+
)
36+
neg_input_tensors = [neg_inp_tensor_wrapper]
37+
output_tensor = self.get_tensor(node, node)
38+
output_tensor_wrapper = self.define_tensor(
39+
node,
40+
node,
41+
output_tensor,
42+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
43+
nodes_to_wrappers,
44+
)
45+
neg_output_tensors = [output_tensor_wrapper]
46+
neg_op = PyQnnWrapper.PyQnnOpWrapper(
47+
node.name,
48+
QNN_OP_PACKAGE_NAME_QTI_AISW,
49+
OpElementWiseNeg.op_name,
50+
)
51+
neg_op.AddInputTensors(neg_input_tensors)
52+
neg_op.AddOutputTensors(neg_output_tensors)
53+
return neg_op

backends/qualcomm/builders/op_to.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616

1717
@register_node_visitor
1818
class To(NodeVisitor):
19-
target = ["aten._to_copy.default", "dim_order_ops._to_dim_order_copy.default"]
19+
target = [
20+
"aten._to_copy.default",
21+
"dim_order_ops._to_dim_order_copy.default",
22+
"aten.to.dtype"
23+
]
2024
sufixed_8_offset_diff = 128
2125
sufixed_16_offset_diff = 32768
2226
epsilon = 1e-6

backends/qualcomm/builders/qnn_constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,11 @@ class OpElementWiseMultiply:
145145
op_name: str = "ElementWiseMultiply"
146146

147147

148+
@dataclass(init=False, frozen=True)
149+
class OpElementWiseNeg:
150+
op_name: str = "ElementWiseNeg"
151+
152+
148153
@dataclass(init=False, frozen=True)
149154
class OpElementWiseNeuron:
150155
op_name: str = "ElementWiseNeuron"

backends/qualcomm/quantizer/annotators.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,11 @@ def annotate_max_pool2d_with_indices(
396396
annotate_single_in_single_out(node, quantization_config)
397397

398398

399+
@register_annotator([torch.ops.aten.neg.default])
400+
def annotate_neg(node: Node, quantization_config: QuantizationConfig) -> None:
401+
annotate_single_in_single_out(node, quantization_config)
402+
403+
399404
@register_annotator([torch.ops.aten.adaptive_avg_pool2d.default])
400405
def annotate_adaptive_avgpool2d(
401406
node: Node, quantization_config: QuantizationConfig

backends/qualcomm/tests/models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,13 @@ def forward(self, x):
863863
attn_output, _ = self.multi_head_attention(x, x, x, need_weights=False)
864864
return attn_output
865865

866+
class Neg(torch.nn.Module):
867+
def __init__(self):
868+
super().__init__()
869+
870+
def forward(self, x):
871+
return torch.neg(x)
872+
866873

867874
class Pad(torch.nn.Module):
868875
def __init__(self):

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,11 @@ def test_qnn_backend_minimum(self):
532532
module = Minimum() # noqa: F405
533533
sample_input = (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4))
534534
self.lower_module_and_test_output(module, sample_input)
535+
536+
def test_qnn_backend_neg(self):
537+
module = Neg() # noqa: F405
538+
sample_input = (torch.randn(1, 4, 16, 16),)
539+
self.lower_module_and_test_output(module, sample_input)
535540

536541
def test_qnn_backend_pad(self):
537542
module = Pad() # noqa: F405

0 commit comments

Comments
 (0)