Skip to content

Commit e922f14

Browse files
committed
Qualcomm AI Engine Direct - Delegate mutable buffer and fix the mutable buffer issue
Summary: - Add a parameter to support mutable buffer delegation in QNN Backend - Set the same memory address for I/O of mutable buffer at runtime - Avoid annotating the input node because mutable buffers will be folded during the convert_pt2e process. - Deprecated use_legacy_export in executorch llama
1 parent 44d2643 commit e922f14

File tree

22 files changed

+357
-177
lines changed

22 files changed

+357
-177
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from .remove_0d_tensor import Remove0DTensor
3636
from .remove_redundancy import RemoveRedundancy
3737
from .replace_arange_args import ReplaceArangeArgs
38-
from .replace_index_put_input import ReplaceIndexPutInput
3938
from .replace_inf_values import ReplaceInfValues
4039
from .tag_quant_io import TagQuantIO
4140

@@ -72,7 +71,6 @@
7271
Remove0DTensor,
7372
RemoveRedundancy,
7473
ReplaceArangeArgs,
75-
ReplaceIndexPutInput,
7674
ReplaceInfValues,
7775
TagQuantIO,
7876
]

backends/qualcomm/_passes/insert_io_qdq.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99

1010
from executorch.backends.qualcomm.builders.node_visitor import q_ops
1111

12-
from executorch.backends.qualcomm.builders.utils import is_parameter
12+
from executorch.backends.qualcomm.builders.utils import (
13+
is_mutable_buffer_input,
14+
is_parameter,
15+
)
1316
from executorch.backends.qualcomm.utils.constants import (
1417
QCOM_ENCODING,
1518
QCOM_QUANT_ATTRS,
@@ -124,7 +127,10 @@ def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
124127
if (
125128
n.op == "placeholder"
126129
and n.meta.get(QCOM_QUANT_ATTRS)
127-
and not is_parameter(n, self.edge_program)
130+
and (
131+
not is_parameter(n, self.edge_program)
132+
or is_mutable_buffer_input(n, self.edge_program)
133+
)
128134
):
129135
self._insert_quant_node(
130136
graph_module, n, n.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING]

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
Remove0DTensor,
4141
RemoveRedundancy,
4242
ReplaceArangeArgs,
43-
ReplaceIndexPutInput,
4443
ReplaceInfValues,
4544
TagQuantIO,
4645
)
@@ -92,7 +91,6 @@ def get_capture_program_passes():
9291
(RecomposeRmsNorm, False),
9392
(Remove0DTensor, True),
9493
(RemoveRedundancy, True),
95-
(ReplaceIndexPutInput, True),
9694
(TagQuantIO, False),
9795
]
9896

@@ -224,4 +222,11 @@ def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram):
224222
self.add_pass(LayoutTransform(exported_program, insert_permute=True))
225223
self.add_pass(FuseConsecutiveCast())
226224
self.add_pass(FuseConsecutiveTranspose())
227-
return self._transform(exported_program.graph_module)
225+
self._transform(exported_program.graph_module)
226+
# Update inputs_to_buffers and buffers_to_mutate in graph signature for mutable buffer
227+
# Since I/O will be inserted Q/DQ, it results in failed to mapping output node names and buffer
228+
exported_program._graph_signature = _get_updated_graph_signature(
229+
exported_program.graph_signature,
230+
exported_program.graph_module,
231+
)
232+
return exported_program.graph_module

backends/qualcomm/_passes/replace_index_put_input.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

backends/qualcomm/_passes/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def get_passes_dependency_for_capture_program():
7676
RecomposePixelUnshuffle,
7777
RecomposeRmsNorm,
7878
RemoveRedundancy,
79-
ReplaceIndexPutInput,
8079
TagQuantIO,
8180
)
8281

@@ -103,8 +102,7 @@ def get_passes_dependency_for_capture_program():
103102
],
104103
RecomposePixelUnshuffle: [RemoveRedundancy],
105104
RecomposeRmsNorm: [RemoveRedundancy],
106-
ReplaceIndexPutInput: [LayoutTransform],
107-
TagQuantIO: [ReplaceIndexPutInput],
105+
TagQuantIO: [LayoutTransform],
108106
}
109107

110108

backends/qualcomm/builders/node_visitor.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
get_parameter,
4242
is_graph_input,
4343
is_graph_output,
44+
is_mutable_buffer_input,
45+
is_mutable_buffer_output,
4446
is_parameter,
4547
)
4648

@@ -307,7 +309,9 @@ def get_tensor_type(
307309
node: torch.fx.Node,
308310
tensor_type: PyQnnWrapper.Qnn_TensorType_t,
309311
) -> PyQnnWrapper.Qnn_TensorType_t:
310-
is_input = is_graph_input(node, self.edge_program)
312+
is_input = is_graph_input(node, self.edge_program) or is_mutable_buffer_input(
313+
node, self.edge_program
314+
)
311315
is_output = is_graph_output(node)
312316
# handle logic for input/output tensors
313317
if is_input or is_output:
@@ -352,6 +356,33 @@ def get_dynamic_dimension(self, dims):
352356

353357
return dynamic_dims if any(dynamic_dims) else [], nominal_dims
354358

359+
def get_tensor_name(
360+
self,
361+
node: torch.fx.Node,
362+
wrapper_idx: int = 0,
363+
):
364+
tensor_name = f"{node.name}_{wrapper_idx}"
365+
# The `input_{id}` is utilized for sorting at runtime. Due to multiple passes in qnn_preprocess,
366+
# the input order between QNN and the original graph’s forward function may differ.
367+
# The `mutbuf_{id}` is utilized for mapping I/O of mutable buffer at runtime.
368+
# The `output_` is identified as the graph’s output at runtime to prevent confusion with per_tensor_dump.
369+
if is_mutable_buffer_input(node, self.edge_program):
370+
fqn = self.edge_program.graph_signature.inputs_to_buffers[node.target]
371+
position_index = list(
372+
self.edge_program.graph_signature.buffers_to_mutate.values()
373+
).index(fqn)
374+
tensor_name = f"input_{str(self.external_ids[node])}_mutbuf_{str(position_index)}_{tensor_name}"
375+
elif is_graph_input(node, self.edge_program):
376+
tensor_name = f"input_{str(self.external_ids[node])}_{tensor_name}"
377+
elif is_mutable_buffer_output(node, self.edge_program):
378+
position_index = list(
379+
self.edge_program.graph_signature.buffers_to_mutate.keys()
380+
).index(node.name)
381+
tensor_name = f"output_mutbuf_{position_index}_{tensor_name}"
382+
elif is_graph_output(node):
383+
tensor_name = f"output_{tensor_name}"
384+
return tensor_name
385+
355386
def define_custom_tensor_wrapper(
356387
self,
357388
node_name: str,
@@ -413,16 +444,7 @@ def define_tensor(
413444
if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
414445
return cached
415446

416-
tensor_name = f"{tensor_source_node.name}_{wrapper_idx}"
417-
if is_graph_input(tensor_source_node, self.edge_program):
418-
tensor_name = (
419-
"input_"
420-
+ str(self.external_ids[tensor_source_node])
421-
+ "_"
422-
+ tensor_name
423-
)
424-
if is_graph_output(tensor_source_node):
425-
tensor_name = "output_" + tensor_name
447+
tensor_name = self.get_tensor_name(tensor_source_node, wrapper_idx)
426448
dims = torch.Size([1]) if len(tensor.size()) == 0 else tensor.size()
427449
dynamic_dims, nominal_dims = self.get_dynamic_dimension(dims)
428450
tensor_type = self.get_tensor_type(tensor_source_node, tensor_type)

backends/qualcomm/builders/node_visitor_manager.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from .node_visitor import NodeVisitor
1515
from .op_custom_op import CustomOp
16-
from .utils import is_graph_input, is_graph_output
16+
from .utils import is_graph_input, is_graph_output, is_mutable_buffer_input
1717

1818

1919
# This will hold mapping of all node names to the visitor class
@@ -39,7 +39,9 @@ def generate_node_to_external_map(
3939
# The order in which we visit the placeholder node is same as the *args
4040
# order for the forward(*args) signature for this gm. Using the order of
4141
# the nodes as external_id to extract the right arg from *args at runtime
42-
if is_graph_input(node, edge_program):
42+
if is_graph_input(node, edge_program) or is_mutable_buffer_input(
43+
node, edge_program
44+
):
4345
node_to_external_map[node] = len(node_to_external_map)
4446
for node in edge_program.graph_module.graph.nodes:
4547
if is_graph_output(node):

backends/qualcomm/builders/op_index_put.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from typing import Dict
22

33
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
4-
54
import torch
65

6+
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
7+
78
from .node_visitor import NodeVisitor
89
from .node_visitor_manager import register_node_visitor
910
from .qnn_constants import OpScatterNd, QNN_OP_PACKAGE_NAME_QTI_AISW
@@ -22,6 +23,10 @@ def define_node(
2223
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
2324
) -> PyQnnWrapper.PyQnnOpWrapper:
2425
input_node = self.get_node(node.args[0])
26+
# Because the args[0] of index_put op doesn't annotate, need to fill in the quant_attr with the node here.
27+
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
28+
quant_attrs = quant_attrs.copy()
29+
input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
2530
input_tensor = self.get_tensor(input_node, node)
2631
input_tensor_wrapper = self.define_tensor(
2732
input_node,

backends/qualcomm/builders/utils.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,20 @@ def is_graph_input(
7575
return tensor.op == "placeholder" and not is_parameter(tensor, edge_program)
7676

7777

78+
def is_mutable_buffer_input(
79+
tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram
80+
) -> bool:
81+
"""
82+
Check if the given tensor is a mutable buffer input
83+
Args:
84+
tensor: EdgeIR Tensor that is being checked for mutable buffer input
85+
"""
86+
if tensor.op == "placeholder" and is_buffer(edge_program, tensor):
87+
fqn = edge_program.graph_signature.inputs_to_buffers[tensor.target]
88+
# if the buffer is mutated then record that
89+
return fqn in edge_program.graph_signature.buffers_to_mutate.values()
90+
91+
7892
def is_graph_output(node: torch.fx.Node) -> bool:
7993
"""
8094
Check if the given tensor is used as a graph output
@@ -83,14 +97,33 @@ def is_graph_output(node: torch.fx.Node) -> bool:
8397
tensor: EdgeIR Tensor that is being checked for graph input
8498
"""
8599
for user in node.users.keys():
86-
# getitem node is skiped, check the op_skip_ops.py
100+
# getitem node is skipped, check the op_skip_ops.py
87101
if user.op == "output" or (
88102
user.target.__name__ == "getitem" and is_graph_output(user)
89103
):
90104
return True
91105
return False
92106

93107

108+
def is_mutable_buffer_output(
109+
tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram
110+
) -> bool:
111+
"""
112+
Check if the given tensor is a mutable buffer output
113+
Args:
114+
tensor: EdgeIR Tensor that is being checked for mutable buffer output
115+
"""
116+
return (
117+
any(
118+
user.op == "output"
119+
or user.target.__name__ == "getitem"
120+
and is_graph_output(user)
121+
for user in tensor.users.keys()
122+
)
123+
and tensor.name in edge_program.graph_signature.buffers_to_mutate.keys()
124+
)
125+
126+
94127
def is_constant(
95128
tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram
96129
) -> bool:

backends/qualcomm/partition/qnn_partitioner.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import copy
7+
import logging
78
from collections import defaultdict
89
from typing import Any, Callable, Dict, List, Optional, Tuple
910

@@ -29,7 +30,7 @@
2930
Partitioner,
3031
PartitionResult,
3132
)
32-
from executorch.exir.backend.utils import tag_constant_data
33+
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
3334
from torch.export.exported_program import ExportedProgram
3435
from torch.fx.passes.infra.partitioner import Partition
3536
from torch.fx.passes.operator_support import OperatorSupportBase
@@ -42,6 +43,9 @@
4243
)
4344
from .utils import filter_fn, generate_qnn_executorch_option, get_skip_decomp_table
4445

46+
logger = logging.getLogger(__name__)
47+
logger.setLevel(logging.DEBUG)
48+
4549

4650
class QnnOperatorSupport(OperatorSupportBase):
4751
def __init__(
@@ -124,6 +128,7 @@ def __init__(
124128
compiler_specs: List[CompileSpec],
125129
skip_node_id_set: set = None,
126130
skip_node_op_set: set = None,
131+
skip_mutable_buffer: bool = False,
127132
):
128133
self.compiler_specs_snapshot = copy.deepcopy(compiler_specs)
129134

@@ -133,6 +138,7 @@ def __init__(
133138
self.partition_tags: Dict[str, DelegationSpec] = {}
134139
self.skip_node_id_set = set() if skip_node_id_set is None else skip_node_id_set
135140
self.skip_node_op_set = set() if skip_node_op_set is None else skip_node_op_set
141+
self.skip_mutable_buffer = skip_mutable_buffer
136142

137143
def generate_partitions(
138144
self, edge_program: torch.export.ExportedProgram
@@ -178,6 +184,15 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu
178184
if len(partitions) != 0:
179185
self.tag_nodes(partitions, edge_program)
180186
tag_constant_data(edge_program)
187+
if not self.skip_mutable_buffer:
188+
logger.info(
189+
"Qnn partitioner will delegate torch mutable buffer with the same I/O address during the runtime, "
190+
"so if your model contains mutable buffer, "
191+
"then you can get the better performance with skip_mutable_buffer=False. "
192+
"If you encounter accuracy issue during the runtime, "
193+
"then please set `skip_mutable_buffer=True` and try again."
194+
)
195+
tag_mutated_buffer(edge_program)
181196
for node in edge_program.graph_module.graph.nodes:
182197
if hasattr(node, "meta"):
183198
# pop certain keys in meta for not affecting the passes in compilation

0 commit comments

Comments
 (0)