Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.amax.default,
exir_ops.edge.aten.amin.default,
exir_ops.edge.aten.atan.default,
exir_ops.edge.aten.bitwise_or.Tensor,
exir_ops.edge.aten.bmm.default,
Expand Down
18 changes: 18 additions & 0 deletions backends/qualcomm/aot/python/PyQnnManagerAdaptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*/
#include <executorch/backends/qualcomm/aot/python/PyQnnManagerAdaptor.h>
#include <pybind11/pybind11.h>
#include "QnnSdkBuildId.h"

namespace py = pybind11;
namespace executorch {
Expand All @@ -15,10 +16,27 @@ namespace qnn {

using executorch::runtime::Error;

std::string GetQnnSdkBuildId(std::string library_path) {
QnnImplementation qnn_loaded_backend = QnnImplementation(library_path);
ET_CHECK_MSG(
qnn_loaded_backend.Load(nullptr) == Error::Ok,
"Fail to load Qnn library");
const char* id = nullptr;
// Safe to call any time, backend does not have to be created.
Qnn_ErrorHandle_t err =
qnn_loaded_backend.GetQnnInterface().qnn_backend_get_build_id(&id);
if (err != QNN_SUCCESS || id == nullptr) {
throw std::runtime_error("Failed to get QNN backend build ID");
}
qnn_loaded_backend.TerminateAllBackends();
return std::string(id);
}

PYBIND11_MODULE(PyQnnManagerAdaptor, m) {
// TODO: Add related documents for configurations listed below
using namespace qnn_delegate;

m.def("GetQnnSdkBuildId", &GetQnnSdkBuildId);
py::class_<QnnExecuTorchContextBinary>(m, "QnnExecuTorchContextBinary")
.def(py::init<>());

Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
op_adaptive_avg_pool2d,
op_add,
op_amax,
op_amin,
op_and,
op_arange,
op_argmax,
Expand Down Expand Up @@ -106,6 +107,7 @@
op_adaptive_avg_pool2d,
op_add,
op_amax,
op_amin,
op_and,
op_arange,
op_argmax,
Expand Down
85 changes: 85 additions & 0 deletions backends/qualcomm/builders/op_amin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import cast, Dict, List

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import numpy as np

import torch
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA

from .node_visitor import NodeVisitor
from .node_visitor_manager import register_node_visitor
from .qnn_constants import OpReduceMin, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class AMin(NodeVisitor):
target = ["aten.amin.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

# mean dims and keep dims
mean_dims = cast(List[int], node.args[1])
mean_dims = [
mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims
]
if QCOM_AXIS_ORDER in node.meta:
mean_dims = [
node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in mean_dims
]
mean_dims_shape = [len(mean_dims)]

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

reduce_min_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpReduceMin.op_name,
)
reduce_min_op.AddInputTensors([input_tensor_wrapper])
reduce_min_op.AddOutputTensors([output_tensor_wrapper])
reduce_min_op.AddTensorParam(
OpReduceMin.param_axes,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
len(mean_dims_shape),
mean_dims_shape,
np.array(mean_dims, dtype=np.uint32),
True,
)
if len(node.args) > 2:
keep_dims = cast(bool, node.args[2])
reduce_min_op.AddScalarParam(
OpReduceMin.param_keep_dims,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
{QCOM_DATA: keep_dims},
)

return reduce_min_op
5 changes: 5 additions & 0 deletions backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,11 @@ def annotate_argmax(node: Node, quantization_config: QuantizationConfig) -> None
annotate_single_in(node, quantization_config)


@register_annotator([torch.ops.aten.amin.default])
def annotate_amin(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_binary(node, quantization_config)


@register_annotator([torch.ops.aten.argmin.default])
def annotate_argmin(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in(node, quantization_config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class QnnInterface {

// --------- QnnBackend ---------
DEFINE_SHIM_FUNCTION_INTERFACE(backend_create, backendCreate);
DEFINE_SHIM_FUNCTION_INTERFACE(backend_get_build_id, backendGetBuildId);
DEFINE_SHIM_FUNCTION_INTERFACE(backend_free, backendFree);
DEFINE_SHIM_FUNCTION_INTERFACE(
backend_register_op_package,
Expand Down
21 changes: 21 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ def forward(self, x):
return torch.amax(x, dim=self.dim, keepdim=self.keepdim)


class AMin(torch.nn.Module):
def __init__(self, dim=None, keepdim=False):
super().__init__()
self.dim = dim
self.keepdim = keepdim

def forward(self, x):
return torch.amin(x, dim=self.dim, keepdim=self.keepdim)


class Arange(torch.nn.Module):
def __init__(self, start, end, step, dtype):
super().__init__()
Expand Down Expand Up @@ -1155,6 +1165,17 @@ def forward(self, attn_mask):
)


class MaskedSoftmax(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, attention_mask, input):
attn_weights = torch.where(
attention_mask == 0, input, torch.amin(input, dim=3, keepdim=True) + (-20)
)
return torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)


class MaxDim(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
72 changes: 68 additions & 4 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,14 @@
from_context_binary,
generate_htp_compiler_spec,
generate_qnn_executorch_compiler_spec,
is_qnn_sdk_version_less_than,
PyQnnManagerAdaptor,
rewrite_prepared_observer,
skip_annotation,
to_edge_transform_and_lower_to_qnn,
update_spill_fill_size,
)

from executorch.examples.models.llama.llama_transformer import MOEFeedForward

from executorch.examples.models.llama.model_args import ModelArgs

from executorch.examples.qualcomm.utils import (
make_quantizer,
setup_common_args_and_variables,
Expand Down Expand Up @@ -136,6 +133,13 @@ def test_qnn_backend_amax(self):
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_amin(self):
modules = [AMin(dim=1, keepdim=False), AMin(dim=1, keepdim=True)] # noqa: F405
sample_input = (torch.randn(4, 4),)
for i, module in enumerate(modules):
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_any(self):
modules = [Any(), Any(dim=[0, 1]), Any(dim=1, keepdim=True)] # noqa: F405
sample_input = (torch.randn(3, 3, 3) > 0,)
Expand Down Expand Up @@ -1227,6 +1231,9 @@ def test_qnn_backend_lift_add_tensor(self):

@unittest.skip("Fail because of bad accuracy")
def test_qnn_backend_moe_feed_forward(self):
from executorch.examples.models.llama.llama_transformer import MOEFeedForward
from executorch.examples.models.llama.model_args import ModelArgs

args = ModelArgs()
args.dim = 32
args.n_heads = 8
Expand Down Expand Up @@ -1421,6 +1428,14 @@ def test_qnn_backend_amax(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_amin(self):
modules = [AMin(dim=1, keepdim=False), AMin(dim=1, keepdim=True)] # noqa: F405
sample_input = (torch.randn(4, 4),)
for i, module in enumerate(modules):
with self.subTest(i=i):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_any(self):
modules = [Any(), Any(dim=[0, 1]), Any(dim=1, keepdim=True)] # noqa: F405
sample_input = (torch.randn(3, 3, 3) > 0,)
Expand Down Expand Up @@ -2643,8 +2658,57 @@ def test_qnn_backend_einsum_outer_product_relu(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

@unittest.skipIf(is_qnn_sdk_version_less_than("2.35"), "UT pass after QNN 2.35")
def test_qnn_backend_masked_softmax(self):
if self.enable_x86_64:
self.skipTest(
"At the moment, testing is only being conducted on the device."
)
module = MaskedSoftmax() # noqa: F405
kv_arange = torch.arange(128)
reshaped_cache_position = torch.tensor([[0]])

# Simplest and most efficient way to obtain a causal mask
causal_mask = kv_arange <= reshaped_cache_position
atten_mask = torch.full((1, 128), torch.tensor(-65535.0))
atten_mask = atten_mask.masked_fill(causal_mask, 0)
atten_mask = atten_mask[None, None, :, :].expand(1, -1, -1, -1)
sample_input = (atten_mask, torch.randn([1, 1, 1, 128]))
# Masked softmax is only support in quantized model
module = self.get_qdq_module(
module, sample_input, quant_dtype=QuantDtype.use_16a8w
)
backend_options = generate_htp_compiler_spec(use_fp16=False)
compiler_spec = generate_qnn_executorch_compiler_spec(
soc_model=self.chipset_table[TestQNN.model],
backend_options=backend_options,
optrace=True,
)
with tempfile.TemporaryDirectory() as tmp_dir:
edge_prog_mgr = to_edge_transform_and_lower_to_qnn(
module, sample_input, compiler_spec
).to_executorch()
pte_path = f"{tmp_dir}/model.pte"
with open(pte_path, "wb") as f:
edge_prog_mgr.write_to_file(f)
adb = self.get_adb_tool(pte_path)
binaries_trace = generate_optrace(
tmp_dir, self.chipset_table[self.model], adb, pte_path, sample_input
)
has_masked_softmax = False
for _, (_, qhas) in binaries_trace.items():
with open(qhas, "r") as qhas_file:
qhas_data = json.load(qhas_file)
for row in qhas_data["data"]["htp_op_types"]["data"]:
if "MaskedSoftmax" in row["op"]:
has_masked_softmax = True
self.assertTrue(has_masked_softmax)

@unittest.skip("UT pass before QNN 2.26, segfault during partitioner")
def test_qnn_backend_moe_feed_forward(self):
from executorch.examples.models.llama.llama_transformer import MOEFeedForward
from executorch.examples.models.llama.model_args import ModelArgs

args = ModelArgs()
args.dim = 32
args.n_heads = 8
Expand Down
28 changes: 27 additions & 1 deletion backends/qualcomm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import operator
import os
import re
import warnings
from collections import defaultdict, OrderedDict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor

import executorch.exir as exir

import torch

from executorch.backends.qualcomm._passes import AnnotateStack, AnnotateUnbind
Expand Down Expand Up @@ -1167,3 +1168,28 @@ def rewrite_prepared_observer(
continue
for target_name in module_name_list[old_module]:
setattr(graph_module, target_name, new_observer)


def get_sdk_build_id():
htp_library_path = (
os.environ.get("QNN_SDK_ROOT", None) + "/lib/x86_64-linux-clang/libQnnHtp.so"
)
# The GetQnnSdkBuildId API can be used without needing to create a backend first, so it works regardless of which backend is used.
sdk_build_id = PyQnnManagerAdaptor.GetQnnSdkBuildId(htp_library_path)
return sdk_build_id


def is_qnn_sdk_version_less_than(target_version):
current_version = get_sdk_build_id()

match = re.search(r"v(\d+)\.(\d+)", current_version)
if match:
current_major, current_minor = map(int, match.groups()[:2])
else:
raise ValueError(
f"Failed to get current major and minor version from QNN sdk Build id {current_version}"
)

target_major, target_minor = map(int, target_version.split(".")[:2])

return current_major == target_major and current_minor < target_minor
9 changes: 9 additions & 0 deletions examples/qualcomm/oss_scripts/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,16 @@ On the other hand, if you already have a pre-compiled .pte model, you can perfor
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE}
```

#### KV Cache Updater

You can select the KV Cache update mechanism at runtime by setting the `KV_UPDATER` variable to either "shift_pointer" or "smart_mask". By default, it is set to "smart_mask".
`KV_UPDATER` = "shift_pointer"
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER}
```

#### Lookahead Decoding Mode

You can choose the lookahead mode to enhance decoding speed. To use this mode, you need to specify the following parameters:
- `--ngram` (N-gram size): Represents the size of the n-grams used in the lookahead process.
- `--window` (window size): Determines how many future tokens the algorithm attempts to predict in each step.
Expand All @@ -140,3 +144,8 @@ For more details, please refer to the paper ["Break the Sequential Dependency of
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode lookahead --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --ngram 3 --window 2 --gcap 2
```

#### Masked Softmax

You can enable MaskedSoftmax feature by providing the flag `--enable_masked_softmax`. It is designed to optimize the LLMs accuracy and performance executed on HTP backend. MaskedSoftmax is used to replace the Softmax(Add(In, Mask)) structure in attention block in LLMs during backend optimization. For more details, please refer to QNN documents.
Note that it is only supported starting from QNN 2.35.
Loading
Loading