Skip to content

Commit 200dd3d

Browse files
billmguofacebook-github-bot
authored andcommitted
Add to.dtype and neg ops
Summary: add neg ops and annonation of to.dtype Differential Revision: D68815927
1 parent c5fea7e commit 200dd3d

File tree

7 files changed

+66
-1
lines changed

7 files changed

+66
-1
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class LayoutTransform(ExportPass):
7373
exir_ops.edge.aten.mean.dim,
7474
exir_ops.edge.aten.minimum.default,
7575
exir_ops.edge.aten.mul.Tensor,
76+
exir_ops.edge.aten.neg.default,
7677
exir_ops.edge.aten.pow.Tensor_Scalar,
7778
exir_ops.edge.aten.prelu.default,
7879
exir_ops.edge.aten.repeat.default,
@@ -85,6 +86,7 @@ class LayoutTransform(ExportPass):
8586
exir_ops.edge.aten.sum.dim_IntList,
8687
exir_ops.edge.aten.topk.default,
8788
exir_ops.edge.aten._to_copy.default,
89+
exir_ops.edge.aten.to.dtype,
8890
*q_ops,
8991
*dq_ops,
9092
_operator.getitem,

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
op_mean_dim,
4646
op_min,
4747
op_mul,
48+
op_neg,
4849
op_pad,
4950
op_pow,
5051
op_prelu,
@@ -116,6 +117,7 @@
116117
op_mean_dim,
117118
op_min,
118119
op_mul,
120+
op_neg,
119121
op_pad,
120122
op_pow,
121123
op_prelu,
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
8+
import torch
9+
from .node_visitor import NodeVisitor, register_node_visitor
10+
from .qnn_constants import OpElementWiseNeg, QNN_OP_PACKAGE_NAME_QTI_AISW
11+
@register_node_visitor
12+
class Neg(NodeVisitor):
13+
target = ["aten.neg.default"]
14+
def __init__(self, *args) -> None:
15+
super().__init__(*args)
16+
def define_node(
17+
self,
18+
node: torch.fx.Node,
19+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
20+
) -> PyQnnWrapper.PyQnnOpWrapper:
21+
input_node = node.args[0]
22+
input_tensor = self.get_tensor(input_node, node)
23+
neg_inp_tensor_wrapper = self.define_tensor(
24+
input_node,
25+
node,
26+
input_tensor,
27+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
28+
nodes_to_wrappers,
29+
)
30+
neg_input_tensors = [neg_inp_tensor_wrapper]
31+
output_tensor = self.get_tensor(node, node)
32+
output_tensor_wrapper = self.define_tensor(
33+
node,
34+
node,
35+
output_tensor,
36+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
37+
nodes_to_wrappers,
38+
)
39+
neg_output_tensors = [output_tensor_wrapper]
40+
neg_op = PyQnnWrapper.PyQnnOpWrapper(
41+
node.name,
42+
QNN_OP_PACKAGE_NAME_QTI_AISW,
43+
OpElementWiseNeg.op_name,
44+
)
45+
neg_op.AddInputTensors(neg_input_tensors)
46+
neg_op.AddOutputTensors(neg_output_tensors)
47+
return neg_op

backends/qualcomm/builders/op_to.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@register_node_visitor
1818
class To(NodeVisitor):
19-
target = ["aten._to_copy.default", "dim_order_ops._to_dim_order_copy.default"]
19+
target = ["aten._to_copy.default", "dim_order_ops._to_dim_order_copy.default", "aten.to.dtype"]
2020
sufixed_8_offset_diff = 128
2121
sufixed_16_offset_diff = 32768
2222
epsilon = 1e-6

backends/qualcomm/builders/qnn_constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ class OpElementWiseMinimum:
144144
class OpElementWiseMultiply:
145145
op_name: str = "ElementWiseMultiply"
146146

147+
@dataclass(init=False, frozen=True)
148+
class OpElementWiseNeg:
149+
op_name: str = "ElementWiseNeg"
147150

148151
@dataclass(init=False, frozen=True)
149152
class OpElementWiseNeuron:

backends/qualcomm/quantizer/annotators.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,9 @@ def annotate_max_pool2d_with_indices(
395395
) -> None:
396396
annotate_single_in_single_out(node, quantization_config)
397397

398+
@register_annotator([torch.ops.aten.neg.default])
399+
def annotate_log(node: Node, quantization_config: QuantizationConfig) -> None:
400+
annotate_single_in_single_out(node, quantization_config)
398401

399402
@register_annotator([torch.ops.aten.adaptive_avg_pool2d.default])
400403
def annotate_adaptive_avgpool2d(

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,14 @@ def forward(self, x):
863863
attn_output, _ = self.multi_head_attention(x, x, x, need_weights=False)
864864
return attn_output
865865

866+
class Neg(torch.nn.Module):
867+
def __init__(self):
868+
super().__init__()
869+
870+
871+
def forward(self, x):
872+
return torch.neg(x)
873+
866874

867875
class Pad(torch.nn.Module):
868876
def __init__(self):

0 commit comments

Comments
 (0)