Skip to content

Commit 1a4c77c

Browse files
committed
Qualcomm AI Engine Direct - GA FocalNet
1 parent 4e38f4a commit 1a4c77c

19 files changed

+284
-39
lines changed

backends/qualcomm/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
253253

254254
pybind11_extension(PyQnnManagerAdaptor)
255255
pybind11_extension(PyQnnWrapperAdaptor)
256-
if(NOT MSVC AND NOT ${CMAKE_BUILD_TYPE} MATCHES Debug|RelWithDebInfo)
256+
if(NOT MSVC AND NOT ${CMAKE_BUILD_TYPE} MATCHES RelWithDebInfo)
257257
# Strip unnecessary sections of the binary
258258
pybind11_strip(PyQnnManagerAdaptor)
259259
pybind11_strip(PyQnnWrapperAdaptor)

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from .annotate_adaptive_avg_pool1d import AnnotateAdaptiveAvgPool1D
78
from .annotate_quant_attrs import AnnotateQuantAttrs
89
from .annotate_stack import AnnotateStack
910
from .annotate_unbind import AnnotateUnbind
@@ -38,6 +39,7 @@
3839

3940

4041
__all__ = [
42+
AnnotateAdaptiveAvgPool1D,
4143
AnnotateQuantAttrs,
4244
AnnotateStack,
4345
AnnotateUnbind,
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import torch
7+
from executorch.backends.qualcomm.builders.node_visitor import q_ops
8+
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
11+
12+
from .utils import get_quant_attrs
13+
14+
15+
class AnnotateAdaptiveAvgPool1D(ExportPass):
16+
"""
17+
Add "quant_attrs" to graph nodes' meta from the QDQ information
18+
generated after quantization process.
19+
adaptive_avg_pool1d got decomposed to unsqueeze -> adaptive_avg_pool2d -> squeeze
20+
"""
21+
22+
def __init__(self, edge_program: torch.export.ExportedProgram):
23+
super(AnnotateAdaptiveAvgPool1D, self).__init__()
24+
self.edge_program = edge_program
25+
26+
def _annotate_adaptive_avg_pool1d(self, graph_module: torch.fx.GraphModule):
27+
partitions = get_source_partitions(
28+
graph_module.graph, [torch.ops.aten.adaptive_avg_pool1d.default]
29+
)
30+
for src_partitions in partitions.values():
31+
for src_partition in src_partitions:
32+
output = src_partition.output_nodes[0]
33+
if (list(output.users)[0].target) in q_ops:
34+
quant_attrs = get_quant_attrs(
35+
self.edge_program, list(output.users)[0]
36+
)
37+
for n in src_partition.nodes:
38+
n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy()
39+
40+
def call(self, graph_module: torch.fx.GraphModule):
41+
self._annotate_adaptive_avg_pool1d(graph_module)
42+
graph_module.recompile()
43+
return PassResult(graph_module, True)

backends/qualcomm/_passes/annotate_quant_attrs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Dict
88

99
import torch
10+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops
1011
from executorch.backends.qualcomm.builders.utils import get_parameter
1112
from executorch.backends.qualcomm.utils.constants import (
1213
QCOM_DTYPE,
@@ -20,7 +21,7 @@
2021
)
2122
from executorch.exir.pass_base import ExportPass, PassResult
2223

23-
from .utils import dq_ops, get_quant_attrs, q_ops
24+
from .utils import get_quant_attrs
2425

2526

2627
class AnnotateQuantAttrs(ExportPass):

backends/qualcomm/_passes/annotate_stack.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7+
from executorch.backends.qualcomm.builders.node_visitor import q_ops
78
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
89
from executorch.exir.pass_base import ExportPass, PassResult
910
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1011

11-
from .utils import get_quant_attrs, q_ops
12+
from .utils import get_quant_attrs
1213

1314

1415
class AnnotateStack(ExportPass):
@@ -27,7 +28,7 @@ def _annotate_stack(self, graph_module: torch.fx.GraphModule):
2728
partitions = get_source_partitions(
2829
graph_module.graph, [torch.stack, torch.ops.aten.stack.default, "stack"]
2930
)
30-
for _, src_partitions in partitions.items():
31+
for src_partitions in partitions.values():
3132
for src_partition in src_partitions:
3233
output = src_partition.output_nodes[0]
3334
if (list(output.users)[0].target) in q_ops:

backends/qualcomm/_passes/annotate_unbind.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7+
8+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops
79
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
810
from executorch.exir.pass_base import ExportPass, PassResult
911
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1012

11-
from .utils import dq_ops, get_quant_attrs
13+
from .utils import get_quant_attrs
1214

1315

1416
class AnnotateUnbind(ExportPass):
@@ -27,7 +29,7 @@ def _annotate_unbind(self, graph_module: torch.fx.GraphModule):
2729
partitions = get_source_partitions(
2830
graph_module.graph, [torch.unbind, torch.ops.aten.unbind.int, "unbind"]
2931
)
30-
for _, src_partitions in partitions.items():
32+
for src_partitions in partitions.values():
3133
for src_partition in src_partitions:
3234
if src_partition.input_nodes[0].target in dq_ops:
3335
q_node = src_partition.input_nodes[0].args[0]

backends/qualcomm/_passes/expand_broadcast_tensor_shape.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8+
9+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops
810
from executorch.exir.dialects._ops import ops as exir_ops
911
from executorch.exir.pass_base import ExportPass, PassResult
1012
from executorch.exir.passes import dead_code_elimination_pass
1113

12-
from .utils import dq_ops
13-
1414

1515
class ExpandBroadcastTensorShape(ExportPass):
1616
"""

backends/qualcomm/_passes/fold_qdq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7+
8+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops
79
from executorch.backends.qualcomm.builders.utils import is_parameter
810
from executorch.backends.qualcomm.utils.constants import QCOM_BYPASS_NODE
911
from executorch.exir.dialects._ops import ops as exir_ops
1012
from executorch.exir.pass_base import ExportPass, PassResult
1113
from executorch.exir.passes import dead_code_elimination_pass
1214

13-
from .utils import dq_ops, q_ops
14-
1515

1616
class FoldQDQ(ExportPass):
1717
"""

backends/qualcomm/_passes/insert_io_qdq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import torch
99

10+
from executorch.backends.qualcomm.builders.node_visitor import q_ops
11+
1012
from executorch.backends.qualcomm.builders.utils import is_parameter
1113
from executorch.backends.qualcomm.utils.constants import (
1214
QCOM_ENCODING,
@@ -16,8 +18,6 @@
1618
from executorch.exir.dialects._ops import ops as exir_ops
1719
from executorch.exir.pass_base import ExportPass, PassResult
1820

19-
from .utils import q_ops
20-
2121

2222
class InsertIOQDQ(ExportPass):
2323
"""

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Dict
1010

1111
from executorch.backends.qualcomm._passes import (
12+
AnnotateAdaptiveAvgPool1D,
1213
AnnotateQuantAttrs,
1314
AnnotateStack,
1415
AnnotateUnbind,
@@ -73,6 +74,7 @@ def get_capture_program_passes():
7374
# The second value in each tuple in `default_passes_and_setting` indicates whether the corresponding pass is activated by default.
7475
# If a pass is activated, it will be executed by default.
7576
default_passes_and_setting = [
77+
(AnnotateAdaptiveAvgPool1D, True),
7678
(AnnotateQuantAttrs, True),
7779
(AnnotateStack, True),
7880
(AnnotateUnbind, True),
@@ -128,11 +130,11 @@ def get_to_edge_transform_passes(
128130
dep_table: Dict = None,
129131
):
130132
# TODO: remove this workaround when target could be correctly detected
131-
from executorch.backends.qualcomm._passes import utils
133+
from executorch.backends.qualcomm.builders import node_visitor
132134
from executorch.exir.dialects._ops import ops as exir_ops
133135

134-
utils.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default)
135-
utils.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default)
136+
node_visitor.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default)
137+
node_visitor.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default)
136138

137139
passes_job = (
138140
passes_job if passes_job is not None else get_capture_program_passes()

0 commit comments

Comments
 (0)