Skip to content

Commit 19c5aa1

Browse files
committed
Fixed the CI for meta's llama
1 parent e922f14 commit 19c5aa1

File tree

6 files changed

+90
-4
lines changed

6 files changed

+90
-4
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .annotate_quant_attrs import AnnotateQuantAttrs
99
from .annotate_stack import AnnotateStack
1010
from .annotate_unbind import AnnotateUnbind
11+
from .convert_bmm_to_matmul import ConvertBmmToMatmul
1112
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
1213
from .convert_square_to_pow import ConvertSquareToPow
1314
from .decompose_any import DecomposeAny
@@ -44,6 +45,7 @@
4445
AnnotateQuantAttrs,
4546
AnnotateStack,
4647
AnnotateUnbind,
48+
ConvertBmmToMatmul,
4749
ConvertConv1dToConv2d,
4850
ConvertSquareToPow,
4951
DecomposeAny,
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import operator
7+
from collections import Counter
8+
from typing import List
9+
10+
import torch
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
14+
15+
16+
class ConvertBmmToMatmul(ExportPass):
17+
"""
18+
Replace bmm to matmul, because bmm is eqaul to matmul in QNN.
19+
Handle missing quantization tag for bmm op.
20+
"""
21+
22+
view_copy = exir_ops.edge.aten.view_copy.default
23+
expand_copy = exir_ops.edge.aten.expand_copy.default
24+
clone = exir_ops.edge.aten.clone.default
25+
bmm = exir_ops.edge.aten.bmm.default
26+
matmul = exir_ops.edge.aten.matmul.default
27+
patterns = [
28+
{expand_copy: 2, view_copy: 3, bmm: 1},
29+
{expand_copy: 2, view_copy: 3, bmm: 1, clone: 1},
30+
{bmm: 1},
31+
]
32+
33+
def __init__(self):
34+
super(ConvertBmmToMatmul, self).__init__()
35+
36+
def _get_ordered_inputs(
37+
self, inputs: List[torch.fx.Node], output: torch.fx.Node
38+
) -> List[torch.fx.Node]:
39+
bmm_inputs = []
40+
for arg in output.args:
41+
while arg not in inputs:
42+
arg = arg.args[0]
43+
bmm_inputs.append(arg)
44+
return bmm_inputs
45+
46+
def call(self, graph_module: torch.fx.GraphModule):
47+
graph = graph_module.graph
48+
partitions = get_source_partitions(
49+
graph,
50+
[operator.matmul, torch.matmul, torch.bmm, torch.ops.aten.matmul.default],
51+
)
52+
for _, src_partitions in partitions.items():
53+
for src_partition in src_partitions:
54+
op_cnt = Counter([n.target for n in src_partition.nodes])
55+
if op_cnt not in self.patterns:
56+
raise AssertionError(
57+
"Found a new pattern needed be converted to linear op"
58+
)
59+
60+
inputs = src_partition.input_nodes
61+
bmm_node = [n for n in src_partition.nodes if n.target == self.bmm][0]
62+
output = src_partition.output_nodes[0]
63+
# the order of src_partition.inputs is not guaranteed.
64+
lhs, rhs = self._get_ordered_inputs(inputs, bmm_node)
65+
with graph_module.graph.inserting_before(output):
66+
# replace bmm to matmul, because bmm is eqaul to matmul in qnn.
67+
matmul_node = graph.create_node(
68+
"call_function", self.matmul, (lhs, rhs)
69+
)
70+
matmul_node.meta = output.meta
71+
for user in output.users.copy():
72+
user.replace_input_with(output, matmul_node)
73+
74+
graph.eliminate_dead_code()
75+
graph_module.recompile()
76+
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
AnnotateQuantAttrs,
1414
AnnotateStack,
1515
AnnotateUnbind,
16+
ConvertBmmToMatmul,
1617
ConvertConv1dToConv2d,
1718
ConvertSquareToPow,
1819
DecomposeAny,
@@ -79,6 +80,7 @@ def get_capture_program_passes():
7980
(AnnotateQuantAttrs, True),
8081
(AnnotateStack, True),
8182
(AnnotateUnbind, True),
83+
(ConvertBmmToMatmul, False),
8284
(ConvertConv1dToConv2d, True),
8385
(DecomposeAny, True),
8486
(DecomposeColIm, True),

backends/qualcomm/_passes/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def get_passes_dependency_for_capture_program():
6464
AnnotateQuantAttrs,
6565
AnnotateStack,
6666
AnnotateUnbind,
67+
ConvertBmmToMatmul,
6768
ConvertConv1dToConv2d,
6869
DecomposeAny,
6970
DecomposeColIm,
@@ -82,11 +83,13 @@ def get_passes_dependency_for_capture_program():
8283
return {
8384
AnnotateAdaptiveAvgPool1D: [RemoveRedundancy],
8485
AnnotateQuantAttrs: [
86+
ConvertBmmToMatmul,
8587
RecomposePixelUnshuffle,
8688
RemoveRedundancy,
8789
],
8890
AnnotateStack: [RemoveRedundancy],
8991
AnnotateUnbind: [RemoveRedundancy],
92+
ConvertBmmToMatmul: [RecomposePixelUnshuffle],
9093
DecomposeAny: [RemoveRedundancy],
9194
DecomposeColIm: [FoldQDQ],
9295
DecomposeLinalgVectorNorm: [RemoveRedundancy],

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,14 +292,15 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
292292
)
293293

294294
def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None:
295-
input = node.args[0]
295+
# Avoid annotating the input node because mutable buffers will be folded during the convert_pt2e process.
296296
value = node.args[2]
297+
297298
input_qspec_map = {}
298-
input_qspec_map[input] = quantization_config.input_activation
299-
input_qspec_map[value] = SharedQuantizationSpec((input, node))
299+
input_qspec_map[value] = quantization_config.input_activation
300+
300301
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
301302
input_qspec_map=input_qspec_map,
302-
output_qspec=SharedQuantizationSpec((input, node)),
303+
output_qspec=SharedQuantizationSpec((value, node)),
303304
_annotated=True,
304305
)
305306

examples/models/llama/export_llama_lib.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,7 @@ def _to_edge_and_lower_llama( # noqa: C901
914914
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm._passes`
915915
from executorch.backends.qualcomm._passes import (
916916
AnnotateStack,
917+
ConvertBmmToMatmul,
917918
FoldQDQ,
918919
RecomposeRmsNorm,
919920
TagQuantIO,
@@ -956,6 +957,7 @@ def _to_edge_and_lower_llama( # noqa: C901
956957
passes_job = get_capture_program_passes()
957958
dep_table = get_passes_dependency_for_capture_program()
958959
passes_job[AnnotateStack][QCOM_PASS_ACTIVATE_KEY] = True
960+
passes_job[ConvertBmmToMatmul][QCOM_PASS_ACTIVATE_KEY] = True
959961
passes_job[RecomposeRmsNorm][QCOM_PASS_ACTIVATE_KEY] = True
960962
passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True
961963
passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][

0 commit comments

Comments
 (0)