Skip to content

Commit 2c88338

Browse files
committed
Qualcomm AI Engine Direct - Support MaskedSoftmax in static llama
Summary: - Add a unit test for masked softmax - Add amin op support - Add a flag `--enable_masked_softmax` to enable masked softmax feature. It is designed to optimize the LLMs accuracy and performance executed on HTP backend. MaskedSoftmax is used to replace the Softmax(Add(In, Mask)) structure in attention block in LLMs during backend optimization. For more details, please refer to QNN documents. Note that it is only supported starting from QNN 2.35.
1 parent b562f36 commit 2c88338

File tree

11 files changed

+232
-7
lines changed

11 files changed

+232
-7
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class LayoutTransform(ExportPass):
6363
exir_ops.edge.aten.abs.default,
6464
exir_ops.edge.aten.add.Tensor,
6565
exir_ops.edge.aten.amax.default,
66+
exir_ops.edge.aten.amin.default,
6667
exir_ops.edge.aten.atan.default,
6768
exir_ops.edge.aten.bitwise_or.Tensor,
6869
exir_ops.edge.aten.bmm.default,

backends/qualcomm/aot/python/PyQnnManagerAdaptor.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88
#include <executorch/backends/qualcomm/aot/python/PyQnnManagerAdaptor.h>
99
#include <pybind11/pybind11.h>
10+
#include "QnnSdkBuildId.h"
1011

1112
namespace py = pybind11;
1213
namespace executorch {
@@ -19,6 +20,7 @@ PYBIND11_MODULE(PyQnnManagerAdaptor, m) {
1920
// TODO: Add related documents for configurations listed below
2021
using namespace qnn_delegate;
2122

23+
m.def("GetQnnSdkBuildId", []() { return std::string(QNN_SDK_BUILD_ID); });
2224
py::class_<QnnExecuTorchContextBinary>(m, "QnnExecuTorchContextBinary")
2325
.def(py::init<>());
2426

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
op_adaptive_avg_pool2d,
1111
op_add,
1212
op_amax,
13+
op_amin,
1314
op_and,
1415
op_arange,
1516
op_argmax,
@@ -106,6 +107,7 @@
106107
op_adaptive_avg_pool2d,
107108
op_add,
108109
op_amax,
110+
op_amin,
109111
op_and,
110112
op_arange,
111113
op_argmax,
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import cast, Dict, List
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import numpy as np
11+
12+
import torch
13+
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
14+
15+
from .node_visitor import NodeVisitor
16+
from .node_visitor_manager import register_node_visitor
17+
from .qnn_constants import OpReduceMin, QNN_OP_PACKAGE_NAME_QTI_AISW
18+
19+
20+
@register_node_visitor
21+
class AMin(NodeVisitor):
22+
target = ["aten.amin.default"]
23+
24+
def __init__(self, *args) -> None:
25+
super().__init__(*args)
26+
27+
def define_node(
28+
self,
29+
node: torch.fx.Node,
30+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
31+
) -> PyQnnWrapper.PyQnnOpWrapper:
32+
input_node = self.get_node(node.args[0])
33+
input_tensor = self.get_tensor(input_node, node)
34+
input_tensor_wrapper = self.define_tensor(
35+
input_node,
36+
node,
37+
input_tensor,
38+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
39+
nodes_to_wrappers,
40+
)
41+
42+
# mean dims and keep dims
43+
mean_dims = cast(List[int], node.args[1])
44+
mean_dims = [
45+
mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims
46+
]
47+
if QCOM_AXIS_ORDER in node.meta:
48+
mean_dims = [
49+
node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in mean_dims
50+
]
51+
mean_dims_shape = [len(mean_dims)]
52+
53+
output_tensor = self.get_tensor(node, node)
54+
output_tensor_wrapper = self.define_tensor(
55+
node,
56+
node,
57+
output_tensor,
58+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
59+
nodes_to_wrappers,
60+
)
61+
62+
reduce_min_op = PyQnnWrapper.PyQnnOpWrapper(
63+
node.name,
64+
QNN_OP_PACKAGE_NAME_QTI_AISW,
65+
OpReduceMin.op_name,
66+
)
67+
reduce_min_op.AddInputTensors([input_tensor_wrapper])
68+
reduce_min_op.AddOutputTensors([output_tensor_wrapper])
69+
reduce_min_op.AddTensorParam(
70+
OpReduceMin.param_axes,
71+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
72+
len(mean_dims_shape),
73+
mean_dims_shape,
74+
np.array(mean_dims, dtype=np.uint32),
75+
True,
76+
)
77+
if len(node.args) > 2:
78+
keep_dims = cast(bool, node.args[2])
79+
reduce_min_op.AddScalarParam(
80+
OpReduceMin.param_keep_dims,
81+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
82+
{QCOM_DATA: keep_dims},
83+
)
84+
85+
return reduce_min_op

backends/qualcomm/quantizer/annotators.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,11 @@ def annotate_argmax(node: Node, quantization_config: QuantizationConfig) -> None
218218
annotate_single_in(node, quantization_config)
219219

220220

221+
@register_annotator([torch.ops.aten.amin.default])
222+
def annotate_amin(node: Node, quantization_config: QuantizationConfig) -> None:
223+
annotate_binary(node, quantization_config)
224+
225+
221226
@register_annotator([torch.ops.aten.argmin.default])
222227
def annotate_argmin(node: Node, quantization_config: QuantizationConfig) -> None:
223228
annotate_single_in(node, quantization_config)

backends/qualcomm/tests/models.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,16 @@ def forward(self, x):
102102
return torch.amax(x, dim=self.dim, keepdim=self.keepdim)
103103

104104

105+
class AMin(torch.nn.Module):
106+
def __init__(self, dim=None, keepdim=False):
107+
super().__init__()
108+
self.dim = dim
109+
self.keepdim = keepdim
110+
111+
def forward(self, x):
112+
return torch.amin(x, dim=self.dim, keepdim=self.keepdim)
113+
114+
105115
class Arange(torch.nn.Module):
106116
def __init__(self, start, end, step, dtype):
107117
super().__init__()
@@ -1155,6 +1165,17 @@ def forward(self, attn_mask):
11551165
)
11561166

11571167

1168+
class MaskedSoftmax(torch.nn.Module):
1169+
def __init__(self):
1170+
super().__init__()
1171+
1172+
def forward(self, attention_mask, input):
1173+
attn_weights = torch.where(
1174+
attention_mask == 0, input, torch.amin(input, dim=3, keepdim=True) + (-20)
1175+
)
1176+
return torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
1177+
1178+
11581179
class MaxDim(torch.nn.Module):
11591180
def __init__(self):
11601181
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,14 @@
4646
from_context_binary,
4747
generate_htp_compiler_spec,
4848
generate_qnn_executorch_compiler_spec,
49+
is_qnn_sdk_version_less_than,
4950
PyQnnManagerAdaptor,
5051
rewrite_prepared_observer,
5152
skip_annotation,
5253
to_edge_transform_and_lower_to_qnn,
5354
update_spill_fill_size,
5455
)
5556

56-
from executorch.examples.models.llama.llama_transformer import MOEFeedForward
57-
58-
from executorch.examples.models.llama.model_args import ModelArgs
59-
6057
from executorch.examples.qualcomm.utils import (
6158
make_quantizer,
6259
setup_common_args_and_variables,
@@ -136,6 +133,13 @@ def test_qnn_backend_amax(self):
136133
with self.subTest(i=i):
137134
self.lower_module_and_test_output(module, sample_input)
138135

136+
def test_qnn_backend_amin(self):
137+
modules = [AMin(dim=1, keepdim=False), AMin(dim=1, keepdim=True)] # noqa: F405
138+
sample_input = (torch.randn(4, 4),)
139+
for i, module in enumerate(modules):
140+
with self.subTest(i=i):
141+
self.lower_module_and_test_output(module, sample_input)
142+
139143
def test_qnn_backend_any(self):
140144
modules = [Any(), Any(dim=[0, 1]), Any(dim=1, keepdim=True)] # noqa: F405
141145
sample_input = (torch.randn(3, 3, 3) > 0,)
@@ -1227,6 +1231,9 @@ def test_qnn_backend_lift_add_tensor(self):
12271231

12281232
@unittest.skip("Fail because of bad accuracy")
12291233
def test_qnn_backend_moe_feed_forward(self):
1234+
from executorch.examples.models.llama.llama_transformer import MOEFeedForward
1235+
from executorch.examples.models.llama.model_args import ModelArgs
1236+
12301237
args = ModelArgs()
12311238
args.dim = 32
12321239
args.n_heads = 8
@@ -1421,6 +1428,14 @@ def test_qnn_backend_amax(self):
14211428
module = self.get_qdq_module(module, sample_input)
14221429
self.lower_module_and_test_output(module, sample_input)
14231430

1431+
def test_qnn_backend_amin(self):
1432+
modules = [AMin(dim=1, keepdim=False), AMin(dim=1, keepdim=True)] # noqa: F405
1433+
sample_input = (torch.randn(4, 4),)
1434+
for i, module in enumerate(modules):
1435+
with self.subTest(i=i):
1436+
module = self.get_qdq_module(module, sample_input)
1437+
self.lower_module_and_test_output(module, sample_input)
1438+
14241439
def test_qnn_backend_any(self):
14251440
modules = [Any(), Any(dim=[0, 1]), Any(dim=1, keepdim=True)] # noqa: F405
14261441
sample_input = (torch.randn(3, 3, 3) > 0,)
@@ -2643,8 +2658,57 @@ def test_qnn_backend_einsum_outer_product_relu(self):
26432658
module = self.get_qdq_module(module, sample_input)
26442659
self.lower_module_and_test_output(module, sample_input)
26452660

2661+
@unittest.skipIf(is_qnn_sdk_version_less_than("2.35"), "UT pass after QNN 2.35")
2662+
def test_qnn_backend_masked_softmax(self):
2663+
if self.enable_x86_64:
2664+
self.skipTest(
2665+
"At the moment, testing is only being conducted on the device."
2666+
)
2667+
module = MaskedSoftmax() # noqa: F405
2668+
kv_arange = torch.arange(128)
2669+
reshaped_cache_position = torch.tensor([[0]])
2670+
2671+
# Simplest and most efficient way to obtain a causal mask
2672+
causal_mask = kv_arange <= reshaped_cache_position
2673+
atten_mask = torch.full((1, 128), torch.tensor(-65535.0))
2674+
atten_mask = atten_mask.masked_fill(causal_mask, 0)
2675+
atten_mask = atten_mask[None, None, :, :].expand(1, -1, -1, -1)
2676+
sample_input = (atten_mask, torch.randn([1, 1, 1, 128]))
2677+
# Masked softmax is only support in quantized model
2678+
module = self.get_qdq_module(
2679+
module, sample_input, quant_dtype=QuantDtype.use_16a8w
2680+
)
2681+
backend_options = generate_htp_compiler_spec(use_fp16=False)
2682+
compiler_spec = generate_qnn_executorch_compiler_spec(
2683+
soc_model=self.chipset_table[TestQNN.model],
2684+
backend_options=backend_options,
2685+
optrace=True,
2686+
)
2687+
with tempfile.TemporaryDirectory() as tmp_dir:
2688+
edge_prog_mgr = to_edge_transform_and_lower_to_qnn(
2689+
module, sample_input, compiler_spec
2690+
).to_executorch()
2691+
pte_path = f"{tmp_dir}/model.pte"
2692+
with open(pte_path, "wb") as f:
2693+
edge_prog_mgr.write_to_file(f)
2694+
adb = self.get_adb_tool(pte_path)
2695+
binaries_trace = generate_optrace(
2696+
tmp_dir, self.chipset_table[self.model], adb, pte_path, sample_input
2697+
)
2698+
has_masked_softmax = False
2699+
for _, (_, qhas) in binaries_trace.items():
2700+
with open(qhas, "r") as qhas_file:
2701+
qhas_data = json.load(qhas_file)
2702+
for row in qhas_data["data"]["htp_op_types"]["data"]:
2703+
if "MaskedSoftmax" in row["op"]:
2704+
has_masked_softmax = True
2705+
self.assertTrue(has_masked_softmax)
2706+
26462707
@unittest.skip("UT pass before QNN 2.26, segfault during partitioner")
26472708
def test_qnn_backend_moe_feed_forward(self):
2709+
from executorch.examples.models.llama.llama_transformer import MOEFeedForward
2710+
from executorch.examples.models.llama.model_args import ModelArgs
2711+
26482712
args = ModelArgs()
26492713
args.dim = 32
26502714
args.n_heads = 8

backends/qualcomm/utils/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import operator
7+
import re
78
import warnings
89
from collections import defaultdict, OrderedDict
910
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1011

1112
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor
1213

1314
import executorch.exir as exir
14-
1515
import torch
1616

1717
from executorch.backends.qualcomm._passes import AnnotateStack, AnnotateUnbind
@@ -1167,3 +1167,19 @@ def rewrite_prepared_observer(
11671167
continue
11681168
for target_name in module_name_list[old_module]:
11691169
setattr(graph_module, target_name, new_observer)
1170+
1171+
1172+
def is_qnn_sdk_version_less_than(target_version):
1173+
current_version = PyQnnManagerAdaptor.GetQnnSdkBuildId()
1174+
1175+
match = re.search(r"v(\d+)\.(\d+)", current_version)
1176+
if match:
1177+
current_major, current_minor = map(int, match.groups()[:2])
1178+
else:
1179+
raise ValueError(
1180+
f"Failed to get current major and minor version from QNN sdk Build id {current_version}"
1181+
)
1182+
1183+
target_major, target_minor = map(int, target_version.split(".")[:2])
1184+
1185+
return current_major == target_major and current_minor < target_minor

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,16 @@ On the other hand, if you already have a pre-compiled .pte model, you can perfor
123123
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE}
124124
```
125125

126+
#### KV Cache Updater
127+
126128
You can select the KV Cache update mechanism at runtime by setting the `KV_UPDATER` variable to either "shift_pointer" or "smart_mask". By default, it is set to "smart_mask".
127129
`KV_UPDATER` = "shift_pointer"
128130
```bash
129131
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER}
130132
```
131133

134+
#### Lookahead Decoding Mode
135+
132136
You can choose the lookahead mode to enhance decoding speed. To use this mode, you need to specify the following parameters:
133137
- `--ngram` (N-gram size): Represents the size of the n-grams used in the lookahead process.
134138
- `--window` (window size): Determines how many future tokens the algorithm attempts to predict in each step.
@@ -139,3 +143,8 @@ For more details, please refer to the paper ["Break the Sequential Dependency of
139143
```bash
140144
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode lookahead --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --ngram 3 --window 2 --gcap 2
141145
```
146+
147+
#### Masked Softmax
148+
149+
You can enable MaskedSoftmax feature by providing the flag `--enable_masked_softmax`. It is designed to optimize the LLMs accuracy and performance executed on HTP backend. MaskedSoftmax is used to replace the Softmax(Add(In, Mask)) structure in attention block in LLMs during backend optimization. For more details, please refer to QNN documents.
150+
Note that it is only supported starting from QNN 2.35.

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
generate_htp_compiler_spec,
5353
generate_qnn_executorch_compiler_spec,
5454
get_soc_to_chipset_map,
55+
is_qnn_sdk_version_less_than,
5556
to_edge_transform_and_lower_to_qnn,
5657
update_spill_fill_size,
5758
)
@@ -515,9 +516,9 @@ def compile(args, pte_filename, tokenizer):
515516
kv_config.max_batch_size = 1
516517
kv_config.max_seq_len = args.max_seq_len
517518
kv_config.use_kv_cache = True
519+
kv_config.enable_masked_softmax = args.enable_masked_softmax
518520

519521
prefill_config = copy.copy(kv_config)
520-
prefill_config.max_seq_len = args.max_seq_len
521522
prefill_config.use_kv_cache = (
522523
False if args.max_seq_len == args.prefill_ar_len else True
523524
)
@@ -1144,6 +1145,12 @@ def _build_parser():
11441145
action="store_true",
11451146
)
11461147

1148+
parser.add_argument(
1149+
"--enable_masked_softmax",
1150+
help="The MaskedSoftmax feature is designed to optimize the LLMs accuracy and performance executed on HTP backend. Note that it is only supported starting from QNN 2.35.",
1151+
action="store_true",
1152+
)
1153+
11471154
parser.add_argument("-v", "--verbose", action="store_true")
11481155

11491156
return parser
@@ -1197,6 +1204,12 @@ def export_llama(args) -> None:
11971204
else:
11981205
raise RuntimeError(f"Using an unknown kv update {args.kv_updater}")
11991206

1207+
if args.enable_masked_softmax and is_qnn_sdk_version_less_than("2.35"):
1208+
logging.warning(
1209+
"Masked softmax is supported after QNN SDK 2.35. Override enable_masked_softmax to False."
1210+
)
1211+
args.enable_masked_softmax = False
1212+
12001213
if args.pre_gen_pte:
12011214
inference(args, pte_filename, runtime_tokenizer_path, args.pre_gen_pte)
12021215
print(f"Finish the running pre_gen_pte from {args.pre_gen_pte}")

0 commit comments

Comments
 (0)