Skip to content

Commit 9314b68

Browse files
committed
Qualcomm AI Engine Direct - Enable custom operator
Summary: - Support to register op package in QNN Backend - Add example script to run torch custom op with QNN Op package - Allow op package override torch built-in operator - Add op package example - Modify the flag of dlopen for QNN library - Generate custom op based on the meta and _schema.arguments of torch.fx.Node - Add README for the custom op
1 parent 64c4b33 commit 9314b68

File tree

112 files changed

+2079
-156
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

112 files changed

+2079
-156
lines changed

backends/qualcomm/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ add_library(qnn_implementation STATIC)
130130
add_library(qnn_logger STATIC)
131131
add_library(qnn_manager STATIC)
132132
add_library(qnn_mem_manager STATIC)
133+
add_library(qnn_op_package_manager STATIC)
133134
add_library(qnn_profiler STATIC)
134135
add_library(qnn_schema INTERFACE ${_qnn_schema__outputs})
135136
add_library(qnn_sys_function_interface INTERFACE)
@@ -152,13 +153,13 @@ target_link_libraries(
152153
target_link_libraries(qnn_executorch_logging PRIVATE qnn_schema)
153154
target_link_libraries(qnn_profiler PRIVATE qnn_executorch_logging)
154155
target_link_libraries(qnn_logger PRIVATE qnn_implementation ${android_log})
155-
target_link_libraries(qnn_backend PRIVATE qnn_implementation qnn_logger)
156+
target_link_libraries(qnn_backend PRIVATE qnn_implementation qnn_logger qnn_op_package_manager)
156157
target_link_libraries(qnn_custom_protocol PRIVATE qcir_utils)
157158
target_link_libraries(
158159
qnn_device PRIVATE qnn_executorch_logging qnn_implementation qnn_logger
159160
)
160161
target_link_libraries(
161-
qnn_backend_cache PRIVATE qnn_sys_implementation qcir_utils
162+
qnn_backend_cache PRIVATE qnn_sys_implementation qcir_utils qnn_schema
162163
)
163164
target_link_libraries(
164165
qnn_context PRIVATE qnn_implementation qnn_logger qnn_backend qnn_device

backends/qualcomm/builders/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ import torch
176176
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
177177
# op builder will inherit NodeVisitor and have its own implementation
178178
# register_node_visitor for book-keeping the dictionary of target name v.s. callback
179-
from .node_visitor import NodeVisitor, register_node_visitor
179+
from .node_visitor import NodeVisitor
180+
from .node_visitor_manager import register_node_visitor
180181
# the definitions required to build operator in QNN
181182
from .qnn_constants import OpLayerNorm, QNN_OP_PACKAGE_NAME_QTI_AISW
182183
# utility to get parameter value when creating tensor in QNN

backends/qualcomm/builders/node_visitor.py

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@
6464
torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_64,
6565
torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_8,
6666
torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16,
67+
torch.uint32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
6768
float: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
69+
int: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
6870
}
6971

7072
PER_CHANNEL_ENCODING = {
@@ -459,51 +461,3 @@ def define_node(
459461
) -> PyQnnWrapper.PyQnnOpWrapper:
460462
"""Convert torch.fx.Node to OpWrapper"""
461463
raise NotImplementedError("NodeVisitor must be extended!")
462-
463-
464-
# This will hold mapping of all node names to the visitor class
465-
_node_visitor_dict = {}
466-
467-
468-
def register_node_visitor(visitor):
469-
"""Register node visitor into _node_visitor_dict"""
470-
assert (
471-
isinstance(visitor, type)
472-
and issubclass(visitor, NodeVisitor)
473-
and hasattr(visitor, "target")
474-
), f"Illformed NodeVisitor subclass, can't register!, got: {visitor}"
475-
for target in visitor.target:
476-
_node_visitor_dict[target] = visitor
477-
478-
479-
def generate_node_to_external_map(
480-
edge_program: torch.export.ExportedProgram,
481-
) -> Dict[torch.fx.Node, int]:
482-
node_to_external_map = {}
483-
for node in edge_program.graph_module.graph.nodes:
484-
# The order in which we visit the placeholder node is same as the *args
485-
# order for the forward(*args) signature for this gm. Using the order of
486-
# the nodes as external_id to extract the right arg from *args at runtime
487-
if is_graph_input(node, edge_program):
488-
node_to_external_map[node] = len(node_to_external_map)
489-
for node in edge_program.graph_module.graph.nodes:
490-
if is_graph_output(node):
491-
node_to_external_map[node] = len(node_to_external_map)
492-
return node_to_external_map
493-
494-
495-
def get_node_visitors(
496-
edge_program: torch.export.ExportedProgram,
497-
enable_tensor_dump=False,
498-
) -> Dict[str, NodeVisitor]:
499-
"""Create a new class instance at runtime, and put them in a dict"""
500-
node_to_external_map = generate_node_to_external_map(edge_program)
501-
node_visitors = {}
502-
for target, visitor in _node_visitor_dict.items():
503-
assert callable(
504-
visitor
505-
), f"Expeting a callable class, but got {visitor} of type {type(visitor)}"
506-
node_visitors[target] = visitor(
507-
node_to_external_map, edge_program, enable_tensor_dump
508-
)
509-
return node_visitors
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict, List
8+
9+
import torch
10+
from executorch.backends.qualcomm.serialization.qc_schema import (
11+
QnnExecuTorchOpPackageInfo,
12+
)
13+
14+
from .node_visitor import NodeVisitor
15+
from .op_custom_op import CustomOp
16+
from .utils import is_graph_input, is_graph_output
17+
18+
19+
# This will hold mapping of all node names to the visitor class
20+
_node_visitor_dict = {}
21+
22+
23+
def register_node_visitor(visitor):
24+
"""Register node visitor into _node_visitor_dict"""
25+
assert (
26+
isinstance(visitor, type)
27+
and issubclass(visitor, NodeVisitor)
28+
and hasattr(visitor, "target")
29+
), f"Informed NodeVisitor subclass, can't register!, got: {visitor}"
30+
for target in visitor.target:
31+
_node_visitor_dict[target] = visitor
32+
33+
34+
def generate_node_to_external_map(
35+
edge_program: torch.export.ExportedProgram,
36+
) -> Dict[torch.fx.Node, int]:
37+
node_to_external_map = {}
38+
for node in edge_program.graph_module.graph.nodes:
39+
# The order in which we visit the placeholder node is same as the *args
40+
# order for the forward(*args) signature for this gm. Using the order of
41+
# the nodes as external_id to extract the right arg from *args at runtime
42+
if is_graph_input(node, edge_program):
43+
node_to_external_map[node] = len(node_to_external_map)
44+
for node in edge_program.graph_module.graph.nodes:
45+
if is_graph_output(node):
46+
node_to_external_map[node] = len(node_to_external_map)
47+
return node_to_external_map
48+
49+
50+
def get_node_visitors(
51+
edge_program: torch.export.ExportedProgram,
52+
enable_tensor_dump=False,
53+
op_package_infos: List[QnnExecuTorchOpPackageInfo] = None,
54+
) -> Dict[str, NodeVisitor]:
55+
"""Create a new class instance at runtime, and put them in a dict"""
56+
node_to_external_map = generate_node_to_external_map(edge_program)
57+
node_visitors = {}
58+
for target, visitor in _node_visitor_dict.items():
59+
assert callable(
60+
visitor
61+
), f"Expecting a callable class, but got {visitor} of type {type(visitor)}"
62+
node_visitors[target] = visitor(
63+
node_to_external_map, edge_program, enable_tensor_dump
64+
)
65+
if op_package_infos:
66+
custom_ops = []
67+
for op_package_info in op_package_infos:
68+
if op_package_info.custom_op_name not in custom_ops:
69+
custom_op_builder = CustomOp(
70+
op_package_info,
71+
node_to_external_map,
72+
edge_program,
73+
enable_tensor_dump,
74+
)
75+
node_visitors[op_package_info.custom_op_name] = custom_op_builder
76+
custom_ops.append(op_package_info.custom_op_name)
77+
return node_visitors

backends/qualcomm/builders/op_abs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import torch
1111

12-
from .node_visitor import NodeVisitor, register_node_visitor
12+
from .node_visitor import NodeVisitor
13+
from .node_visitor_manager import register_node_visitor
1314
from .qnn_constants import OpElementWiseAbs, QNN_OP_PACKAGE_NAME_QTI_AISW
1415

1516

backends/qualcomm/builders/op_adaptive_avg_pool2d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
import torch
1313

14-
from .node_visitor import NodeVisitor, register_node_visitor
14+
from .node_visitor import NodeVisitor
15+
from .node_visitor_manager import register_node_visitor
1516
from .qnn_constants import OpPoolAvg2d, QNN_OP_PACKAGE_NAME_QTI_AISW
1617

1718

backends/qualcomm/builders/op_add.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import torch
1111

12-
from .node_visitor import NodeVisitor, register_node_visitor
12+
from .node_visitor import NodeVisitor
13+
from .node_visitor_manager import register_node_visitor
1314
from .qnn_constants import OpElementWiseAdd, QNN_OP_PACKAGE_NAME_QTI_AISW
1415

1516

backends/qualcomm/builders/op_amax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
import torch
1313
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
1414

15-
from .node_visitor import NodeVisitor, register_node_visitor
15+
from .node_visitor import NodeVisitor
16+
from .node_visitor_manager import register_node_visitor
1617
from .qnn_constants import OpReduceMax, QNN_OP_PACKAGE_NAME_QTI_AISW
1718

1819

backends/qualcomm/builders/op_and.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import torch
1111

12-
from .node_visitor import NodeVisitor, register_node_visitor
12+
from .node_visitor import NodeVisitor
13+
from .node_visitor_manager import register_node_visitor
1314
from .qnn_constants import OpElementWiseAnd, QNN_OP_PACKAGE_NAME_QTI_AISW
1415

1516

backends/qualcomm/builders/op_arange.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import torch
1111

12-
from .node_visitor import NodeVisitor, register_node_visitor
12+
from .node_visitor import NodeVisitor
13+
from .node_visitor_manager import register_node_visitor
1314

1415

1516
@register_node_visitor

0 commit comments

Comments
 (0)