Skip to content

Commit 82684dd

Browse files
committed
Update on "[ET-VK] Allow specifying multiple storage types/memory layouts for an operator + register group norm operator"
## Changes * Handle cases where an operator needs to specify a separate storage type / memory layout for each individual output. ## Motivation Required for the group norm operator. ## Future Work Currently, the `tag_memory_meta_pass` graph pass assumes that all tensors participating in a computation (aside from weights) will have the same storage type and memory layout. As more operators are being added, there are more exceptions to this rule. The pass may need an update in the near future to make it possible to specify required storage types and memory layouts on a more granular level. Differential Revision: [D77038781](https://our.internmc.facebook.com/intern/diff/D77038781/) [ghstack-poisoned]
2 parents 7c31608 + 6615378 commit 82684dd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+681
-229
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ It supports a wide range of models including LLMs (Large Language Models), CV (C
1919
Platform Support:
2020
- Operating Systems:
2121
- iOS
22-
- Mac
22+
- MacOS (ARM64)
2323
- Android
2424
- Linux
2525
- Microcontrollers

backends/arm/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ if(NOT EXECUTORCH_ROOT)
1212
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
1313
endif()
1414

15+
add_compile_options("-Wall" "-Werror")
16+
1517
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
1618

1719
set(_common_include_directories ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10)

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .annotate_quant_attrs import AnnotateQuantAttrs
99
from .annotate_stack import AnnotateStack
1010
from .annotate_unbind import AnnotateUnbind
11+
from .convert_bmm_to_matmul import ConvertBmmToMatmul
1112
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
1213
from .convert_square_to_pow import ConvertSquareToPow
1314
from .decompose_any import DecomposeAny
@@ -35,7 +36,6 @@
3536
from .remove_0d_tensor import Remove0DTensor
3637
from .remove_redundancy import RemoveRedundancy
3738
from .replace_arange_args import ReplaceArangeArgs
38-
from .replace_index_put_input import ReplaceIndexPutInput
3939
from .replace_inf_values import ReplaceInfValues
4040
from .tag_quant_io import TagQuantIO
4141

@@ -45,6 +45,7 @@
4545
AnnotateQuantAttrs,
4646
AnnotateStack,
4747
AnnotateUnbind,
48+
ConvertBmmToMatmul,
4849
ConvertConv1dToConv2d,
4950
ConvertSquareToPow,
5051
DecomposeAny,
@@ -72,7 +73,6 @@
7273
Remove0DTensor,
7374
RemoveRedundancy,
7475
ReplaceArangeArgs,
75-
ReplaceIndexPutInput,
7676
ReplaceInfValues,
7777
TagQuantIO,
7878
]
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import operator
7+
from collections import Counter
8+
from typing import List
9+
10+
import torch
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
14+
15+
16+
class ConvertBmmToMatmul(ExportPass):
17+
"""
18+
Replace bmm to matmul, because bmm is eqaul to matmul in QNN.
19+
Handle missing quantization tag for bmm op.
20+
"""
21+
22+
view_copy = exir_ops.edge.aten.view_copy.default
23+
expand_copy = exir_ops.edge.aten.expand_copy.default
24+
clone = exir_ops.edge.aten.clone.default
25+
bmm = exir_ops.edge.aten.bmm.default
26+
matmul = exir_ops.edge.aten.matmul.default
27+
patterns = [
28+
{expand_copy: 2, view_copy: 3, bmm: 1},
29+
{expand_copy: 2, view_copy: 3, bmm: 1, clone: 1},
30+
{bmm: 1},
31+
]
32+
33+
def __init__(self):
34+
super(ConvertBmmToMatmul, self).__init__()
35+
36+
def _get_ordered_inputs(
37+
self, inputs: List[torch.fx.Node], output: torch.fx.Node
38+
) -> List[torch.fx.Node]:
39+
bmm_inputs = []
40+
for arg in output.args:
41+
while arg not in inputs:
42+
arg = arg.args[0]
43+
bmm_inputs.append(arg)
44+
return bmm_inputs
45+
46+
def call(self, graph_module: torch.fx.GraphModule):
47+
graph = graph_module.graph
48+
partitions = get_source_partitions(
49+
graph,
50+
[operator.matmul, torch.matmul, torch.bmm, torch.ops.aten.matmul.default],
51+
)
52+
for _, src_partitions in partitions.items():
53+
for src_partition in src_partitions:
54+
op_cnt = Counter([n.target for n in src_partition.nodes])
55+
if op_cnt not in self.patterns:
56+
raise AssertionError(
57+
"Found a new pattern needed be converted to linear op"
58+
)
59+
60+
inputs = src_partition.input_nodes
61+
bmm_node = [n for n in src_partition.nodes if n.target == self.bmm][0]
62+
output = src_partition.output_nodes[0]
63+
# the order of src_partition.inputs is not guaranteed.
64+
lhs, rhs = self._get_ordered_inputs(inputs, bmm_node)
65+
with graph_module.graph.inserting_before(output):
66+
# replace bmm to matmul, because bmm is eqaul to matmul in qnn.
67+
matmul_node = graph.create_node(
68+
"call_function", self.matmul, (lhs, rhs)
69+
)
70+
matmul_node.meta = output.meta
71+
for user in output.users.copy():
72+
user.replace_input_with(output, matmul_node)
73+
74+
graph.eliminate_dead_code()
75+
graph_module.recompile()
76+
return PassResult(graph_module, True)

backends/qualcomm/_passes/insert_io_qdq.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99

1010
from executorch.backends.qualcomm.builders.node_visitor import q_ops
1111

12-
from executorch.backends.qualcomm.builders.utils import is_parameter
12+
from executorch.backends.qualcomm.builders.utils import (
13+
is_mutable_buffer_input,
14+
is_parameter,
15+
)
1316
from executorch.backends.qualcomm.utils.constants import (
1417
QCOM_ENCODING,
1518
QCOM_QUANT_ATTRS,
@@ -124,7 +127,10 @@ def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
124127
if (
125128
n.op == "placeholder"
126129
and n.meta.get(QCOM_QUANT_ATTRS)
127-
and not is_parameter(n, self.edge_program)
130+
and (
131+
not is_parameter(n, self.edge_program)
132+
or is_mutable_buffer_input(n, self.edge_program)
133+
)
128134
):
129135
self._insert_quant_node(
130136
graph_module, n, n.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING]

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
AnnotateQuantAttrs,
1414
AnnotateStack,
1515
AnnotateUnbind,
16+
ConvertBmmToMatmul,
1617
ConvertConv1dToConv2d,
1718
ConvertSquareToPow,
1819
DecomposeAny,
@@ -40,7 +41,6 @@
4041
Remove0DTensor,
4142
RemoveRedundancy,
4243
ReplaceArangeArgs,
43-
ReplaceIndexPutInput,
4444
ReplaceInfValues,
4545
TagQuantIO,
4646
)
@@ -80,6 +80,7 @@ def get_capture_program_passes():
8080
(AnnotateQuantAttrs, True),
8181
(AnnotateStack, True),
8282
(AnnotateUnbind, True),
83+
(ConvertBmmToMatmul, False),
8384
(ConvertConv1dToConv2d, True),
8485
(DecomposeAny, True),
8586
(DecomposeColIm, True),
@@ -92,7 +93,6 @@ def get_capture_program_passes():
9293
(RecomposeRmsNorm, False),
9394
(Remove0DTensor, True),
9495
(RemoveRedundancy, True),
95-
(ReplaceIndexPutInput, True),
9696
(TagQuantIO, False),
9797
]
9898

@@ -224,4 +224,11 @@ def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram):
224224
self.add_pass(LayoutTransform(exported_program, insert_permute=True))
225225
self.add_pass(FuseConsecutiveCast())
226226
self.add_pass(FuseConsecutiveTranspose())
227-
return self._transform(exported_program.graph_module)
227+
self._transform(exported_program.graph_module)
228+
# Update inputs_to_buffers and buffers_to_mutate in graph signature for mutable buffer
229+
# Since I/O will be inserted Q/DQ, it results in failed to mapping output node names and buffer
230+
exported_program._graph_signature = _get_updated_graph_signature(
231+
exported_program.graph_signature,
232+
exported_program.graph_module,
233+
)
234+
return exported_program.graph_module

backends/qualcomm/_passes/replace_index_put_input.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

backends/qualcomm/_passes/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def get_passes_dependency_for_capture_program():
6464
AnnotateQuantAttrs,
6565
AnnotateStack,
6666
AnnotateUnbind,
67+
ConvertBmmToMatmul,
6768
ConvertConv1dToConv2d,
6869
DecomposeAny,
6970
DecomposeColIm,
@@ -76,18 +77,19 @@ def get_passes_dependency_for_capture_program():
7677
RecomposePixelUnshuffle,
7778
RecomposeRmsNorm,
7879
RemoveRedundancy,
79-
ReplaceIndexPutInput,
8080
TagQuantIO,
8181
)
8282

8383
return {
8484
AnnotateAdaptiveAvgPool1D: [RemoveRedundancy],
8585
AnnotateQuantAttrs: [
86+
ConvertBmmToMatmul,
8687
RecomposePixelUnshuffle,
8788
RemoveRedundancy,
8889
],
8990
AnnotateStack: [RemoveRedundancy],
9091
AnnotateUnbind: [RemoveRedundancy],
92+
ConvertBmmToMatmul: [RecomposePixelUnshuffle],
9193
DecomposeAny: [RemoveRedundancy],
9294
DecomposeColIm: [FoldQDQ],
9395
DecomposeLinalgVectorNorm: [RemoveRedundancy],
@@ -103,8 +105,7 @@ def get_passes_dependency_for_capture_program():
103105
],
104106
RecomposePixelUnshuffle: [RemoveRedundancy],
105107
RecomposeRmsNorm: [RemoveRedundancy],
106-
ReplaceIndexPutInput: [LayoutTransform],
107-
TagQuantIO: [ReplaceIndexPutInput],
108+
TagQuantIO: [LayoutTransform],
108109
}
109110

110111

backends/qualcomm/builders/node_visitor.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
get_parameter,
4242
is_graph_input,
4343
is_graph_output,
44+
is_mutable_buffer_input,
45+
is_mutable_buffer_output,
4446
is_parameter,
4547
)
4648

@@ -307,7 +309,9 @@ def get_tensor_type(
307309
node: torch.fx.Node,
308310
tensor_type: PyQnnWrapper.Qnn_TensorType_t,
309311
) -> PyQnnWrapper.Qnn_TensorType_t:
310-
is_input = is_graph_input(node, self.edge_program)
312+
is_input = is_graph_input(node, self.edge_program) or is_mutable_buffer_input(
313+
node, self.edge_program
314+
)
311315
is_output = is_graph_output(node)
312316
# handle logic for input/output tensors
313317
if is_input or is_output:
@@ -352,6 +356,33 @@ def get_dynamic_dimension(self, dims):
352356

353357
return dynamic_dims if any(dynamic_dims) else [], nominal_dims
354358

359+
def get_tensor_name(
360+
self,
361+
node: torch.fx.Node,
362+
wrapper_idx: int = 0,
363+
):
364+
tensor_name = f"{node.name}_{wrapper_idx}"
365+
# The `input_{id}` is utilized for sorting at runtime. Due to multiple passes in qnn_preprocess,
366+
# the input order between QNN and the original graph’s forward function may differ.
367+
# The `mutbuf_{id}` is utilized for mapping I/O of mutable buffer at runtime.
368+
# The `output_` is identified as the graph’s output at runtime to prevent confusion with per_tensor_dump.
369+
if is_mutable_buffer_input(node, self.edge_program):
370+
fqn = self.edge_program.graph_signature.inputs_to_buffers[node.target]
371+
position_index = list(
372+
self.edge_program.graph_signature.buffers_to_mutate.values()
373+
).index(fqn)
374+
tensor_name = f"input_{str(self.external_ids[node])}_mutbuf_{str(position_index)}_{tensor_name}"
375+
elif is_graph_input(node, self.edge_program):
376+
tensor_name = f"input_{str(self.external_ids[node])}_{tensor_name}"
377+
elif is_mutable_buffer_output(node, self.edge_program):
378+
position_index = list(
379+
self.edge_program.graph_signature.buffers_to_mutate.keys()
380+
).index(node.name)
381+
tensor_name = f"output_mutbuf_{position_index}_{tensor_name}"
382+
elif is_graph_output(node):
383+
tensor_name = f"output_{tensor_name}"
384+
return tensor_name
385+
355386
def define_custom_tensor_wrapper(
356387
self,
357388
node_name: str,
@@ -413,16 +444,7 @@ def define_tensor(
413444
if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
414445
return cached
415446

416-
tensor_name = f"{tensor_source_node.name}_{wrapper_idx}"
417-
if is_graph_input(tensor_source_node, self.edge_program):
418-
tensor_name = (
419-
"input_"
420-
+ str(self.external_ids[tensor_source_node])
421-
+ "_"
422-
+ tensor_name
423-
)
424-
if is_graph_output(tensor_source_node):
425-
tensor_name = "output_" + tensor_name
447+
tensor_name = self.get_tensor_name(tensor_source_node, wrapper_idx)
426448
dims = torch.Size([1]) if len(tensor.size()) == 0 else tensor.size()
427449
dynamic_dims, nominal_dims = self.get_dynamic_dimension(dims)
428450
tensor_type = self.get_tensor_type(tensor_source_node, tensor_type)

backends/qualcomm/builders/node_visitor_manager.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from .node_visitor import NodeVisitor
1515
from .op_custom_op import CustomOp
16-
from .utils import is_graph_input, is_graph_output
16+
from .utils import is_graph_input, is_graph_output, is_mutable_buffer_input
1717

1818

1919
# This will hold mapping of all node names to the visitor class
@@ -39,7 +39,9 @@ def generate_node_to_external_map(
3939
# The order in which we visit the placeholder node is same as the *args
4040
# order for the forward(*args) signature for this gm. Using the order of
4141
# the nodes as external_id to extract the right arg from *args at runtime
42-
if is_graph_input(node, edge_program):
42+
if is_graph_input(node, edge_program) or is_mutable_buffer_input(
43+
node, edge_program
44+
):
4345
node_to_external_map[node] = len(node_to_external_map)
4446
for node in edge_program.graph_module.graph.nodes:
4547
if is_graph_output(node):

0 commit comments

Comments
 (0)