Skip to content

Commit d9f3f62

Browse files
committed
Support Binary Alpha Operator
1 parent 5fd66ee commit d9f3f62

File tree

5 files changed

+172
-0
lines changed

5 files changed

+172
-0
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .convert_linear_to_conv2d import ConvertLinearToConv2d
1414
from .convert_square_to_pow import ConvertSquareToPow
1515
from .decompose_any import DecomposeAny
16+
from .decompose_binary_alpha import DecomposeBinaryAlpha
1617
from .decompose_cdist import DecomposeCDist
1718
from .decompose_col_im import DecomposeColIm
1819
from .decompose_einsum import DecomposeEinsum
@@ -54,6 +55,7 @@
5455
ConvertLinearToConv2d,
5556
ConvertSquareToPow,
5657
DecomposeAny,
58+
DecomposeBinaryAlpha,
5759
DecomposeCDist,
5860
DecomposeColIm,
5961
DecomposeEinsum,
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
10+
from .utils import copy_meta
11+
12+
decomp_set = {torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor}
13+
14+
15+
class DecomposeBinaryAlpha(ExportPass):
16+
"""
17+
QNN does not support alpha parameter for add/sub.
18+
Decompose to mul + add / mul + sub
19+
"""
20+
21+
def __init__(self) -> None:
22+
super().__init__()
23+
24+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
25+
graph = graph_module.graph
26+
for node in graph.nodes:
27+
if node.target in decomp_set and "alpha" in node.kwargs:
28+
alpha = node.kwargs["alpha"]
29+
# Remove alpha from immutable dict
30+
node.kwargs = {k: v for k, v in node.kwargs.items() if k != "alpha"}
31+
with graph.inserting_before(node):
32+
input2_node = node.args[1]
33+
mul_op = torch.ops.aten.mul.Scalar
34+
mul_node = graph.create_node(
35+
"call_function",
36+
mul_op,
37+
(
38+
input2_node,
39+
alpha,
40+
),
41+
)
42+
mul_node.meta = copy_meta(node.meta)
43+
mul_node.users = {node: None}
44+
node.args = (
45+
node.args[0],
46+
mul_node,
47+
)
48+
49+
graph.eliminate_dead_code()
50+
graph_module.recompile()
51+
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ConvertLinearToConv2d,
1919
ConvertSquareToPow,
2020
DecomposeAny,
21+
DecomposeBinaryAlpha,
2122
DecomposeCDist,
2223
DecomposeColIm,
2324
DecomposeEinsum,
@@ -194,6 +195,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
194195
self.add_pass(RecomposePixelUnshuffle(quantization_capture=True))
195196
self.add_pass(RecomposeRmsNorm(quantization_capture=True))
196197
self.add_pass(ReplaceArangeArgs())
198+
self.add_pass(DecomposeBinaryAlpha())
197199
self.add_pass(DecomposeCDist())
198200
self.add_pass(DecomposeScaledDotProductAttention())
199201
self.add_pass(DecomposeRoll())
@@ -210,6 +212,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
210212
def transform_for_export_pipeline(
211213
self, exported_program: ExportedProgram, convert_linear_to_conv2d: bool = False
212214
):
215+
self.add_pass(DecomposeBinaryAlpha())
213216
self.add_pass(DecomposeCDist())
214217
self.add_pass(DecomposeScaledDotProductAttention())
215218
self.add_pass(DecomposeRoll())

backends/qualcomm/tests/models.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,28 @@ def forward(self, x, y):
6666
return torch.add(x, y)
6767

6868

69+
class AddAlpha(torch.nn.Module):
70+
def __init__(self, alpha):
71+
super().__init__()
72+
self.alpha = alpha
73+
74+
def forward(self, x, y):
75+
return torch.add(x, y, alpha=self.alpha)
76+
77+
78+
class AddAlphaConstant(torch.nn.Module):
79+
def __init__(self, alpha, constant_first=False):
80+
super().__init__()
81+
self.alpha = alpha
82+
self.constant_first = constant_first
83+
84+
def forward(self, x):
85+
if self.constant_first:
86+
return torch.add(5.0, x, alpha=self.alpha)
87+
else:
88+
return torch.add(x, 5.0, alpha=self.alpha)
89+
90+
6991
class AddConstantFloat(torch.nn.Module):
7092
def __init__(self):
7193
super().__init__()
@@ -1863,6 +1885,28 @@ def forward(self, x, y):
18631885
return torch.sub(x, y)
18641886

18651887

1888+
class SubAlpha(torch.nn.Module):
1889+
def __init__(self, alpha):
1890+
super().__init__()
1891+
self.alpha = alpha
1892+
1893+
def forward(self, x, y):
1894+
return torch.sub(x, y, alpha=self.alpha)
1895+
1896+
1897+
class SubAlphaConstant(torch.nn.Module):
1898+
def __init__(self, alpha, constant_first=False):
1899+
super().__init__()
1900+
self.alpha = alpha
1901+
self.constant_first = constant_first
1902+
1903+
def forward(self, x):
1904+
if self.constant_first:
1905+
return torch.sub(5.0, x, alpha=self.alpha)
1906+
else:
1907+
return torch.sub(x, 5.0, alpha=self.alpha)
1908+
1909+
18661910
class SubConstantFloat(torch.nn.Module):
18671911
def __init__(self):
18681912
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,24 @@ def test_qnn_backend_element_wise_add(self):
372372
],
373373
QCOM_SAMPLE_INPUTS: [(torch.randint(0, 10, size=(2, 3)),)],
374374
},
375+
{
376+
QCOM_MODULE: [
377+
AddAlpha(alpha=2), # noqa: F405
378+
],
379+
QCOM_SAMPLE_INPUTS: [
380+
(
381+
torch.tensor([[1.2, 1.3, 1.4]]),
382+
torch.tensor([[0.8, 1.6, 0.2]]),
383+
)
384+
],
385+
},
386+
{
387+
QCOM_MODULE: [
388+
AddAlphaConstant(alpha=2, constant_first=True), # noqa: F405
389+
AddAlphaConstant(alpha=2, constant_first=False), # noqa: F405
390+
],
391+
QCOM_SAMPLE_INPUTS: [(torch.tensor([[1.2, 1.3, 1.4]]),)],
392+
},
375393
]
376394

377395
index = 0
@@ -495,6 +513,24 @@ def test_qnn_backend_element_wise_sub(self):
495513
QCOM_MODULE: [SubConstantFloat()], # noqa: F405
496514
QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
497515
},
516+
{
517+
QCOM_MODULE: [
518+
SubAlpha(alpha=2), # noqa: F405
519+
],
520+
QCOM_SAMPLE_INPUTS: [
521+
(
522+
torch.tensor([[1.2, 1.3, 1.4]]),
523+
torch.tensor([[0.8, 1.6, 0.2]]),
524+
)
525+
],
526+
},
527+
{
528+
QCOM_MODULE: [
529+
SubAlphaConstant(alpha=2, constant_first=True), # noqa: F405
530+
SubAlphaConstant(alpha=2, constant_first=False), # noqa: F405
531+
],
532+
QCOM_SAMPLE_INPUTS: [(torch.tensor([[1.2, 1.3, 1.4]]),)],
533+
},
498534
]
499535

500536
index = 0
@@ -1880,6 +1916,24 @@ def test_qnn_backend_element_wise_add(self):
18801916
QCOM_MODULE: [AddConstantFloat(), AddConstantLong()], # noqa: F405
18811917
QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
18821918
},
1919+
{
1920+
QCOM_MODULE: [
1921+
AddAlpha(alpha=2), # noqa: F405
1922+
],
1923+
QCOM_SAMPLE_INPUTS: [
1924+
(
1925+
torch.tensor([[1.2, 1.3, 1.4]]),
1926+
torch.tensor([[0.8, 1.6, 0.2]]),
1927+
)
1928+
],
1929+
},
1930+
{
1931+
QCOM_MODULE: [
1932+
AddAlphaConstant(alpha=2, constant_first=True), # noqa: F405
1933+
AddAlphaConstant(alpha=2, constant_first=False), # noqa: F405
1934+
],
1935+
QCOM_SAMPLE_INPUTS: [(torch.tensor([[1.2, 1.3, 1.4]]),)],
1936+
},
18831937
]
18841938

18851939
index = 0
@@ -2009,6 +2063,24 @@ def test_qnn_backend_element_wise_sub(self):
20092063
QCOM_MODULE: [SubConstantFloat(), SubConstantLong()], # noqa: F405
20102064
QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
20112065
},
2066+
{
2067+
QCOM_MODULE: [
2068+
SubAlpha(alpha=2), # noqa: F405
2069+
],
2070+
QCOM_SAMPLE_INPUTS: [
2071+
(
2072+
torch.tensor([[1.2, 1.3, 1.4]]),
2073+
torch.tensor([[0.8, 1.6, 0.2]]),
2074+
)
2075+
],
2076+
},
2077+
{
2078+
QCOM_MODULE: [
2079+
SubAlphaConstant(alpha=2, constant_first=True), # noqa: F405
2080+
SubAlphaConstant(alpha=2, constant_first=False), # noqa: F405
2081+
],
2082+
QCOM_SAMPLE_INPUTS: [(torch.tensor([[1.2, 1.3, 1.4]]),)],
2083+
},
20122084
]
20132085

20142086
index = 0

0 commit comments

Comments
 (0)