Skip to content

Commit 1774308

Browse files
authored
Merge branch 'main' into sqrt_tensor
2 parents 56555a5 + 121714a commit 1774308

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+769
-324
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ It supports a wide range of models including LLMs (Large Language Models), CV (C
1919
Platform Support:
2020
- Operating Systems:
2121
- iOS
22-
- Mac
22+
- MacOS (ARM64)
2323
- Android
2424
- Linux
2525
- Microcontrollers

backends/arm/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ if(NOT EXECUTORCH_ROOT)
1212
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
1313
endif()
1414

15+
add_compile_options("-Wall" "-Werror")
16+
1517
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
1618

1719
set(_common_include_directories ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10)

backends/cadence/aot/fuse_ops.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -712,32 +712,14 @@ def _create_requantize_node(
712712
out_dtype: torch.dtype,
713713
graph: torch.fx.Graph,
714714
) -> torch.fx.Node:
715-
in_scale_tensor = graph.call_function(
716-
exir_ops.edge.aten.full.default, args=((1,), in_scale)
717-
)
718-
in_zero_point_tensor = graph.call_function(
719-
exir_ops.edge.aten.full.default,
720-
args=((1,), in_zero_point),
721-
kwargs={"dtype": torch.int32},
722-
)
723-
out_scale_tensor = graph.call_function(
724-
exir_ops.edge.aten.full.default, args=((1,), out_scale)
725-
)
726-
out_zero_point_tensor = graph.call_function(
727-
exir_ops.edge.aten.full.default,
728-
args=((1,), out_zero_point),
729-
kwargs={"dtype": torch.int32},
730-
)
731-
# cadence::requantize(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, ScalarType out_dtype) -> Tensor Y
732-
# TODO(hardiksharma): Add support for per-tensor requantize.
733715
return graph.call_function(
734-
exir_ops.edge.cadence.requantize.default,
716+
exir_ops.edge.cadence.requantize.per_tensor,
735717
args=(
736718
in_tensor,
737-
in_scale_tensor,
738-
in_zero_point_tensor,
739-
out_scale_tensor,
740-
out_zero_point_tensor,
719+
in_scale,
720+
in_zero_point,
721+
out_scale,
722+
out_zero_point,
741723
out_dtype,
742724
),
743725
)

backends/cadence/aot/remove_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def call_operator(
447447
kwargs: dict[str, Argument],
448448
meta: NodeMetadata,
449449
) -> ProxyValue:
450-
if op != exir_ops.edge.cadence.requantize.default:
450+
if op != exir_ops.edge.cadence.requantize.per_tensor:
451451
return super().call_operator(op, args, kwargs, meta)
452452

453453
# Parse the args

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def test_force_quant_dequant_fusion(self) -> None:
306306
# Verify that dequant/quant pair was replaced with requantize.
307307
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
308308
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
309-
exir_ops.edge.cadence.requantize.default: 1,
309+
exir_ops.edge.cadence.requantize.per_tensor: 1,
310310
},
311311
)
312312

@@ -336,7 +336,7 @@ def test_no_replace_quant_permute_dequant_with_requantize(self) -> None:
336336
# quantize -> permute -> dequantize should not be replaced with requantize.
337337
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
338338
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1,
339-
exir_ops.edge.cadence.requantize.default: 0,
339+
exir_ops.edge.cadence.requantize.per_tensor: 0,
340340
},
341341
)
342342

@@ -364,7 +364,7 @@ def test_replace_quant_view_dequant_with_requantize(self) -> None:
364364
# Verify that dequant/quant pair was replaced with requantize.
365365
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
366366
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
367-
exir_ops.edge.cadence.requantize.default: 1,
367+
exir_ops.edge.cadence.requantize.per_tensor: 1,
368368
},
369369
)
370370

@@ -390,7 +390,7 @@ def test_replace_dequant_quant_with_requantize(self) -> None:
390390
# Verify that dequant -> quant was replaced with requantize.
391391
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
392392
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
393-
exir_ops.edge.cadence.requantize.default: 1,
393+
exir_ops.edge.cadence.requantize.per_tensor: 1,
394394
},
395395
)
396396

@@ -420,7 +420,7 @@ def test_replace_dequant_permute_quant_with_requantize(self) -> None:
420420
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
421421
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
422422
exir_ops.edge.aten.permute_copy.default: 1,
423-
exir_ops.edge.cadence.requantize.default: 1,
423+
exir_ops.edge.cadence.requantize.per_tensor: 1,
424424
},
425425
)
426426

backends/cadence/aot/tests/test_reorder_ops_passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def test_advance_branched_quantize(self) -> None:
217217
self.assertEqual(
218218
count_node(
219219
graph_module,
220-
exir_ops.edge.cadence.requantize.default,
220+
exir_ops.edge.cadence.requantize.per_tensor,
221221
),
222222
1,
223223
)

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .annotate_quant_attrs import AnnotateQuantAttrs
99
from .annotate_stack import AnnotateStack
1010
from .annotate_unbind import AnnotateUnbind
11+
from .convert_bmm_to_matmul import ConvertBmmToMatmul
1112
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
1213
from .convert_square_to_pow import ConvertSquareToPow
1314
from .decompose_any import DecomposeAny
@@ -35,7 +36,6 @@
3536
from .remove_0d_tensor import Remove0DTensor
3637
from .remove_redundancy import RemoveRedundancy
3738
from .replace_arange_args import ReplaceArangeArgs
38-
from .replace_index_put_input import ReplaceIndexPutInput
3939
from .replace_inf_values import ReplaceInfValues
4040
from .tag_quant_io import TagQuantIO
4141

@@ -45,6 +45,7 @@
4545
AnnotateQuantAttrs,
4646
AnnotateStack,
4747
AnnotateUnbind,
48+
ConvertBmmToMatmul,
4849
ConvertConv1dToConv2d,
4950
ConvertSquareToPow,
5051
DecomposeAny,
@@ -72,7 +73,6 @@
7273
Remove0DTensor,
7374
RemoveRedundancy,
7475
ReplaceArangeArgs,
75-
ReplaceIndexPutInput,
7676
ReplaceInfValues,
7777
TagQuantIO,
7878
]
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import operator
7+
from collections import Counter
8+
from typing import List
9+
10+
import torch
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
14+
15+
16+
class ConvertBmmToMatmul(ExportPass):
17+
"""
18+
Replace bmm to matmul, because bmm is eqaul to matmul in QNN.
19+
Handle missing quantization tag for bmm op.
20+
"""
21+
22+
view_copy = exir_ops.edge.aten.view_copy.default
23+
expand_copy = exir_ops.edge.aten.expand_copy.default
24+
clone = exir_ops.edge.aten.clone.default
25+
bmm = exir_ops.edge.aten.bmm.default
26+
matmul = exir_ops.edge.aten.matmul.default
27+
patterns = [
28+
{expand_copy: 2, view_copy: 3, bmm: 1},
29+
{expand_copy: 2, view_copy: 3, bmm: 1, clone: 1},
30+
{bmm: 1},
31+
]
32+
33+
def __init__(self):
34+
super(ConvertBmmToMatmul, self).__init__()
35+
36+
def _get_ordered_inputs(
37+
self, inputs: List[torch.fx.Node], output: torch.fx.Node
38+
) -> List[torch.fx.Node]:
39+
bmm_inputs = []
40+
for arg in output.args:
41+
while arg not in inputs:
42+
arg = arg.args[0]
43+
bmm_inputs.append(arg)
44+
return bmm_inputs
45+
46+
def call(self, graph_module: torch.fx.GraphModule):
47+
graph = graph_module.graph
48+
partitions = get_source_partitions(
49+
graph,
50+
[operator.matmul, torch.matmul, torch.bmm, torch.ops.aten.matmul.default],
51+
)
52+
for _, src_partitions in partitions.items():
53+
for src_partition in src_partitions:
54+
op_cnt = Counter([n.target for n in src_partition.nodes])
55+
if op_cnt not in self.patterns:
56+
raise AssertionError(
57+
"Found a new pattern needed be converted to linear op"
58+
)
59+
60+
inputs = src_partition.input_nodes
61+
bmm_node = [n for n in src_partition.nodes if n.target == self.bmm][0]
62+
output = src_partition.output_nodes[0]
63+
# the order of src_partition.inputs is not guaranteed.
64+
lhs, rhs = self._get_ordered_inputs(inputs, bmm_node)
65+
with graph_module.graph.inserting_before(output):
66+
# replace bmm to matmul, because bmm is eqaul to matmul in qnn.
67+
matmul_node = graph.create_node(
68+
"call_function", self.matmul, (lhs, rhs)
69+
)
70+
matmul_node.meta = output.meta
71+
for user in output.users.copy():
72+
user.replace_input_with(output, matmul_node)
73+
74+
graph.eliminate_dead_code()
75+
graph_module.recompile()
76+
return PassResult(graph_module, True)

backends/qualcomm/_passes/insert_io_qdq.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99

1010
from executorch.backends.qualcomm.builders.node_visitor import q_ops
1111

12-
from executorch.backends.qualcomm.builders.utils import is_parameter
12+
from executorch.backends.qualcomm.builders.utils import (
13+
is_mutable_buffer_input,
14+
is_parameter,
15+
)
1316
from executorch.backends.qualcomm.utils.constants import (
1417
QCOM_ENCODING,
1518
QCOM_QUANT_ATTRS,
@@ -124,7 +127,10 @@ def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
124127
if (
125128
n.op == "placeholder"
126129
and n.meta.get(QCOM_QUANT_ATTRS)
127-
and not is_parameter(n, self.edge_program)
130+
and (
131+
not is_parameter(n, self.edge_program)
132+
or is_mutable_buffer_input(n, self.edge_program)
133+
)
128134
):
129135
self._insert_quant_node(
130136
graph_module, n, n.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING]

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
AnnotateQuantAttrs,
1414
AnnotateStack,
1515
AnnotateUnbind,
16+
ConvertBmmToMatmul,
1617
ConvertConv1dToConv2d,
1718
ConvertSquareToPow,
1819
DecomposeAny,
@@ -40,7 +41,6 @@
4041
Remove0DTensor,
4142
RemoveRedundancy,
4243
ReplaceArangeArgs,
43-
ReplaceIndexPutInput,
4444
ReplaceInfValues,
4545
TagQuantIO,
4646
)
@@ -80,6 +80,7 @@ def get_capture_program_passes():
8080
(AnnotateQuantAttrs, True),
8181
(AnnotateStack, True),
8282
(AnnotateUnbind, True),
83+
(ConvertBmmToMatmul, False),
8384
(ConvertConv1dToConv2d, True),
8485
(DecomposeAny, True),
8586
(DecomposeColIm, True),
@@ -92,7 +93,6 @@ def get_capture_program_passes():
9293
(RecomposeRmsNorm, False),
9394
(Remove0DTensor, True),
9495
(RemoveRedundancy, True),
95-
(ReplaceIndexPutInput, True),
9696
(TagQuantIO, False),
9797
]
9898

@@ -224,4 +224,11 @@ def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram):
224224
self.add_pass(LayoutTransform(exported_program, insert_permute=True))
225225
self.add_pass(FuseConsecutiveCast())
226226
self.add_pass(FuseConsecutiveTranspose())
227-
return self._transform(exported_program.graph_module)
227+
self._transform(exported_program.graph_module)
228+
# Update inputs_to_buffers and buffers_to_mutate in graph signature for mutable buffer
229+
# Since I/O will be inserted Q/DQ, it results in failed to mapping output node names and buffer
230+
exported_program._graph_signature = _get_updated_graph_signature(
231+
exported_program.graph_signature,
232+
exported_program.graph_module,
233+
)
234+
return exported_program.graph_module

0 commit comments

Comments
 (0)