Skip to content

Commit 5af73eb

Browse files
Qualcomm AI Engine Direct - Support floor_divide with int input in QNN HTP backend (pytorch#14888)
### Summary - Since QNN does not support floor_divide operations for int32 or int64 inputs, it is necessary to decompose the operation into a division using floating-point precision, followed by applying the floor function. ### Test plan UT added Author: @shewu-quic cc @cccclai @shewu-quic @haowhsu-quic @DannyYuyang-quic @cbilgin --------- Co-authored-by: shewu <[email protected]>
1 parent a9fe0b4 commit 5af73eb

File tree

4 files changed

+124
-21
lines changed

4 files changed

+124
-21
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .decompose_col_im import DecomposeColIm
1919
from .decompose_einsum import DecomposeEinsum
2020
from .decompose_expm1 import DecomposeExpM1
21+
from .decompose_floor_divide import DecomposeFloorDivide
2122
from .decompose_glu import DecomposeGlu
2223
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
2324
from .decompose_minmaxdim import DecomposeMinMaxDim
@@ -61,6 +62,7 @@
6162
DecomposeColIm,
6263
DecomposeEinsum,
6364
DecomposeExpM1,
65+
DecomposeFloorDivide,
6466
DecomposeGlu,
6567
DecomposeLinalgVectorNorm,
6668
DecomposeMinMaxDim,
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
10+
from .utils import merge_decomposed_graph
11+
12+
13+
class FloorDivide(torch.nn.Module):
14+
def __init__(self):
15+
super().__init__()
16+
17+
def forward(self, x, y):
18+
dtype = x.dtype
19+
result = torch.div(x, y)
20+
result = torch.floor(result)
21+
return result.to(dtype)
22+
23+
24+
class DecomposeFloorDivide(ExportPass):
25+
"""
26+
Decompose for math equivalent op.
27+
Since QNN does not support floor_divide operations for int32 or int64 inputs,
28+
it is necessary to decompose the operation into a division using floating-point precision,
29+
followed by applying the floor function.
30+
"""
31+
32+
def __init__(self) -> None:
33+
super().__init__()
34+
35+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
36+
graph = graph_module.graph
37+
for node in graph.nodes:
38+
model = FloorDivide()
39+
if (
40+
torch.ops.aten.floor_divide.default == node.target
41+
and not torch.is_floating_point(node.meta["val"])
42+
):
43+
decomposed_module = torch.export.export(
44+
model,
45+
(node.args[0].meta["val"], node.args[1].meta["val"]),
46+
strict=True,
47+
).module()
48+
with graph.inserting_before(node):
49+
# remap is used to map original node values to new node values,
50+
# which ensures that reference to nodes are correctly updated in the new graph
51+
remap = {"x": node.args[0], "y": node.args[1]}
52+
merge_decomposed_graph(
53+
remap=remap,
54+
target_node=node,
55+
target_graph=graph,
56+
decomposed_graph_module=decomposed_module,
57+
)
58+
graph.erase_node(node)
59+
60+
graph.eliminate_dead_code()
61+
graph_module.recompile()
62+
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
DecomposeColIm,
2424
DecomposeEinsum,
2525
DecomposeExpM1,
26+
DecomposeFloorDivide,
2627
DecomposeGlu,
2728
DecomposeLinalgVectorNorm,
2829
DecomposeMinMaxDim,
@@ -223,6 +224,11 @@ def transform_for_export_pipeline(
223224
self.add_pass(DecomposeThreshold())
224225
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
225226
self.add_pass(DecomposeExpM1())
227+
# DecomposeFloorDivide does not apply to the annotation pipeline,
228+
# since the CPU QDQ model would reduce accuracy.
229+
# We keep div and floor operations in floating-point to maintain precision.
230+
# This pass is needed before to_edge pipeline to avoid mixed type for div operator with RemoveMixedTypeOperators pass.
231+
self.add_pass(DecomposeFloorDivide())
226232
self.add_pass(DecomposeWrapWithAutocast())
227233
# this pass will rewrite state_dict, it needs to be accomplished before
228234
# to_edge_transform_and_lower

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -397,8 +397,8 @@ def test_qnn_backend_cumsum(self):
397397
for module in comb[QCOM_MODULE]:
398398
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
399399
with self.subTest(i=index):
400-
self.lower_module_and_test_output(module, sample_input)
401400
index += 1
401+
self.lower_module_and_test_output(module, sample_input)
402402

403403
def test_qnn_backend_einsum_outer_product(self):
404404
module = EinsumOuterProduct() # noqa: F405
@@ -466,8 +466,8 @@ def test_qnn_backend_element_wise_add(self):
466466
for module in comb[QCOM_MODULE]:
467467
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
468468
with self.subTest(i=index):
469-
self.lower_module_and_test_output(module, sample_input)
470469
index += 1
470+
self.lower_module_and_test_output(module, sample_input)
471471

472472
def test_qnn_backend_element_wise_and(self):
473473
module = And(torch.tensor(1.7), torch.tensor(0.2)) # noqa: F405
@@ -505,8 +505,8 @@ def test_qnn_backend_element_wise_div(self):
505505
for module in comb[QCOM_MODULE]:
506506
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
507507
with self.subTest(i=index):
508-
self.lower_module_and_test_output(module, sample_input)
509508
index += 1
509+
self.lower_module_and_test_output(module, sample_input)
510510

511511
def test_qnn_backend_element_wise_mul(self):
512512
test_comb = [
@@ -532,8 +532,8 @@ def test_qnn_backend_element_wise_mul(self):
532532
for module in comb[QCOM_MODULE]:
533533
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
534534
with self.subTest(i=index):
535-
self.lower_module_and_test_output(module, sample_input)
536535
index += 1
536+
self.lower_module_and_test_output(module, sample_input)
537537

538538
def test_qnn_backend_element_wise_or(self):
539539
test_comb = [
@@ -607,8 +607,8 @@ def test_qnn_backend_element_wise_sub(self):
607607
for module in comb[QCOM_MODULE]:
608608
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
609609
with self.subTest(i=index):
610-
self.lower_module_and_test_output(module, sample_input)
611610
index += 1
611+
self.lower_module_and_test_output(module, sample_input)
612612

613613
@unittest.expectedFailure
614614
def test_qnn_backend_elu(self):
@@ -650,10 +650,10 @@ def test_qnn_backend_expand(self):
650650
for module in modules:
651651
for sample_input in sample_inputs:
652652
with self.subTest(i=index):
653+
index += 1
653654
self.lower_module_and_test_output(
654655
module, sample_input, passes_job=passes_job
655656
)
656-
index += 1
657657

658658
def test_qnn_backend_expm1(self):
659659
sample_input = (torch.randn(3, 4, 5),)
@@ -676,6 +676,21 @@ def test_qnn_backend_floor_divide(self):
676676
{
677677
QCOM_MODULE: [FloorDiv()], # noqa: F405
678678
QCOM_SAMPLE_INPUTS: [
679+
(torch.randint(-100, 100, (10, 10)), torch.full((10, 10), 3)),
680+
(
681+
torch.randint(-100, 100, (10, 10)).float(),
682+
torch.full((10, 10), 2.5),
683+
),
684+
(torch.randint(-1000, 1000, (10, 10)), torch.full((10, 10), 100)),
685+
(torch.tensor([10]), torch.arange(1, 5)), # Failed
686+
(torch.arange(-10, 10), torch.tensor([2])),
687+
(torch.randint(-100, 100, (20,)), torch.full((20,), 2)),
688+
(torch.randint(-100, 100, (5, 10)), torch.full((5, 10), 2)),
689+
(torch.randint(-100, 100, (3, 4, 5)), torch.full((3, 4, 5), 2)),
690+
(
691+
torch.randint(-100, 100, (2, 3, 4, 5)),
692+
torch.full((2, 3, 4, 5), 2),
693+
),
679694
(torch.randn(2, 5, 1, 3), eps + torch.randn(2, 5, 1, 3)),
680695
(torch.randn([2, 5, 1, 3]), eps + torch.randn([4, 1])),
681696
],
@@ -691,8 +706,8 @@ def test_qnn_backend_floor_divide(self):
691706
for module in comb[QCOM_MODULE]:
692707
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
693708
with self.subTest(i=index):
694-
self.lower_module_and_test_output(module, sample_input)
695709
index += 1
710+
self.lower_module_and_test_output(module, sample_input)
696711

697712
def test_qnn_backend_fold(self):
698713
sample_input = (torch.randn(3, 512, 256),)
@@ -972,8 +987,8 @@ def test_qnn_backend_leaky_relu(self):
972987
for module in comb[QCOM_MODULE]:
973988
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
974989
with self.subTest(i=index):
975-
self.lower_module_and_test_output(module, sample_input)
976990
index += 1
991+
self.lower_module_and_test_output(module, sample_input)
977992

978993
def test_qnn_backend_less_equal(self):
979994
test_comb = [
@@ -1228,8 +1243,8 @@ def test_qnn_backend_prelu(self):
12281243
for module in comb[QCOM_MODULE]:
12291244
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
12301245
with self.subTest(i=index):
1231-
self.lower_module_and_test_output(module, sample_input)
12321246
index += 1
1247+
self.lower_module_and_test_output(module, sample_input)
12331248

12341249
def test_qnn_backend_relu(self):
12351250
module = Relu() # noqa: F405
@@ -1356,8 +1371,8 @@ def test_qnn_backend_slice_scatter(self):
13561371
for module in comb[QCOM_MODULE]:
13571372
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
13581373
with self.subTest(i=index):
1359-
self.lower_module_and_test_output(module, sample_input)
13601374
index += 1
1375+
self.lower_module_and_test_output(module, sample_input)
13611376

13621377
def test_qnn_backend_stack(self):
13631378
module = Stack() # noqa: F405
@@ -2168,9 +2183,9 @@ def test_qnn_backend_element_wise_add(self):
21682183
for module in comb[QCOM_MODULE]:
21692184
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
21702185
with self.subTest(i=index):
2186+
index += 1
21712187
gm = self.get_qdq_module(module, sample_input)
21722188
self.lower_module_and_test_output(gm, sample_input)
2173-
index += 1
21742189

21752190
def test_qnn_backend_element_wise_and(self):
21762191
module = And(torch.tensor(1.7), torch.tensor(0.2)) # noqa: F405
@@ -2209,9 +2224,9 @@ def test_qnn_backend_element_wise_div(self):
22092224
for module in comb[QCOM_MODULE]:
22102225
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
22112226
with self.subTest(i=index):
2227+
index += 1
22122228
gm = self.get_qdq_module(module, sample_input)
22132229
self.lower_module_and_test_output(gm, sample_input)
2214-
index += 1
22152230

22162231
def test_qnn_backend_element_wise_mul(self):
22172232
test_comb = [
@@ -2237,9 +2252,9 @@ def test_qnn_backend_element_wise_mul(self):
22372252
for module in comb[QCOM_MODULE]:
22382253
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
22392254
with self.subTest(i=index):
2255+
index += 1
22402256
gm = self.get_qdq_module(module, sample_input)
22412257
self.lower_module_and_test_output(gm, sample_input)
2242-
index += 1
22432258

22442259
def test_qnn_backend_element_wise_or(self):
22452260
test_comb = [
@@ -2315,9 +2330,9 @@ def test_qnn_backend_element_wise_sub(self):
23152330
for module in comb[QCOM_MODULE]:
23162331
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
23172332
with self.subTest(i=index):
2333+
index += 1
23182334
gm = self.get_qdq_module(module, sample_input)
23192335
self.lower_module_and_test_output(gm, sample_input)
2320-
index += 1
23212336

23222337
def test_qnn_backend_elu(self):
23232338
module = Elu() # noqa: F405
@@ -2366,11 +2381,11 @@ def test_qnn_backend_expand(self):
23662381
for module in modules:
23672382
for sample_input in sample_inputs:
23682383
with self.subTest(i=index):
2384+
index += 1
23692385
module = self.get_qdq_module(module, sample_input)
23702386
self.lower_module_and_test_output(
23712387
module, sample_input, passes_job=passes_job
23722388
)
2373-
index += 1
23742389

23752390
def test_qnn_backend_expm1(self):
23762391
sample_input = (torch.randn(3, 4, 5),)
@@ -2396,6 +2411,21 @@ def test_qnn_backend_floor_divide(self):
23962411
{
23972412
QCOM_MODULE: [FloorDiv()], # noqa: F405
23982413
QCOM_SAMPLE_INPUTS: [
2414+
(torch.randint(-100, 100, (10, 10)), torch.full((10, 10), 3)),
2415+
(
2416+
torch.randint(-100, 100, (10, 10)).float(),
2417+
torch.full((10, 10), 2.5),
2418+
),
2419+
(torch.randint(-1000, 1000, (10, 10)), torch.full((10, 10), 100)),
2420+
(torch.tensor([10]), torch.arange(1, 5)),
2421+
(torch.arange(-10, 10), torch.tensor([2])),
2422+
(torch.randint(-100, 100, (20,)), torch.full((20,), 2)),
2423+
(torch.randint(-100, 100, (5, 10)), torch.full((5, 10), 2)),
2424+
(torch.randint(-100, 100, (3, 4, 5)), torch.full((3, 4, 5), 2)),
2425+
(
2426+
torch.randint(-100, 100, (2, 3, 4, 5)),
2427+
torch.full((2, 3, 4, 5), 2),
2428+
),
23992429
(torch.randn(2, 5, 1, 3), eps + torch.randn(2, 5, 1, 3)),
24002430
(torch.randn([2, 5, 1, 3]), eps + torch.randn([4, 1])),
24012431
],
@@ -2411,9 +2441,12 @@ def test_qnn_backend_floor_divide(self):
24112441
for module in comb[QCOM_MODULE]:
24122442
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
24132443
with self.subTest(i=index):
2414-
gm = self.get_qdq_module(module, sample_input)
2415-
self.lower_module_and_test_output(gm, sample_input)
24162444
index += 1
2445+
# Support int input cases with bypass_check=True
2446+
gm = self.get_qdq_module(
2447+
module, sample_input, bypass_check=True
2448+
)
2449+
self.lower_module_and_test_output(gm, sample_input)
24172450

24182451
def test_qnn_backend_fold(self):
24192452
sample_input = (torch.randn(3, 512, 256),)
@@ -2719,9 +2752,9 @@ def test_qnn_backend_leaky_relu(self):
27192752
for module in comb[QCOM_MODULE]:
27202753
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
27212754
with self.subTest(i=index):
2755+
index += 1
27222756
module = self.get_qdq_module(module, sample_input)
27232757
self.lower_module_and_test_output(module, sample_input)
2724-
index += 1
27252758

27262759
def test_qnn_backend_less_equal(self):
27272760
test_comb = [
@@ -3023,9 +3056,9 @@ def test_qnn_backend_prelu(self):
30233056
for module in comb[QCOM_MODULE]:
30243057
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
30253058
with self.subTest(i=index):
3026-
qdq_module = self.get_qdq_module(module, sample_input)
3027-
self.lower_module_and_test_output(qdq_module, sample_input)
30283059
index += 1
3060+
module = self.get_qdq_module(module, sample_input)
3061+
self.lower_module_and_test_output(module, sample_input)
30293062

30303063
def test_qnn_backend_relu(self):
30313064
module = Relu() # noqa: F405
@@ -3175,9 +3208,9 @@ def test_qnn_backend_slice_scatter(self):
31753208
for module in comb[QCOM_MODULE]:
31763209
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
31773210
with self.subTest(i=index):
3211+
index += 1
31783212
module = self.get_qdq_module(module, sample_input)
31793213
self.lower_module_and_test_output(module, sample_input)
3180-
index += 1
31813214

31823215
def test_qnn_backend_softmax(self):
31833216
modules = [Softmax(dim=1), Softmax(dim=-1)] # noqa: F405

0 commit comments

Comments
 (0)