Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .annotate_quant_attrs import AnnotateQuantAttrs
from .annotate_stack import AnnotateStack
from .annotate_unbind import AnnotateUnbind
from .convert_bmm_to_matmul import ConvertBmmToMatmul
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
from .convert_square_to_pow import ConvertSquareToPow
from .decompose_any import DecomposeAny
Expand Down Expand Up @@ -35,7 +36,6 @@
from .remove_0d_tensor import Remove0DTensor
from .remove_redundancy import RemoveRedundancy
from .replace_arange_args import ReplaceArangeArgs
from .replace_index_put_input import ReplaceIndexPutInput
from .replace_inf_values import ReplaceInfValues
from .tag_quant_io import TagQuantIO

Expand All @@ -45,6 +45,7 @@
AnnotateQuantAttrs,
AnnotateStack,
AnnotateUnbind,
ConvertBmmToMatmul,
ConvertConv1dToConv2d,
ConvertSquareToPow,
DecomposeAny,
Expand Down Expand Up @@ -72,7 +73,6 @@
Remove0DTensor,
RemoveRedundancy,
ReplaceArangeArgs,
ReplaceIndexPutInput,
ReplaceInfValues,
TagQuantIO,
]
76 changes: 76 additions & 0 deletions backends/qualcomm/_passes/convert_bmm_to_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import operator
from collections import Counter
from typing import List

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


class ConvertBmmToMatmul(ExportPass):
"""
Replace bmm to matmul, because bmm is eqaul to matmul in QNN.
Handle missing quantization tag for bmm op.
"""

view_copy = exir_ops.edge.aten.view_copy.default
expand_copy = exir_ops.edge.aten.expand_copy.default
clone = exir_ops.edge.aten.clone.default
bmm = exir_ops.edge.aten.bmm.default
matmul = exir_ops.edge.aten.matmul.default
patterns = [
{expand_copy: 2, view_copy: 3, bmm: 1},
{expand_copy: 2, view_copy: 3, bmm: 1, clone: 1},
{bmm: 1},
]

def __init__(self):
super(ConvertBmmToMatmul, self).__init__()

def _get_ordered_inputs(
self, inputs: List[torch.fx.Node], output: torch.fx.Node
) -> List[torch.fx.Node]:
bmm_inputs = []
for arg in output.args:
while arg not in inputs:
arg = arg.args[0]
bmm_inputs.append(arg)
return bmm_inputs

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
partitions = get_source_partitions(
graph,
[operator.matmul, torch.matmul, torch.bmm, torch.ops.aten.matmul.default],
)
for _, src_partitions in partitions.items():
for src_partition in src_partitions:
op_cnt = Counter([n.target for n in src_partition.nodes])
if op_cnt not in self.patterns:
raise AssertionError(
"Found a new pattern needed be converted to linear op"
)

inputs = src_partition.input_nodes
bmm_node = [n for n in src_partition.nodes if n.target == self.bmm][0]
output = src_partition.output_nodes[0]
# the order of src_partition.inputs is not guaranteed.
lhs, rhs = self._get_ordered_inputs(inputs, bmm_node)
with graph_module.graph.inserting_before(output):
# replace bmm to matmul, because bmm is eqaul to matmul in qnn.
matmul_node = graph.create_node(
"call_function", self.matmul, (lhs, rhs)
)
matmul_node.meta = output.meta
for user in output.users.copy():
user.replace_input_with(output, matmul_node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
10 changes: 8 additions & 2 deletions backends/qualcomm/_passes/insert_io_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

from executorch.backends.qualcomm.builders.node_visitor import q_ops

from executorch.backends.qualcomm.builders.utils import is_parameter
from executorch.backends.qualcomm.builders.utils import (
is_mutable_buffer_input,
is_parameter,
)
from executorch.backends.qualcomm.utils.constants import (
QCOM_ENCODING,
QCOM_QUANT_ATTRS,
Expand Down Expand Up @@ -124,7 +127,10 @@ def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
if (
n.op == "placeholder"
and n.meta.get(QCOM_QUANT_ATTRS)
and not is_parameter(n, self.edge_program)
and (
not is_parameter(n, self.edge_program)
or is_mutable_buffer_input(n, self.edge_program)
)
):
self._insert_quant_node(
graph_module, n, n.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING]
Expand Down
13 changes: 10 additions & 3 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
AnnotateQuantAttrs,
AnnotateStack,
AnnotateUnbind,
ConvertBmmToMatmul,
ConvertConv1dToConv2d,
ConvertSquareToPow,
DecomposeAny,
Expand Down Expand Up @@ -40,7 +41,6 @@
Remove0DTensor,
RemoveRedundancy,
ReplaceArangeArgs,
ReplaceIndexPutInput,
ReplaceInfValues,
TagQuantIO,
)
Expand Down Expand Up @@ -80,6 +80,7 @@ def get_capture_program_passes():
(AnnotateQuantAttrs, True),
(AnnotateStack, True),
(AnnotateUnbind, True),
(ConvertBmmToMatmul, False),
(ConvertConv1dToConv2d, True),
(DecomposeAny, True),
(DecomposeColIm, True),
Expand All @@ -92,7 +93,6 @@ def get_capture_program_passes():
(RecomposeRmsNorm, False),
(Remove0DTensor, True),
(RemoveRedundancy, True),
(ReplaceIndexPutInput, True),
(TagQuantIO, False),
]

Expand Down Expand Up @@ -224,4 +224,11 @@ def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram):
self.add_pass(LayoutTransform(exported_program, insert_permute=True))
self.add_pass(FuseConsecutiveCast())
self.add_pass(FuseConsecutiveTranspose())
return self._transform(exported_program.graph_module)
self._transform(exported_program.graph_module)
# Update inputs_to_buffers and buffers_to_mutate in graph signature for mutable buffer
# Since I/O will be inserted Q/DQ, it results in failed to mapping output node names and buffer
exported_program._graph_signature = _get_updated_graph_signature(
exported_program.graph_signature,
exported_program.graph_module,
)
return exported_program.graph_module
54 changes: 0 additions & 54 deletions backends/qualcomm/_passes/replace_index_put_input.py

This file was deleted.

7 changes: 4 additions & 3 deletions backends/qualcomm/_passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def get_passes_dependency_for_capture_program():
AnnotateQuantAttrs,
AnnotateStack,
AnnotateUnbind,
ConvertBmmToMatmul,
ConvertConv1dToConv2d,
DecomposeAny,
DecomposeColIm,
Expand All @@ -76,18 +77,19 @@ def get_passes_dependency_for_capture_program():
RecomposePixelUnshuffle,
RecomposeRmsNorm,
RemoveRedundancy,
ReplaceIndexPutInput,
TagQuantIO,
)

return {
AnnotateAdaptiveAvgPool1D: [RemoveRedundancy],
AnnotateQuantAttrs: [
ConvertBmmToMatmul,
RecomposePixelUnshuffle,
RemoveRedundancy,
],
AnnotateStack: [RemoveRedundancy],
AnnotateUnbind: [RemoveRedundancy],
ConvertBmmToMatmul: [RecomposePixelUnshuffle],
DecomposeAny: [RemoveRedundancy],
DecomposeColIm: [FoldQDQ],
DecomposeLinalgVectorNorm: [RemoveRedundancy],
Expand All @@ -103,8 +105,7 @@ def get_passes_dependency_for_capture_program():
],
RecomposePixelUnshuffle: [RemoveRedundancy],
RecomposeRmsNorm: [RemoveRedundancy],
ReplaceIndexPutInput: [LayoutTransform],
TagQuantIO: [ReplaceIndexPutInput],
TagQuantIO: [LayoutTransform],
}


Expand Down
44 changes: 33 additions & 11 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
get_parameter,
is_graph_input,
is_graph_output,
is_mutable_buffer_input,
is_mutable_buffer_output,
is_parameter,
)

Expand Down Expand Up @@ -307,7 +309,9 @@ def get_tensor_type(
node: torch.fx.Node,
tensor_type: PyQnnWrapper.Qnn_TensorType_t,
) -> PyQnnWrapper.Qnn_TensorType_t:
is_input = is_graph_input(node, self.edge_program)
is_input = is_graph_input(node, self.edge_program) or is_mutable_buffer_input(
node, self.edge_program
)
is_output = is_graph_output(node)
# handle logic for input/output tensors
if is_input or is_output:
Expand Down Expand Up @@ -352,6 +356,33 @@ def get_dynamic_dimension(self, dims):

return dynamic_dims if any(dynamic_dims) else [], nominal_dims

def get_tensor_name(
self,
node: torch.fx.Node,
wrapper_idx: int = 0,
):
tensor_name = f"{node.name}_{wrapper_idx}"
# The `input_{id}` is utilized for sorting at runtime. Due to multiple passes in qnn_preprocess,
# the input order between QNN and the original graph’s forward function may differ.
# The `mutbuf_{id}` is utilized for mapping I/O of mutable buffer at runtime.
# The `output_` is identified as the graph’s output at runtime to prevent confusion with per_tensor_dump.
if is_mutable_buffer_input(node, self.edge_program):
fqn = self.edge_program.graph_signature.inputs_to_buffers[node.target]
position_index = list(
self.edge_program.graph_signature.buffers_to_mutate.values()
).index(fqn)
tensor_name = f"input_{str(self.external_ids[node])}_mutbuf_{str(position_index)}_{tensor_name}"
elif is_graph_input(node, self.edge_program):
tensor_name = f"input_{str(self.external_ids[node])}_{tensor_name}"
elif is_mutable_buffer_output(node, self.edge_program):
position_index = list(
self.edge_program.graph_signature.buffers_to_mutate.keys()
).index(node.name)
tensor_name = f"output_mutbuf_{position_index}_{tensor_name}"
elif is_graph_output(node):
tensor_name = f"output_{tensor_name}"
return tensor_name

def define_custom_tensor_wrapper(
self,
node_name: str,
Expand Down Expand Up @@ -413,16 +444,7 @@ def define_tensor(
if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
return cached

tensor_name = f"{tensor_source_node.name}_{wrapper_idx}"
if is_graph_input(tensor_source_node, self.edge_program):
tensor_name = (
"input_"
+ str(self.external_ids[tensor_source_node])
+ "_"
+ tensor_name
)
if is_graph_output(tensor_source_node):
tensor_name = "output_" + tensor_name
tensor_name = self.get_tensor_name(tensor_source_node, wrapper_idx)
dims = torch.Size([1]) if len(tensor.size()) == 0 else tensor.size()
dynamic_dims, nominal_dims = self.get_dynamic_dimension(dims)
tensor_type = self.get_tensor_type(tensor_source_node, tensor_type)
Expand Down
6 changes: 4 additions & 2 deletions backends/qualcomm/builders/node_visitor_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .node_visitor import NodeVisitor
from .op_custom_op import CustomOp
from .utils import is_graph_input, is_graph_output
from .utils import is_graph_input, is_graph_output, is_mutable_buffer_input


# This will hold mapping of all node names to the visitor class
Expand All @@ -39,7 +39,9 @@ def generate_node_to_external_map(
# The order in which we visit the placeholder node is same as the *args
# order for the forward(*args) signature for this gm. Using the order of
# the nodes as external_id to extract the right arg from *args at runtime
if is_graph_input(node, edge_program):
if is_graph_input(node, edge_program) or is_mutable_buffer_input(
node, edge_program
):
node_to_external_map[node] = len(node_to_external_map)
for node in edge_program.graph_module.graph.nodes:
if is_graph_output(node):
Expand Down
7 changes: 6 additions & 1 deletion backends/qualcomm/builders/op_index_put.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import torch

from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS

from .node_visitor import NodeVisitor
from .node_visitor_manager import register_node_visitor
from .qnn_constants import OpScatterNd, QNN_OP_PACKAGE_NAME_QTI_AISW
Expand All @@ -22,6 +23,10 @@ def define_node(
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = self.get_node(node.args[0])
# Because the args[0] of index_put op doesn't annotate, need to fill in the quant_attr with the node here.
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
quant_attrs = quant_attrs.copy()
input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
Expand Down
Loading
Loading