Skip to content

Commit a825fd1

Browse files
Qualcomm AI Engine Direct - Support tile op for different I/O rank
Summary: - Support if the rank of input tensor is less than the rank of output tensor. - make_quantizer kwargs alignment. - Remove module.eval() since calling eval() is not supported for exported models.
1 parent c9c5481 commit a825fd1

File tree

5 files changed

+45
-16
lines changed

5 files changed

+45
-16
lines changed

backends/qualcomm/_passes/expand_broadcast_tensor_shape.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,16 @@ def __init__(self):
2222
exir_ops.edge.aten.sub.Tensor,
2323
exir_ops.edge.aten.mul.Tensor,
2424
exir_ops.edge.aten.div.Tensor,
25+
# Support if the rank of input tensor: {input_dims} is less than the rank of output tensor: {output_dims}.
26+
exir_ops.edge.aten.expand_copy.default,
2527
]
2628

2729
def traverse_broadcast_node(self, graph_module: torch.fx.GraphModule):
2830
for node in graph_module.graph.nodes:
2931
if node.target in self.broadcast_op_targets:
3032
for arg in node.args:
33+
if not isinstance(arg, torch.fx.Node):
34+
continue
3135
input_rank = len(arg.meta["val"].shape)
3236
output_rank = len(node.meta["val"].shape)
3337
if input_rank != output_rank:

backends/qualcomm/_passes/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def get_passes_dependency_for_capture_program():
104104
ConvertConv1dToConv2d: [FoldQDQ],
105105
DecomposeAny: [RemoveRedundancy],
106106
DecomposeLinalgVectorNorm: [RemoveRedundancy],
107-
ExpandBroadcastTensorShape: [RemoveRedundancy],
107+
ExpandBroadcastTensorShape: [FoldQDQ],
108108
FixedLinearKeepDim: [FoldQDQ],
109109
FoldQDQ: [AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind],
110110
I64toI32: [RemoveRedundancy],

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,11 @@
6969
from collections import defaultdict
7070
from typing import List
7171

72-
from executorch.backends.qualcomm._passes import FoldQDQ, TagQuantIO
72+
from executorch.backends.qualcomm._passes import (
73+
ExpandBroadcastTensorShape,
74+
FoldQDQ,
75+
TagQuantIO,
76+
)
7377
from executorch.backends.qualcomm.builders.node_visitor import get_node_visitors
7478
from executorch.backends.qualcomm.debugger.utils import DrawGraph
7579
from executorch.examples.models.deeplab_v3 import DeepLabV3ResNet101Model
@@ -430,10 +434,20 @@ def test_qnn_backend_equal(self):
430434

431435
def test_qnn_backend_expand(self):
432436
modules = [ExpandAs(), ExpandCopy()] # noqa: F405
433-
sample_input = (torch.randn([3, 1]),)
434-
for i, module in enumerate(modules):
435-
with self.subTest(i=i):
436-
self.lower_module_and_test_output(module, sample_input)
437+
sample_inputs = [
438+
(torch.randn([3, 1]),),
439+
(torch.randn([4]),),
440+
]
441+
passes_job = get_capture_program_passes()
442+
passes_job[ExpandBroadcastTensorShape][QCOM_PASS_ACTIVATE_KEY] = True
443+
index = 0
444+
for module in modules:
445+
for sample_input in sample_inputs:
446+
with self.subTest(i=index):
447+
self.lower_module_and_test_output(
448+
module, sample_input, passes_job=passes_job
449+
)
450+
index += 1
437451

438452
def test_qnn_backend_expm1(self):
439453
sample_input = (torch.randn(3, 4, 5),)
@@ -1506,11 +1520,21 @@ def test_qnn_backend_equal(self):
15061520

15071521
def test_qnn_backend_expand(self):
15081522
modules = [ExpandAs(), ExpandCopy()] # noqa: F405
1509-
sample_input = (torch.randn([3, 1]),)
1510-
for i, module in enumerate(modules):
1511-
with self.subTest(i=i):
1512-
module = self.get_qdq_module(module, sample_input)
1513-
self.lower_module_and_test_output(module, sample_input)
1523+
sample_inputs = [
1524+
(torch.randn([3, 1]),),
1525+
(torch.randn([4]),),
1526+
]
1527+
passes_job = get_capture_program_passes()
1528+
passes_job[ExpandBroadcastTensorShape][QCOM_PASS_ACTIVATE_KEY] = True
1529+
index = 0
1530+
for module in modules:
1531+
for sample_input in sample_inputs:
1532+
with self.subTest(i=index):
1533+
module = self.get_qdq_module(module, sample_input)
1534+
self.lower_module_and_test_output(
1535+
module, sample_input, passes_job=passes_job
1536+
)
1537+
index += 1
15141538

15151539
def test_qnn_backend_expm1(self):
15161540
sample_input = (torch.randn(3, 4, 5),)

backends/qualcomm/tests/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import subprocess
1010
import tempfile
1111
import unittest
12-
from typing import Callable, Dict, List, Optional, Tuple
12+
from typing import Callable, Dict, List, Optional, OrderedDict, Tuple
1313

1414
import numpy as np
1515
import torch
@@ -435,6 +435,7 @@ def lower_module_and_test_output(
435435
expected_profile_events: int = -1,
436436
expected_intermediate_events: int = -1,
437437
assert_output_equal: bool = True,
438+
passes_job: Optional[OrderedDict] = None,
438439
skip_node_id_set: set = None,
439440
skip_node_op_set: set = None,
440441
dynamic_shapes: Dict = None,
@@ -444,6 +445,7 @@ def lower_module_and_test_output(
444445
sample_inputs,
445446
self.compiler_specs,
446447
dynamic_shapes=dynamic_shapes,
448+
passes_job=passes_job,
447449
skip_node_id_set=skip_node_id_set,
448450
skip_node_op_set=skip_node_op_set,
449451
)
@@ -506,7 +508,6 @@ def get_qdq_module(
506508
block_size_map: Dict[str, Tuple] = None,
507509
submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None,
508510
) -> torch.fx.GraphModule:
509-
module = module.eval()
510511
m = torch.export.export(
511512
module, inputs, dynamic_shapes=dynamic_shapes, strict=True
512513
).module()

examples/qualcomm/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def make_quantizer(
262262
per_channel_linear=False,
263263
act_observer=MovingAverageMinMaxObserver,
264264
is_qat=False,
265-
callback_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None,
265+
submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None,
266266
):
267267
quantizer = QnnQuantizer()
268268
quantizer.add_custom_quant_annotations(custom_annotations)
@@ -273,8 +273,8 @@ def make_quantizer(
273273
is_linear_per_channel=per_channel_linear,
274274
act_observer=act_observer,
275275
)
276-
callback_qconfig_list = callback_qconfig_list or []
277-
quantizer.set_submodule_qconfig_list(callback_qconfig_list)
276+
submodule_qconfig_list = submodule_qconfig_list or []
277+
quantizer.set_submodule_qconfig_list(submodule_qconfig_list)
278278
return quantizer
279279

280280

0 commit comments

Comments
 (0)