From 324bb9b0b0ae27393046c3e540a4b97d7e295f4e Mon Sep 17 00:00:00 2001 From: shewu Date: Thu, 20 Nov 2025 17:04:34 +0800 Subject: [PATCH 1/2] Qualcomm AI Engine Direct - Support PARSeq in floating point precision --- backends/qualcomm/_passes/__init__.py | 2 + backends/qualcomm/_passes/decompose_triu.py | 62 +++++++++++++++++++ .../_passes/lift_constant_scalar_operands.py | 1 + backends/qualcomm/_passes/qnn_pass_manager.py | 3 + backends/qualcomm/builders/README.md | 2 +- backends/qualcomm/builders/op_linear.py | 4 +- backends/qualcomm/scripts/download_qnn_sdk.py | 4 +- 7 files changed, 73 insertions(+), 5 deletions(-) create mode 100644 backends/qualcomm/_passes/decompose_triu.py diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 26b2bdc96c9..12370fbc049 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -24,6 +24,7 @@ from .decompose_roll import DecomposeRoll from .decompose_silu import DecomposeSilu from .decompose_threshold import DecomposeThreshold +from .decompose_triu import DecomposeTriu from .decompose_wrap_with_autocast import DecomposeWrapWithAutocast from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape from .fixed_linear_keep_dim import FixedLinearKeepDim @@ -67,6 +68,7 @@ DecomposeRoll, DecomposeSilu, DecomposeThreshold, + DecomposeTriu, DecomposeWrapWithAutocast, ExpandBroadcastTensorShape, FixedLinearKeepDim, diff --git a/backends/qualcomm/_passes/decompose_triu.py b/backends/qualcomm/_passes/decompose_triu.py new file mode 100644 index 00000000000..554c0c98a02 --- /dev/null +++ b/backends/qualcomm/_passes/decompose_triu.py @@ -0,0 +1,62 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict +import torch +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx.experimental.proxy_tensor import make_fx +from torch._decomp import get_decompositions + +from .utils import merge_decomposed_graph + + +class DecomposeTriu(ExportPass): + """ + Decompose triu for quantization annotation to work properly. + """ + + def __init__(self) -> None: + super().__init__() + def _replace_output(self, node: torch.fx.Node, output_node: torch.fx.Node, remap: Dict): + for user in node.users.copy(): + # remap + user.replace_input_with( + node, + remap[output_node.args[0]], + ) + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + decom_mappings = get_decompositions( + [torch.ops.aten.triu.default] + ) + + for node in graph.nodes: + if node.target == torch.ops.aten.triu.default: + decomposed_module = make_fx( + node.target, + decomposition_table=decom_mappings, + tracing_mode="fake", + )(node.args[0].meta["val"], node.args[1]) + + with graph.inserting_before(node): + # remap is used to map original node values to new node values, + # which ensures that reference to nodes are correctly updated in the new graph + remap = {} + remap["arg0_1"] = node.args[0] + + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + predicate=lambda decomp_node: "arg1_1" not in decomp_node.name, + output_processor=self._replace_output, + ) + graph.erase_node(node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/lift_constant_scalar_operands.py b/backends/qualcomm/_passes/lift_constant_scalar_operands.py index 52bdf7fa090..e5d9371709d 100644 --- a/backends/qualcomm/_passes/lift_constant_scalar_operands.py +++ b/backends/qualcomm/_passes/lift_constant_scalar_operands.py @@ -55,6 +55,7 @@ class TensorOpInfo: aten.where.ScalarOther: TensorOpInfo(aten.where.self, False, True), aten.where.Scalar: TensorOpInfo(aten.where.self, False, True), aten.masked_fill.Scalar: TensorOpInfo(aten.masked_fill.Tensor, False, False), + aten.masked_fill_.Scalar: TensorOpInfo(aten.masked_fill.Tensor, False, False), aten.bitwise_xor.Scalar: TensorOpInfo(aten.bitwise_xor.Tensor, False, False), } diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 796662ca6b3..9f8e914bb41 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -29,6 +29,7 @@ DecomposeRoll, DecomposeSilu, DecomposeThreshold, + DecomposeTriu, DecomposeWrapWithAutocast, ExpandBroadcastTensorShape, FixedLinearKeepDim, @@ -203,6 +204,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeRoll()) self.add_pass(DecomposeSilu()) self.add_pass(DecomposeThreshold()) + self.add_pass(DecomposeTriu()) self.add_pass(DecomposeWrapWithAutocast()) self.add_pass(DecomposeEinsum()) self.add_pass(DecomposeExpM1()) @@ -221,6 +223,7 @@ def transform_for_export_pipeline( self.add_pass(DecomposeScaledDotProductAttention()) self.add_pass(DecomposeRoll()) self.add_pass(DecomposeThreshold()) + self.add_pass(DecomposeTriu()) self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) self.add_pass(DecomposeExpM1()) self.add_pass(DecomposeWrapWithAutocast()) diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md index 61ae1061214..5d629c73b51 100644 --- a/backends/qualcomm/builders/README.md +++ b/backends/qualcomm/builders/README.md @@ -37,7 +37,7 @@ class MyModel(torch.nn.Module): ``` At the time we try to lower it with Qualcomm backend: ```python -from excutorch.examples.qualcomm.utils import build_executorch_binary +from executorch.examples.qualcomm.utils import build_executorch_binary build_executorch_binary( model=MyModel(), diff --git a/backends/qualcomm/builders/op_linear.py b/backends/qualcomm/builders/op_linear.py index d5ac153b8d1..5d62901595e 100644 --- a/backends/qualcomm/builders/op_linear.py +++ b/backends/qualcomm/builders/op_linear.py @@ -56,12 +56,12 @@ def define_node( [-1, 1] ) - weight_tensor = get_parameter(weight_node, self.edge_program) + weight_tensor = self.get_tensor(weight_node, node) weight_tensor_wrapper = self.define_tensor( weight_node, node, weight_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) linear_input_tensors.append(weight_tensor_wrapper) diff --git a/backends/qualcomm/scripts/download_qnn_sdk.py b/backends/qualcomm/scripts/download_qnn_sdk.py index 747524a0e5b..adebecf0ebe 100644 --- a/backends/qualcomm/scripts/download_qnn_sdk.py +++ b/backends/qualcomm/scripts/download_qnn_sdk.py @@ -509,7 +509,7 @@ def _ensure_qnn_sdk_lib() -> bool: logger.info("[QNN] Loading %s", qnn_lib) lib_loaded = False try: - ctypes.CDLL(str(qnn_lib), mode=ctypes.RTLD_GLOBAL) + ctypes.CDLL(str(qnn_lib), mode=ctypes.RTLD_LOCAL) logger.info("[QNN] Loaded libQnnHtp.so from packaged SDK.") lib_loaded = True except OSError as e: @@ -528,7 +528,7 @@ def _load_libcxx_libs(lib_path): logger.debug("sorted_candidates: %s", sorted_candidates) for sofile in sorted_candidates: try: - ctypes.CDLL(str(sofile), mode=ctypes.RTLD_GLOBAL) + ctypes.CDLL(str(sofile), mode=ctypes.RTLD_LOCAL) logger.info("Loaded %s", sofile.name) except OSError as e: logger.warning("[WARN] Failed to load %s: %s", sofile.name, e) From e5d085ed71ee00a7c10e9595808cad3ba7b532b0 Mon Sep 17 00:00:00 2001 From: shewu Date: Fri, 21 Nov 2025 17:54:35 +0800 Subject: [PATCH 2/2] linting --- backends/qualcomm/_passes/decompose_triu.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/backends/qualcomm/_passes/decompose_triu.py b/backends/qualcomm/_passes/decompose_triu.py index 554c0c98a02..c97d38b7738 100644 --- a/backends/qualcomm/_passes/decompose_triu.py +++ b/backends/qualcomm/_passes/decompose_triu.py @@ -5,10 +5,11 @@ # LICENSE file in the root directory of this source tree. from typing import Dict + import torch from executorch.exir.pass_base import ExportPass, PassResult -from torch.fx.experimental.proxy_tensor import make_fx from torch._decomp import get_decompositions +from torch.fx.experimental.proxy_tensor import make_fx from .utils import merge_decomposed_graph @@ -20,18 +21,20 @@ class DecomposeTriu(ExportPass): def __init__(self) -> None: super().__init__() - def _replace_output(self, node: torch.fx.Node, output_node: torch.fx.Node, remap: Dict): + + def _replace_output( + self, node: torch.fx.Node, output_node: torch.fx.Node, remap: Dict + ): for user in node.users.copy(): # remap user.replace_input_with( node, remap[output_node.args[0]], ) + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: graph = graph_module.graph - decom_mappings = get_decompositions( - [torch.ops.aten.triu.default] - ) + decom_mappings = get_decompositions([torch.ops.aten.triu.default]) for node in graph.nodes: if node.target == torch.ops.aten.triu.default: