Skip to content

Commit 8d3710e

Browse files
committed
Qualcomm AI Engine Direct - Add MHA2SHA pass
Summary: - Integrated mha into sha pass and implemented it in qnn_preprocess - Refactored mha in static llama - Added support for masked softmax - Included spin quant r3 support - Combined the n_heads key-value cache into a single cache for each layer to decrease the number of inputs and outputs, which enhances performance. - Deprecated ShiftPointer kv updater mode - Since each layer now has its own kv cache, the v cache no longer benefits from ShiftPointer, which previously avoided copying the new v cache to the input v cache. To prevent user confusion, ShiftPointer mode has been deprecated - Applied the correct input template for smollm2 135m - Corrected the quantization annotation for reshape - Remove outdated code from CanonicalizeConv
1 parent ca4c575 commit 8d3710e

30 files changed

+1000
-1025
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .canonicalize_conv import CanonicalizeConv
1212
from .convert_bmm_to_matmul import ConvertBmmToMatmul
1313
from .convert_linear_to_conv2d import ConvertLinearToConv2d
14+
from .convert_mha_to_sha import ConvertMhaToSha
1415
from .convert_square_to_pow import ConvertSquareToPow
1516
from .decompose_any import DecomposeAny
1617
from .decompose_binary_alpha import DecomposeBinaryAlpha
@@ -55,6 +56,7 @@
5556
CanonicalizeConv,
5657
ConvertBmmToMatmul,
5758
ConvertLinearToConv2d,
59+
ConvertMhaToSha,
5860
ConvertSquareToPow,
5961
DecomposeAny,
6062
DecomposeBinaryAlpha,

backends/qualcomm/_passes/canonicalize_conv.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import torch
1010

1111
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
12-
from executorch.backends.qualcomm.utils.constants import QCOM_REQUANTIZE
1312
from executorch.exir.pass_base import ExportPass, PassResult
1413
from torch._guards import detect_fake_mode
1514

@@ -197,14 +196,6 @@ def call(self, graph_module: torch.fx.GraphModule):
197196
)
198197
squeeze_node.meta = copy_meta(node.meta)
199198

200-
if QCOM_REQUANTIZE in input_node.meta:
201-
input_node.meta.pop(QCOM_REQUANTIZE)
202-
if QCOM_REQUANTIZE in node.meta:
203-
squeeze_node.meta[QCOM_REQUANTIZE] = node.meta[
204-
QCOM_REQUANTIZE
205-
]
206-
conv2d_node.meta.pop(QCOM_REQUANTIZE, None)
207-
208199
for user in node.users.copy():
209200
user.replace_input_with(node, squeeze_node)
210201

0 commit comments

Comments
 (0)