Skip to content

Commit 5a622b7

Browse files
winskuo-quicshewu-quic
authored andcommitted
Qualcomm AI Engine Direct - Support floor_divide with int input in QNN HTP backend (#14888)
### Summary - Since QNN does not support floor_divide operations for int32 or int64 inputs, it is necessary to decompose the operation into a division using floating-point precision, followed by applying the floor function. ### Test plan UT added Author: @shewu-quic cc @cccclai @shewu-quic @haowhsu-quic @DannyYuyang-quic @cbilgin --------- Co-authored-by: shewu <[email protected]> (cherry picked from commit 5af73eb)
1 parent ad1ca3f commit 5a622b7

File tree

4 files changed

+124
-21
lines changed

4 files changed

+124
-21
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .decompose_col_im import DecomposeColIm
1919
from .decompose_einsum import DecomposeEinsum
2020
from .decompose_expm1 import DecomposeExpM1
21+
from .decompose_floor_divide import DecomposeFloorDivide
2122
from .decompose_glu import DecomposeGlu
2223
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
2324
from .decompose_minmaxdim import DecomposeMinMaxDim
@@ -61,6 +62,7 @@
6162
DecomposeColIm,
6263
DecomposeEinsum,
6364
DecomposeExpM1,
65+
DecomposeFloorDivide,
6466
DecomposeGlu,
6567
DecomposeLinalgVectorNorm,
6668
DecomposeMinMaxDim,
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
10+
from .utils import merge_decomposed_graph
11+
12+
13+
class FloorDivide(torch.nn.Module):
14+
def __init__(self):
15+
super().__init__()
16+
17+
def forward(self, x, y):
18+
dtype = x.dtype
19+
result = torch.div(x, y)
20+
result = torch.floor(result)
21+
return result.to(dtype)
22+
23+
24+
class DecomposeFloorDivide(ExportPass):
25+
"""
26+
Decompose for math equivalent op.
27+
Since QNN does not support floor_divide operations for int32 or int64 inputs,
28+
it is necessary to decompose the operation into a division using floating-point precision,
29+
followed by applying the floor function.
30+
"""
31+
32+
def __init__(self) -> None:
33+
super().__init__()
34+
35+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
36+
graph = graph_module.graph
37+
for node in graph.nodes:
38+
model = FloorDivide()
39+
if (
40+
torch.ops.aten.floor_divide.default == node.target
41+
and not torch.is_floating_point(node.meta["val"])
42+
):
43+
decomposed_module = torch.export.export(
44+
model,
45+
(node.args[0].meta["val"], node.args[1].meta["val"]),
46+
strict=True,
47+
).module()
48+
with graph.inserting_before(node):
49+
# remap is used to map original node values to new node values,
50+
# which ensures that reference to nodes are correctly updated in the new graph
51+
remap = {"x": node.args[0], "y": node.args[1]}
52+
merge_decomposed_graph(
53+
remap=remap,
54+
target_node=node,
55+
target_graph=graph,
56+
decomposed_graph_module=decomposed_module,
57+
)
58+
graph.erase_node(node)
59+
60+
graph.eliminate_dead_code()
61+
graph_module.recompile()
62+
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
DecomposeColIm,
2424
DecomposeEinsum,
2525
DecomposeExpM1,
26+
DecomposeFloorDivide,
2627
DecomposeGlu,
2728
DecomposeLinalgVectorNorm,
2829
DecomposeMinMaxDim,
@@ -223,6 +224,11 @@ def transform_for_export_pipeline(
223224
self.add_pass(DecomposeThreshold())
224225
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
225226
self.add_pass(DecomposeExpM1())
227+
# DecomposeFloorDivide does not apply to the annotation pipeline,
228+
# since the CPU QDQ model would reduce accuracy.
229+
# We keep div and floor operations in floating-point to maintain precision.
230+
# This pass is needed before to_edge pipeline to avoid mixed type for div operator with RemoveMixedTypeOperators pass.
231+
self.add_pass(DecomposeFloorDivide())
226232
self.add_pass(DecomposeWrapWithAutocast())
227233
# this pass will rewrite state_dict, it needs to be accomplished before
228234
# to_edge_transform_and_lower

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -398,8 +398,8 @@ def test_qnn_backend_cumsum(self):
398398
for module in comb[QCOM_MODULE]:
399399
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
400400
with self.subTest(i=index):
401-
self.lower_module_and_test_output(module, sample_input)
402401
index += 1
402+
self.lower_module_and_test_output(module, sample_input)
403403

404404
def test_qnn_backend_einsum_outer_product(self):
405405
module = EinsumOuterProduct() # noqa: F405
@@ -467,8 +467,8 @@ def test_qnn_backend_element_wise_add(self):
467467
for module in comb[QCOM_MODULE]:
468468
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
469469
with self.subTest(i=index):
470-
self.lower_module_and_test_output(module, sample_input)
471470
index += 1
471+
self.lower_module_and_test_output(module, sample_input)
472472

473473
def test_qnn_backend_element_wise_and(self):
474474
module = And(torch.tensor(1.7), torch.tensor(0.2)) # noqa: F405
@@ -506,8 +506,8 @@ def test_qnn_backend_element_wise_div(self):
506506
for module in comb[QCOM_MODULE]:
507507
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
508508
with self.subTest(i=index):
509-
self.lower_module_and_test_output(module, sample_input)
510509
index += 1
510+
self.lower_module_and_test_output(module, sample_input)
511511

512512
def test_qnn_backend_element_wise_mul(self):
513513
test_comb = [
@@ -533,8 +533,8 @@ def test_qnn_backend_element_wise_mul(self):
533533
for module in comb[QCOM_MODULE]:
534534
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
535535
with self.subTest(i=index):
536-
self.lower_module_and_test_output(module, sample_input)
537536
index += 1
537+
self.lower_module_and_test_output(module, sample_input)
538538

539539
def test_qnn_backend_element_wise_or(self):
540540
test_comb = [
@@ -608,8 +608,8 @@ def test_qnn_backend_element_wise_sub(self):
608608
for module in comb[QCOM_MODULE]:
609609
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
610610
with self.subTest(i=index):
611-
self.lower_module_and_test_output(module, sample_input)
612611
index += 1
612+
self.lower_module_and_test_output(module, sample_input)
613613

614614
@unittest.expectedFailure
615615
def test_qnn_backend_elu(self):
@@ -651,10 +651,10 @@ def test_qnn_backend_expand(self):
651651
for module in modules:
652652
for sample_input in sample_inputs:
653653
with self.subTest(i=index):
654+
index += 1
654655
self.lower_module_and_test_output(
655656
module, sample_input, passes_job=passes_job
656657
)
657-
index += 1
658658

659659
def test_qnn_backend_expm1(self):
660660
sample_input = (torch.randn(3, 4, 5),)
@@ -677,6 +677,21 @@ def test_qnn_backend_floor_divide(self):
677677
{
678678
QCOM_MODULE: [FloorDiv()], # noqa: F405
679679
QCOM_SAMPLE_INPUTS: [
680+
(torch.randint(-100, 100, (10, 10)), torch.full((10, 10), 3)),
681+
(
682+
torch.randint(-100, 100, (10, 10)).float(),
683+
torch.full((10, 10), 2.5),
684+
),
685+
(torch.randint(-1000, 1000, (10, 10)), torch.full((10, 10), 100)),
686+
(torch.tensor([10]), torch.arange(1, 5)), # Failed
687+
(torch.arange(-10, 10), torch.tensor([2])),
688+
(torch.randint(-100, 100, (20,)), torch.full((20,), 2)),
689+
(torch.randint(-100, 100, (5, 10)), torch.full((5, 10), 2)),
690+
(torch.randint(-100, 100, (3, 4, 5)), torch.full((3, 4, 5), 2)),
691+
(
692+
torch.randint(-100, 100, (2, 3, 4, 5)),
693+
torch.full((2, 3, 4, 5), 2),
694+
),
680695
(torch.randn(2, 5, 1, 3), eps + torch.randn(2, 5, 1, 3)),
681696
(torch.randn([2, 5, 1, 3]), eps + torch.randn([4, 1])),
682697
],
@@ -692,8 +707,8 @@ def test_qnn_backend_floor_divide(self):
692707
for module in comb[QCOM_MODULE]:
693708
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
694709
with self.subTest(i=index):
695-
self.lower_module_and_test_output(module, sample_input)
696710
index += 1
711+
self.lower_module_and_test_output(module, sample_input)
697712

698713
def test_qnn_backend_fold(self):
699714
sample_input = (torch.randn(3, 512, 256),)
@@ -1136,8 +1151,8 @@ def test_qnn_backend_leaky_relu(self):
11361151
for module in comb[QCOM_MODULE]:
11371152
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
11381153
with self.subTest(i=index):
1139-
self.lower_module_and_test_output(module, sample_input)
11401154
index += 1
1155+
self.lower_module_and_test_output(module, sample_input)
11411156

11421157
def test_qnn_backend_less_equal(self):
11431158
test_comb = [
@@ -1392,8 +1407,8 @@ def test_qnn_backend_prelu(self):
13921407
for module in comb[QCOM_MODULE]:
13931408
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
13941409
with self.subTest(i=index):
1395-
self.lower_module_and_test_output(module, sample_input)
13961410
index += 1
1411+
self.lower_module_and_test_output(module, sample_input)
13971412

13981413
def test_qnn_backend_relu(self):
13991414
module = Relu() # noqa: F405
@@ -1520,8 +1535,8 @@ def test_qnn_backend_slice_scatter(self):
15201535
for module in comb[QCOM_MODULE]:
15211536
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
15221537
with self.subTest(i=index):
1523-
self.lower_module_and_test_output(module, sample_input)
15241538
index += 1
1539+
self.lower_module_and_test_output(module, sample_input)
15251540

15261541
def test_qnn_backend_stack(self):
15271542
module = Stack() # noqa: F405
@@ -2332,9 +2347,9 @@ def test_qnn_backend_element_wise_add(self):
23322347
for module in comb[QCOM_MODULE]:
23332348
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
23342349
with self.subTest(i=index):
2350+
index += 1
23352351
gm = self.get_qdq_module(module, sample_input)
23362352
self.lower_module_and_test_output(gm, sample_input)
2337-
index += 1
23382353

23392354
def test_qnn_backend_element_wise_and(self):
23402355
module = And(torch.tensor(1.7), torch.tensor(0.2)) # noqa: F405
@@ -2373,9 +2388,9 @@ def test_qnn_backend_element_wise_div(self):
23732388
for module in comb[QCOM_MODULE]:
23742389
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
23752390
with self.subTest(i=index):
2391+
index += 1
23762392
gm = self.get_qdq_module(module, sample_input)
23772393
self.lower_module_and_test_output(gm, sample_input)
2378-
index += 1
23792394

23802395
def test_qnn_backend_element_wise_mul(self):
23812396
test_comb = [
@@ -2401,9 +2416,9 @@ def test_qnn_backend_element_wise_mul(self):
24012416
for module in comb[QCOM_MODULE]:
24022417
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
24032418
with self.subTest(i=index):
2419+
index += 1
24042420
gm = self.get_qdq_module(module, sample_input)
24052421
self.lower_module_and_test_output(gm, sample_input)
2406-
index += 1
24072422

24082423
def test_qnn_backend_element_wise_or(self):
24092424
test_comb = [
@@ -2479,9 +2494,9 @@ def test_qnn_backend_element_wise_sub(self):
24792494
for module in comb[QCOM_MODULE]:
24802495
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
24812496
with self.subTest(i=index):
2497+
index += 1
24822498
gm = self.get_qdq_module(module, sample_input)
24832499
self.lower_module_and_test_output(gm, sample_input)
2484-
index += 1
24852500

24862501
def test_qnn_backend_elu(self):
24872502
module = Elu() # noqa: F405
@@ -2530,11 +2545,11 @@ def test_qnn_backend_expand(self):
25302545
for module in modules:
25312546
for sample_input in sample_inputs:
25322547
with self.subTest(i=index):
2548+
index += 1
25332549
module = self.get_qdq_module(module, sample_input)
25342550
self.lower_module_and_test_output(
25352551
module, sample_input, passes_job=passes_job
25362552
)
2537-
index += 1
25382553

25392554
def test_qnn_backend_expm1(self):
25402555
sample_input = (torch.randn(3, 4, 5),)
@@ -2560,6 +2575,21 @@ def test_qnn_backend_floor_divide(self):
25602575
{
25612576
QCOM_MODULE: [FloorDiv()], # noqa: F405
25622577
QCOM_SAMPLE_INPUTS: [
2578+
(torch.randint(-100, 100, (10, 10)), torch.full((10, 10), 3)),
2579+
(
2580+
torch.randint(-100, 100, (10, 10)).float(),
2581+
torch.full((10, 10), 2.5),
2582+
),
2583+
(torch.randint(-1000, 1000, (10, 10)), torch.full((10, 10), 100)),
2584+
(torch.tensor([10]), torch.arange(1, 5)),
2585+
(torch.arange(-10, 10), torch.tensor([2])),
2586+
(torch.randint(-100, 100, (20,)), torch.full((20,), 2)),
2587+
(torch.randint(-100, 100, (5, 10)), torch.full((5, 10), 2)),
2588+
(torch.randint(-100, 100, (3, 4, 5)), torch.full((3, 4, 5), 2)),
2589+
(
2590+
torch.randint(-100, 100, (2, 3, 4, 5)),
2591+
torch.full((2, 3, 4, 5), 2),
2592+
),
25632593
(torch.randn(2, 5, 1, 3), eps + torch.randn(2, 5, 1, 3)),
25642594
(torch.randn([2, 5, 1, 3]), eps + torch.randn([4, 1])),
25652595
],
@@ -2575,9 +2605,12 @@ def test_qnn_backend_floor_divide(self):
25752605
for module in comb[QCOM_MODULE]:
25762606
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
25772607
with self.subTest(i=index):
2578-
gm = self.get_qdq_module(module, sample_input)
2579-
self.lower_module_and_test_output(gm, sample_input)
25802608
index += 1
2609+
# Support int input cases with bypass_check=True
2610+
gm = self.get_qdq_module(
2611+
module, sample_input, bypass_check=True
2612+
)
2613+
self.lower_module_and_test_output(gm, sample_input)
25812614

25822615
def test_qnn_backend_fold(self):
25832616
sample_input = (torch.randn(3, 512, 256),)
@@ -3048,9 +3081,9 @@ def test_qnn_backend_leaky_relu(self):
30483081
for module in comb[QCOM_MODULE]:
30493082
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
30503083
with self.subTest(i=index):
3084+
index += 1
30513085
module = self.get_qdq_module(module, sample_input)
30523086
self.lower_module_and_test_output(module, sample_input)
3053-
index += 1
30543087

30553088
def test_qnn_backend_less_equal(self):
30563089
test_comb = [
@@ -3352,9 +3385,9 @@ def test_qnn_backend_prelu(self):
33523385
for module in comb[QCOM_MODULE]:
33533386
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
33543387
with self.subTest(i=index):
3355-
qdq_module = self.get_qdq_module(module, sample_input)
3356-
self.lower_module_and_test_output(qdq_module, sample_input)
33573388
index += 1
3389+
module = self.get_qdq_module(module, sample_input)
3390+
self.lower_module_and_test_output(module, sample_input)
33583391

33593392
def test_qnn_backend_relu(self):
33603393
module = Relu() # noqa: F405
@@ -3504,9 +3537,9 @@ def test_qnn_backend_slice_scatter(self):
35043537
for module in comb[QCOM_MODULE]:
35053538
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
35063539
with self.subTest(i=index):
3540+
index += 1
35073541
module = self.get_qdq_module(module, sample_input)
35083542
self.lower_module_and_test_output(module, sample_input)
3509-
index += 1
35103543

35113544
def test_qnn_backend_softmax(self):
35123545
modules = [Softmax(dim=1), Softmax(dim=-1)] # noqa: F405

0 commit comments

Comments
 (0)