Skip to content

Commit 3f188ff

Browse files
committed
delegated copy op
1 parent 30b31ac commit 3f188ff

File tree

4 files changed

+94
-2
lines changed

4 files changed

+94
-2
lines changed

backends/qualcomm/builders/__init__.py

100644100755
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
op_ceil,
1515
op_clamp,
1616
op_conv2d,
17+
op_copy,
1718
op_depth_to_space,
1819
op_dequantize,
1920
op_div,
@@ -70,6 +71,7 @@
7071
op_ceil,
7172
op_clamp,
7273
op_conv2d,
74+
op_copy,
7375
op_depth_to_space,
7476
op_dequantize,
7577
op_div,
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import torch
11+
from executorch.backends.qualcomm.utils.constants import (
12+
QCOM_QUANT_ATTRS,
13+
QCOM_SCALE,
14+
QCOM_ZERO_POINT,
15+
)
16+
from executorch.exir.dialects._ops import ops as exir_ops
17+
18+
from .node_visitor import NodeVisitor, register_node_visitor
19+
from .qnn_constants import OpElementWiseAdd, QNN_OP_PACKAGE_NAME_QTI_AISW
20+
21+
22+
@register_node_visitor
23+
class Copy(NodeVisitor):
24+
target = ["aten.copy.default"]
25+
26+
def __init__(self, *args) -> None:
27+
super().__init__(*args)
28+
29+
def define_node(
30+
self,
31+
node: torch.fx.Node,
32+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
33+
) -> PyQnnWrapper.PyQnnOpWrapper:
34+
input_node = node.args[0]
35+
input_tensor = self.get_tensor(input_node, node)
36+
copy_inp_tensor_wrapper = self.define_tensor(
37+
input_node,
38+
input_tensor,
39+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
40+
nodes_to_wrappers,
41+
is_input_tensor=True,
42+
)
43+
# 'graph', 'name', 'op', 'target', 'args', and 'kwargs'
44+
zero_input_node = torch.fx.Node(
45+
node.graph,
46+
node.name + "_runtime_scalar",
47+
"call_function",
48+
exir_ops.edge.aten.scalar_tensor.default,
49+
(), # args
50+
{}, # kwargs
51+
)
52+
zero_input_tensor = torch.tensor(0, dtype=input_tensor.dtype)
53+
if quant_attrs := input_node.meta.get(QCOM_QUANT_ATTRS):
54+
quant_attrs = quant_attrs.copy()
55+
quant_attrs[QCOM_ZERO_POINT] = 0
56+
quant_attrs[QCOM_SCALE] = 1
57+
zero_input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
58+
59+
60+
zero_tensor_wrapper = self.define_tensor(
61+
zero_input_node,
62+
zero_input_tensor,
63+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
64+
nodes_to_wrappers,
65+
is_input_tensor=True,
66+
)
67+
copy_input_tensors = [copy_inp_tensor_wrapper, zero_tensor_wrapper]
68+
69+
if quant_attrs := input_node.meta.get(QCOM_QUANT_ATTRS):
70+
quant_attrs = quant_attrs.copy()
71+
# Because there is no output after convert_pt2e, the QCOM_QUANT_ATTRS of node is none
72+
node.meta[QCOM_QUANT_ATTRS] = quant_attrs
73+
output_tensor = self.get_tensor(node, node)
74+
output_tensor_wrapper = self.define_tensor(
75+
node,
76+
output_tensor,
77+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
78+
nodes_to_wrappers,
79+
is_input_tensor=False,
80+
)
81+
copy_output_tensors = [output_tensor_wrapper]
82+
83+
copy_op = PyQnnWrapper.PyQnnOpWrapper(
84+
node.name,
85+
QNN_OP_PACKAGE_NAME_QTI_AISW,
86+
OpElementWiseAdd.op_name,
87+
)
88+
copy_op.AddInputTensors(copy_input_tensors)
89+
copy_op.AddOutputTensors(copy_output_tensors)
90+
91+
return copy_op

backends/qualcomm/partition/common_defs.py

100644100755
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
exir_ops.edge.aten.clone.default,
1414
exir_ops.edge.aten.full.default,
1515
exir_ops.edge.aten.slice_scatter.default,
16-
exir_ops.edge.aten.copy.default,
1716
exir_ops.edge.quantized_decomposed.embedding_4bit.dtype,
1817
]
1918

backends/qualcomm/quantizer/custom_annotation.py

100644100755
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def get_custom_quant_ios_dtype(
161161

162162
# Tag index put node before copy node, because copy is a skipped node in qnn
163163
if (
164-
exir_ops.edge.aten.index_put.default == node.target
164+
exir_ops.edge.aten.copy.default == node.target
165165
and node.meta["val"].shape == cache_shape
166166
):
167167
return kv_dtype

0 commit comments

Comments
 (0)