From 5156b446873b9d13a2a0a3f411645ab2af19fadd Mon Sep 17 00:00:00 2001 From: haowhsu-quic Date: Wed, 18 Dec 2024 13:19:21 +0800 Subject: [PATCH] Qualcomm AI Engine Direct - op enablement summary - abs, arange, eq, full_like, eq, ge, gt, log, le, lt, min, max, repeat - test cases & bug fixes - support Snapdragon 8 Elite --- backends/qualcomm/README.md | 3 +- .../qualcomm/_passes/convert_to_linear.py | 6 +- backends/qualcomm/_passes/layout_transform.py | 17 +- .../_passes/recompose_pixel_unshuffle.py | 10 +- .../qualcomm/_passes/remove_redundancy.py | 46 +-- backends/qualcomm/builders/__init__.py | 24 ++ backends/qualcomm/builders/op_abs.py | 56 ++++ backends/qualcomm/builders/op_arange.py | 37 +++ backends/qualcomm/builders/op_conv2d.py | 2 +- backends/qualcomm/builders/op_eq.py | 98 +++++++ backends/qualcomm/builders/op_full_like.py | 37 +++ backends/qualcomm/builders/op_ge.py | 98 +++++++ backends/qualcomm/builders/op_gt.py | 98 +++++++ backends/qualcomm/builders/op_le.py | 98 +++++++ backends/qualcomm/builders/op_linear.py | 9 +- backends/qualcomm/builders/op_log.py | 57 ++++ backends/qualcomm/builders/op_lt.py | 98 +++++++ backends/qualcomm/builders/op_max.py | 61 ++++ backends/qualcomm/builders/op_min.py | 61 ++++ backends/qualcomm/builders/op_repeat.py | 67 +++++ .../qualcomm/builders/op_split_with_sizes.py | 10 +- backends/qualcomm/builders/op_to.py | 2 +- backends/qualcomm/builders/qnn_constants.py | 45 +++ backends/qualcomm/partition/common_defs.py | 8 +- .../qualcomm/partition/qnn_partitioner.py | 3 + backends/qualcomm/quantizer/annotators.py | 81 +++++- .../serialization/qc_compiler_spec.fbs | 2 + backends/qualcomm/serialization/qc_schema.py | 3 + backends/qualcomm/tests/models.py | 148 +++++++++- backends/qualcomm/tests/test_qnn_delegate.py | 265 +++++++++++++++++- backends/qualcomm/utils/utils.py | 3 + examples/qualcomm/utils.py | 2 + 32 files changed, 1500 insertions(+), 55 deletions(-) create mode 100644 backends/qualcomm/builders/op_abs.py create mode 100644 backends/qualcomm/builders/op_arange.py create mode 100644 backends/qualcomm/builders/op_eq.py create mode 100644 backends/qualcomm/builders/op_full_like.py create mode 100644 backends/qualcomm/builders/op_ge.py create mode 100644 backends/qualcomm/builders/op_gt.py create mode 100644 backends/qualcomm/builders/op_le.py create mode 100644 backends/qualcomm/builders/op_log.py create mode 100644 backends/qualcomm/builders/op_lt.py create mode 100644 backends/qualcomm/builders/op_max.py create mode 100644 backends/qualcomm/builders/op_min.py create mode 100644 backends/qualcomm/builders/op_repeat.py diff --git a/backends/qualcomm/README.md b/backends/qualcomm/README.md index a0cb5a5a502..9e1974bad6a 100644 --- a/backends/qualcomm/README.md +++ b/backends/qualcomm/README.md @@ -20,6 +20,7 @@ Please check `generate_qnn_executorch_compiler_spec()` in - Snapdragon 8 Gen 1+ - Snapdragon 8 Gen 2 - Snapdragon 8 Gen 3 +- Snapdragon 8 Elite ### Adding more supported Chipset Currently, users cannot add additional chipset models because the chipset ID is not accessible to community users. If you have specific chipset models you wish to add, please contact one of the authors in the `Code Reviews` section at the bottom of this page. @@ -120,11 +121,9 @@ PRs are always welcome to help improve the codebase in a comprehensive manner. B - **Code Reviews**:
Please ping authors in Qualcomm AI Engine Direct related PRs for reviewing, possible candidates are listed below: - - [chiwwang](https://github.com/chiwwang) - [shewu-quic](https://github.com/shewu-quic) - [chunit-quic](https://github.com/chunit-quic) - [winskuo-quic](https://github.com/winskuo-quic) - - [chuntl](https://github.com/chuntl) - [haowhsu-quic](https://github.com/haowhsu-quic) Thanks again for your contribution! diff --git a/backends/qualcomm/_passes/convert_to_linear.py b/backends/qualcomm/_passes/convert_to_linear.py index e7c4e8f9a92..87b9f8a74b8 100644 --- a/backends/qualcomm/_passes/convert_to_linear.py +++ b/backends/qualcomm/_passes/convert_to_linear.py @@ -110,11 +110,11 @@ def _convert_to_linear( # Since QNN has no keep dims for linear op, we will need to add squeeze and unsqueeze around linear node # TODO: Find a more general conditional statement. linear_output = linear_node.meta["val"] - if linear_output.dim() == 3 and linear_output.shape[0] == 1: + if linear_output.dim() >= 3: with gm.graph.inserting_after(input_node): input_users = list(input_node.users.keys()) input_tensor = input_node.meta["val"] - squeeze_dim = input_tensor.shape[-2:] + squeeze_dim = (-1, input_tensor.shape[-1]) squeeze_node = gm.graph.create_node( "call_function", self.view_copy, @@ -149,7 +149,7 @@ def _convert_to_linear( unsqueeze_node.meta[k] = v # update linear node's shape linear_node.meta["val"] = linear_output.reshape( - linear_output.shape[-2:] + (squeeze_node.meta["val"].shape[0], linear_output.shape[-1]) ) for user in output_users: user.replace_input_with(linear_node, unsqueeze_node) diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 66bada86dc1..098910ed86f 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -42,6 +42,7 @@ class LayoutTransform(ExportPass): } layout_agnostic_ops = { + exir_ops.edge.aten.abs.default, exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.bmm.default, exir_ops.edge.aten.cat.default, @@ -49,27 +50,41 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.clamp.default, exir_ops.edge.aten.constant_pad_nd.default, exir_ops.edge.aten.div.Tensor, + exir_ops.edge.aten.eq.Scalar, + exir_ops.edge.aten.eq.Tensor, exir_ops.edge.aten.full.default, + exir_ops.edge.aten.ge.Scalar, + exir_ops.edge.aten.ge.Tensor, exir_ops.edge.aten.gelu.default, + exir_ops.edge.aten.gt.Scalar, + exir_ops.edge.aten.gt.Tensor, exir_ops.edge.aten.hardswish.default, exir_ops.edge.aten.hardsigmoid.default, exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.leaky_relu.default, + exir_ops.edge.aten.le.Scalar, + exir_ops.edge.aten.le.Tensor, exir_ops.edge.aten.linear.default, + exir_ops.edge.aten.log.default, + exir_ops.edge.aten.lt.Scalar, + exir_ops.edge.aten.lt.Tensor, exir_ops.edge.aten._log_softmax.default, + exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.minimum.default, exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.pow.Tensor_Scalar, exir_ops.edge.aten.prelu.default, + exir_ops.edge.aten.repeat.default, exir_ops.edge.aten.relu.default, exir_ops.edge.aten._softmax.default, # TODO: Need to find a new solution to do "axis_order" to transform axis. exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.split_with_sizes.default, exir_ops.edge.aten.sqrt.default, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.sum.dim_IntList, exir_ops.edge.aten.topk.default, exir_ops.edge.aten._to_copy.default, - exir_ops.edge.aten.split_with_sizes.default, *q_ops, *dq_ops, _operator.getitem, diff --git a/backends/qualcomm/_passes/recompose_pixel_unshuffle.py b/backends/qualcomm/_passes/recompose_pixel_unshuffle.py index 00d46639089..7aac4fb823e 100644 --- a/backends/qualcomm/_passes/recompose_pixel_unshuffle.py +++ b/backends/qualcomm/_passes/recompose_pixel_unshuffle.py @@ -21,9 +21,8 @@ def __init__(self, quantization_capture=False): self.view_target = exir_ops.edge.aten.view_copy.default self.op = exir_ops.edge.aten.pixel_unshuffle.default - self.quantization_capture = quantization_capture if quantization_capture: - self.reshape_target = torch.ops.aten._unsafe_view.default + self.reshape_target = torch.ops.aten.reshape.default self.permute_target = torch.ops.aten.permute.default self.view_target = torch.ops.aten.view.default self.op = torch.ops.aten.pixel_unshuffle.default @@ -35,12 +34,7 @@ def call(self, graph_module: torch.fx.GraphModule): if node.op == "call_function" and node.target == self.reshape_target: with graph.inserting_after(node): - # Clone op still exists between permute and reshape_target during quantization, - # so we need to check for args[0].args[0] to get permute node - if self.quantization_capture: - premute_node = node.args[0].args[0] - else: - premute_node = node.args[0] + premute_node = node.args[0] if any( [ len(node.args[1]) != 4, diff --git a/backends/qualcomm/_passes/remove_redundancy.py b/backends/qualcomm/_passes/remove_redundancy.py index 2b14aed6c7f..07b13d4dd67 100644 --- a/backends/qualcomm/_passes/remove_redundancy.py +++ b/backends/qualcomm/_passes/remove_redundancy.py @@ -14,31 +14,37 @@ class RemoveRedundancy(ExportPass): Trim certain operators to reduce unnecessary overhead. """ - redundant_ops = { - torch.clone, - torch.ops.aten.clone.default, - exir_ops.edge.aten.clone.default, - torch.ops.aten.alias.default, - exir_ops.edge.aten.alias.default, - exir_ops.edge.aten.lift_fresh_copy.default, - # remove this target if '_skip_dim_order' is set to False - exir_ops.edge.dim_order_ops._to_dim_order_copy.default, - # remove channel_last / contiguous _to_copy if '_skip_dim_order' is set to True - exir_ops.edge.aten._to_copy.default, - } - def __init__(self): super(RemoveRedundancy, self).__init__() + self.redundant_ops = { + torch.clone: self._default_condition, + torch.ops.aten.clone.default: self._default_condition, + exir_ops.edge.aten.clone.default: self._default_condition, + torch.ops.aten.alias.default: self._default_condition, + exir_ops.edge.aten.alias.default: self._default_condition, + exir_ops.edge.aten.lift_fresh_copy.default: self._default_condition, + # remove this target if '_skip_dim_order' is set to False + exir_ops.edge.dim_order_ops._to_dim_order_copy.default: self._dim_order_op_condition, + # remove channel_last / contiguous _to_copy if '_skip_dim_order' is set to True + exir_ops.edge.aten._to_copy.default: self._to_copy_op_condition, + } + + def _dim_order_op_condition(self, node): + dim_order = node.kwargs.get("dim_order") + # skip if there contains layout hint + # e.g. (0, 2, 3, 1) != (0, 1, 2, 3) + return dim_order != list(range(len(dim_order))) + + def _to_copy_op_condition(self, node): + return "memory_format" in node.kwargs + + def _default_condition(self, ndoe): + return True def _remove(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: for n in graph_module.graph.nodes: - if n.target not in self.redundant_ops: - continue - - # do not remove cast operator - if ( - n.target == exir_ops.edge.aten._to_copy.default - and "memory_format" not in n.kwargs + if n.target not in self.redundant_ops or not self.redundant_ops[n.target]( + n ): continue diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index 1b853c60c5d..61ed30679e1 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -6,7 +6,9 @@ from . import ( node_visitor, + op_abs, op_add, + op_arange, op_avg_pool2d, op_batch_norm, op_bmm, @@ -19,26 +21,36 @@ op_dequantize, op_div, op_embedding, + op_eq, op_expand, + op_full_like, + op_ge, op_gelu, op_group_norm, + op_gt, op_hardsigmoid, op_hardswish, op_hardtanh, op_index, op_index_put, op_layer_norm, + op_le, op_linear, + op_log, op_log_softmax, + op_lt, op_matmul, + op_max, op_max_pool2d, op_mean_dim, + op_min, op_mul, op_pad, op_pow, op_prelu, op_quantize, op_relu, + op_repeat, op_reshape, op_rms_norm, op_rsqrt, @@ -65,7 +77,9 @@ __all__ = [ node_visitor, + op_abs, op_add, + op_arange, op_avg_pool2d, op_batch_norm, op_bmm, @@ -78,26 +92,36 @@ op_dequantize, op_div, op_embedding, + op_eq, op_expand, + op_full_like, + op_ge, op_gelu, op_group_norm, + op_gt, op_hardswish, op_hardtanh, op_hardsigmoid, op_index, op_index_put, op_layer_norm, + op_le, op_linear, + op_log, op_log_softmax, + op_lt, op_matmul, + op_max, op_max_pool2d, op_mean_dim, + op_min, op_mul, op_pad, op_pow, op_prelu, op_quantize, op_relu, + op_repeat, op_reshape, op_rms_norm, op_rsqrt, diff --git a/backends/qualcomm/builders/op_abs.py b/backends/qualcomm/builders/op_abs.py new file mode 100644 index 00000000000..002ffe85208 --- /dev/null +++ b/backends/qualcomm/builders/op_abs.py @@ -0,0 +1,56 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpElementWiseAbs, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Abs(NodeVisitor): + target = ["aten.abs.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + abs_output_tensors = [output_tensor_wrapper] + + input_node = node.args[0] + input_tensor_wrapper = self.define_tensor( + input_node, + node, + self.get_tensor(input_node, node), + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + abs_input_tensors = [input_tensor_wrapper] + + abs_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseAbs.op_name, + ) + abs_op.AddInputTensors(abs_input_tensors) + abs_op.AddOutputTensors(abs_output_tensors) + + return abs_op diff --git a/backends/qualcomm/builders/op_arange.py b/backends/qualcomm/builders/op_arange.py new file mode 100644 index 00000000000..b719b401191 --- /dev/null +++ b/backends/qualcomm/builders/op_arange.py @@ -0,0 +1,37 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor, register_node_visitor + + +@register_node_visitor +class Arange(NodeVisitor): + target = ["aten.arange.start_step"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + start, end = node.args[0:2] + step = node.args[2] if len(node.args) > 2 else 1 + out_tensor = torch.arange(start, end, step) + + self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + ) diff --git a/backends/qualcomm/builders/op_conv2d.py b/backends/qualcomm/builders/op_conv2d.py index 9daeab6d4bf..a6051636d3e 100644 --- a/backends/qualcomm/builders/op_conv2d.py +++ b/backends/qualcomm/builders/op_conv2d.py @@ -238,7 +238,7 @@ def _define_conv1d( padding_shape, dilation, dilation_shape, - groups, + groups=groups, ) op_wrapper_list.append(conv_op) diff --git a/backends/qualcomm/builders/op_eq.py b/backends/qualcomm/builders/op_eq.py new file mode 100644 index 00000000000..7717180d976 --- /dev/null +++ b/backends/qualcomm/builders/op_eq.py @@ -0,0 +1,98 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch +from executorch.backends.qualcomm.utils.constants import ( + QCOM_QUANT_ATTRS, + QCOM_QUANT_MAX, + QCOM_QUANT_MIN, + QCOM_SCALE, + QCOM_ZERO_POINT, +) +from executorch.exir.dialects._ops import ops as exir_ops + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpElementWiseEqual, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Equal(NodeVisitor): + target = ["aten.eq.Tensor", "aten.eq.Scalar"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + output_tensors = [output_tensor_wrapper] + + input_tensors = [] + for index in range(2): + input_node = node.args[index] + if isinstance(input_node, torch.fx.Node): + input_tensor = self.get_tensor(input_node, node) + tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + else: + scalar = input_node + input_tensor = torch.tensor(scalar, dtype=torch.float32) + tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC + + # 'graph', 'name', 'op', 'target', 'args', and 'kwargs' + input_node = torch.fx.Node( + node.graph, + node.name + "_runtime_scalar", + "call_function", + exir_ops.edge.aten.scalar_tensor.default, + (), # args + {}, # kwargs + ) + # Because the output data type of the ge node is boolean. + # We need to take the quant attr from the non-scalar node. + if quant_attrs := node.args[index ^ 1].meta.get(QCOM_QUANT_ATTRS): + quant_attrs = quant_attrs.copy() + quant_range = ( + quant_attrs[QCOM_QUANT_MAX] - quant_attrs[QCOM_QUANT_MIN] + ) + quant_attrs[QCOM_ZERO_POINT] = ( + 0 if scalar >= 0 else quant_attrs[QCOM_QUANT_MAX] + ) + quant_attrs[QCOM_SCALE] = ( + scalar / quant_range if scalar >= 0 else -scalar / quant_range + ) + input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs + + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + tensor_type, + nodes_to_wrappers, + ) + input_tensors.append(input_tensor_wrapper) + + eq_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseEqual.op_name, + ) + eq_op.AddInputTensors(input_tensors) + eq_op.AddOutputTensors(output_tensors) + + return eq_op diff --git a/backends/qualcomm/builders/op_full_like.py b/backends/qualcomm/builders/op_full_like.py new file mode 100644 index 00000000000..ab25d49446d --- /dev/null +++ b/backends/qualcomm/builders/op_full_like.py @@ -0,0 +1,37 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor, register_node_visitor + + +@register_node_visitor +class FullLike(NodeVisitor): + target = ["aten.full_like.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + in_tensor = node.args[0].meta["val"] + ref_tensor = torch.zeros(in_tensor.shape, dtype=in_tensor.dtype) + out_tensor = torch.full_like(ref_tensor, node.args[1]) + + self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + ) diff --git a/backends/qualcomm/builders/op_ge.py b/backends/qualcomm/builders/op_ge.py new file mode 100644 index 00000000000..552cab659cc --- /dev/null +++ b/backends/qualcomm/builders/op_ge.py @@ -0,0 +1,98 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch +from executorch.backends.qualcomm.utils.constants import ( + QCOM_QUANT_ATTRS, + QCOM_QUANT_MAX, + QCOM_QUANT_MIN, + QCOM_SCALE, + QCOM_ZERO_POINT, +) +from executorch.exir.dialects._ops import ops as exir_ops + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpElementWiseGreaterEqual, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class GreaterEqual(NodeVisitor): + target = ["aten.ge.Tensor", "aten.ge.Scalar"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + output_tensors = [output_tensor_wrapper] + + input_tensors = [] + for index in range(2): + input_node = node.args[index] + if isinstance(input_node, torch.fx.Node): + input_tensor = self.get_tensor(input_node, node) + tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + else: + scalar = input_node + input_tensor = torch.tensor(scalar, dtype=torch.float32) + tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC + + # 'graph', 'name', 'op', 'target', 'args', and 'kwargs' + input_node = torch.fx.Node( + node.graph, + node.name + "_runtime_scalar", + "call_function", + exir_ops.edge.aten.scalar_tensor.default, + (), # args + {}, # kwargs + ) + # Because the output data type of the ge node is boolean. + # We need to take the quant attr from the non-scalar node. + if quant_attrs := node.args[index ^ 1].meta.get(QCOM_QUANT_ATTRS): + quant_attrs = quant_attrs.copy() + quant_range = ( + quant_attrs[QCOM_QUANT_MAX] - quant_attrs[QCOM_QUANT_MIN] + ) + quant_attrs[QCOM_ZERO_POINT] = ( + 0 if scalar >= 0 else quant_attrs[QCOM_QUANT_MAX] + ) + quant_attrs[QCOM_SCALE] = ( + scalar / quant_range if scalar >= 0 else -scalar / quant_range + ) + input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs + + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + tensor_type, + nodes_to_wrappers, + ) + input_tensors.append(input_tensor_wrapper) + + ge_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseGreaterEqual.op_name, + ) + ge_op.AddInputTensors(input_tensors) + ge_op.AddOutputTensors(output_tensors) + + return ge_op diff --git a/backends/qualcomm/builders/op_gt.py b/backends/qualcomm/builders/op_gt.py new file mode 100644 index 00000000000..62bd79739ea --- /dev/null +++ b/backends/qualcomm/builders/op_gt.py @@ -0,0 +1,98 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch +from executorch.backends.qualcomm.utils.constants import ( + QCOM_QUANT_ATTRS, + QCOM_QUANT_MAX, + QCOM_QUANT_MIN, + QCOM_SCALE, + QCOM_ZERO_POINT, +) +from executorch.exir.dialects._ops import ops as exir_ops + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpElementWiseGreater, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class GreaterThan(NodeVisitor): + target = ["aten.gt.Tensor", "aten.gt.Scalar"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + output_tensors = [output_tensor_wrapper] + + input_tensors = [] + for index in range(2): + input_node = node.args[index] + if isinstance(input_node, torch.fx.Node): + input_tensor = self.get_tensor(input_node, node) + tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + else: + scalar = input_node + input_tensor = torch.tensor(scalar, dtype=torch.float32) + tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC + + # 'graph', 'name', 'op', 'target', 'args', and 'kwargs' + input_node = torch.fx.Node( + node.graph, + node.name + "_runtime_scalar", + "call_function", + exir_ops.edge.aten.scalar_tensor.default, + (), # args + {}, # kwargs + ) + # Because the output data type of the ge node is boolean. + # We need to take the quant attr from the non-scalar node. + if quant_attrs := node.args[index ^ 1].meta.get(QCOM_QUANT_ATTRS): + quant_attrs = quant_attrs.copy() + quant_range = ( + quant_attrs[QCOM_QUANT_MAX] - quant_attrs[QCOM_QUANT_MIN] + ) + quant_attrs[QCOM_ZERO_POINT] = ( + 0 if scalar >= 0 else quant_attrs[QCOM_QUANT_MAX] + ) + quant_attrs[QCOM_SCALE] = ( + scalar / quant_range if scalar >= 0 else -scalar / quant_range + ) + input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs + + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + tensor_type, + nodes_to_wrappers, + ) + input_tensors.append(input_tensor_wrapper) + + gt_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseGreater.op_name, + ) + gt_op.AddInputTensors(input_tensors) + gt_op.AddOutputTensors(output_tensors) + + return gt_op diff --git a/backends/qualcomm/builders/op_le.py b/backends/qualcomm/builders/op_le.py new file mode 100644 index 00000000000..6c93fc373e6 --- /dev/null +++ b/backends/qualcomm/builders/op_le.py @@ -0,0 +1,98 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch +from executorch.backends.qualcomm.utils.constants import ( + QCOM_QUANT_ATTRS, + QCOM_QUANT_MAX, + QCOM_QUANT_MIN, + QCOM_SCALE, + QCOM_ZERO_POINT, +) +from executorch.exir.dialects._ops import ops as exir_ops + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpElementWiseLessEqual, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class LessEqual(NodeVisitor): + target = ["aten.le.Tensor", "aten.le.Scalar"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + output_tensors = [output_tensor_wrapper] + + input_tensors = [] + for index in range(2): + input_node = node.args[index] + if isinstance(input_node, torch.fx.Node): + input_tensor = self.get_tensor(input_node, node) + tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + else: + scalar = input_node + input_tensor = torch.tensor(scalar, dtype=torch.float32) + tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC + + # 'graph', 'name', 'op', 'target', 'args', and 'kwargs' + input_node = torch.fx.Node( + node.graph, + node.name + "_runtime_scalar", + "call_function", + exir_ops.edge.aten.scalar_tensor.default, + (), # args + {}, # kwargs + ) + # Because the output data type of the ge node is boolean. + # We need to take the quant attr from the non-scalar node. + if quant_attrs := node.args[index ^ 1].meta.get(QCOM_QUANT_ATTRS): + quant_attrs = quant_attrs.copy() + quant_range = ( + quant_attrs[QCOM_QUANT_MAX] - quant_attrs[QCOM_QUANT_MIN] + ) + quant_attrs[QCOM_ZERO_POINT] = ( + 0 if scalar >= 0 else quant_attrs[QCOM_QUANT_MAX] + ) + quant_attrs[QCOM_SCALE] = ( + scalar / quant_range if scalar >= 0 else -scalar / quant_range + ) + input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs + + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + tensor_type, + nodes_to_wrappers, + ) + input_tensors.append(input_tensor_wrapper) + + le_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseLessEqual.op_name, + ) + le_op.AddInputTensors(input_tensors) + le_op.AddOutputTensors(output_tensors) + + return le_op diff --git a/backends/qualcomm/builders/op_linear.py b/backends/qualcomm/builders/op_linear.py index a16bbd28c98..71b6072b9e5 100644 --- a/backends/qualcomm/builders/op_linear.py +++ b/backends/qualcomm/builders/op_linear.py @@ -75,12 +75,19 @@ def define_node( f"[QNN Delegate Op Builder]: Fallback linear bias, {bias_node}. per channel bias quantization is not support yet.", stacklevel=1, ) + + bias_tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC bias_tensor = get_parameter(bias_node, self.edge_program) + # if bias_node is getitem + if bias_tensor is None: + bias_tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + bias_tensor = bias_node.meta["val"] + bias_tensor_wrapper = self.define_tensor( bias_node, node, bias_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + bias_tensor_type, nodes_to_wrappers, ) linear_input_tensors.append(bias_tensor_wrapper) diff --git a/backends/qualcomm/builders/op_log.py b/backends/qualcomm/builders/op_log.py new file mode 100644 index 00000000000..bcc40aa6268 --- /dev/null +++ b/backends/qualcomm/builders/op_log.py @@ -0,0 +1,57 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpElementWiseLog, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Log(NodeVisitor): + target = ["aten.log.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + log_inp_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + log_input_tensors = [log_inp_tensor_wrapper] + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + log_output_tensors = [output_tensor_wrapper] + + log_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseLog.op_name, + ) + log_op.AddInputTensors(log_input_tensors) + log_op.AddOutputTensors(log_output_tensors) + + return log_op diff --git a/backends/qualcomm/builders/op_lt.py b/backends/qualcomm/builders/op_lt.py new file mode 100644 index 00000000000..9f7a290a6ad --- /dev/null +++ b/backends/qualcomm/builders/op_lt.py @@ -0,0 +1,98 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch +from executorch.backends.qualcomm.utils.constants import ( + QCOM_QUANT_ATTRS, + QCOM_QUANT_MAX, + QCOM_QUANT_MIN, + QCOM_SCALE, + QCOM_ZERO_POINT, +) +from executorch.exir.dialects._ops import ops as exir_ops + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpElementWiseLess, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class LessThan(NodeVisitor): + target = ["aten.lt.Tensor", "aten.lt.Scalar"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + output_tensors = [output_tensor_wrapper] + + input_tensors = [] + for index in range(2): + input_node = node.args[index] + if isinstance(input_node, torch.fx.Node): + input_tensor = self.get_tensor(input_node, node) + tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + else: + scalar = input_node + input_tensor = torch.tensor(scalar, dtype=torch.float32) + tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC + + # 'graph', 'name', 'op', 'target', 'args', and 'kwargs' + input_node = torch.fx.Node( + node.graph, + node.name + "_runtime_scalar", + "call_function", + exir_ops.edge.aten.scalar_tensor.default, + (), # args + {}, # kwargs + ) + # Because the output data type of the ge node is boolean. + # We need to take the quant attr from the non-scalar node. + if quant_attrs := node.args[index ^ 1].meta.get(QCOM_QUANT_ATTRS): + quant_attrs = quant_attrs.copy() + quant_range = ( + quant_attrs[QCOM_QUANT_MAX] - quant_attrs[QCOM_QUANT_MIN] + ) + quant_attrs[QCOM_ZERO_POINT] = ( + 0 if scalar >= 0 else quant_attrs[QCOM_QUANT_MAX] + ) + quant_attrs[QCOM_SCALE] = ( + scalar / quant_range if scalar >= 0 else -scalar / quant_range + ) + input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs + + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + tensor_type, + nodes_to_wrappers, + ) + input_tensors.append(input_tensor_wrapper) + + lt_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseLess.op_name, + ) + lt_op.AddInputTensors(input_tensors) + lt_op.AddOutputTensors(output_tensors) + + return lt_op diff --git a/backends/qualcomm/builders/op_max.py b/backends/qualcomm/builders/op_max.py new file mode 100644 index 00000000000..7d41358a266 --- /dev/null +++ b/backends/qualcomm/builders/op_max.py @@ -0,0 +1,61 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpElementWiseMaximum, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Max(NodeVisitor): + target = ["aten.maximum.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + min_output_tensors = [output_tensor_wrapper] + + min_input_tensors = [] + for index in range(2): + input_node = node.args[index] + input_tensor = self.get_tensor(input_node, node) + tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + tensor_type, + nodes_to_wrappers, + ) + min_input_tensors.append(input_tensor_wrapper) + + max_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseMaximum.op_name, + ) + max_op.AddInputTensors(min_input_tensors) + max_op.AddOutputTensors(min_output_tensors) + + return max_op diff --git a/backends/qualcomm/builders/op_min.py b/backends/qualcomm/builders/op_min.py new file mode 100644 index 00000000000..0df2796974d --- /dev/null +++ b/backends/qualcomm/builders/op_min.py @@ -0,0 +1,61 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpElementWiseMinimum, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Min(NodeVisitor): + target = ["aten.minimum.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + min_output_tensors = [output_tensor_wrapper] + + min_input_tensors = [] + for index in range(2): + input_node = node.args[index] + input_tensor = self.get_tensor(input_node, node) + tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + tensor_type, + nodes_to_wrappers, + ) + min_input_tensors.append(input_tensor_wrapper) + + min_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseMinimum.op_name, + ) + min_op.AddInputTensors(min_input_tensors) + min_op.AddOutputTensors(min_output_tensors) + + return min_op diff --git a/backends/qualcomm/builders/op_repeat.py b/backends/qualcomm/builders/op_repeat.py new file mode 100644 index 00000000000..9748f1e9619 --- /dev/null +++ b/backends/qualcomm/builders/op_repeat.py @@ -0,0 +1,67 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast, Dict, List + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import numpy as np +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpTile, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Repeat(NodeVisitor): + target = ["aten.repeat.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + multiples = cast(List[int], node.args[1]) + multiples_shape = [len(multiples)] + + tile_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpTile.op_name, + ) + tile_op.AddInputTensors([input_tensor_wrapper]) + tile_op.AddOutputTensors([output_tensor_wrapper]) + tile_op.AddTensorParam( + OpTile.param_multiples, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(multiples_shape), + multiples_shape, + np.array(multiples, dtype=np.uint32), + True, + ) + return tile_op diff --git a/backends/qualcomm/builders/op_split_with_sizes.py b/backends/qualcomm/builders/op_split_with_sizes.py index 8e75fd3c10d..629110b3084 100644 --- a/backends/qualcomm/builders/op_split_with_sizes.py +++ b/backends/qualcomm/builders/op_split_with_sizes.py @@ -64,9 +64,13 @@ def define_node( split_indices.append(sum) split_indices_shape = [len(split_indices)] - dim = cast(int, node.args[2]) - if dim < 0: - dim = dim % len(input_tensor.shape) + + if len(node.args) > 2: + dim = cast(int, node.args[2]) + if dim < 0: + dim = dim % len(input_tensor.shape) + else: + dim = 0 if QCOM_AXIS_ORDER in node.meta: dim = node.meta[QCOM_AXIS_ORDER].index(dim) diff --git a/backends/qualcomm/builders/op_to.py b/backends/qualcomm/builders/op_to.py index f5cfd4ecf6e..5fb016aef95 100644 --- a/backends/qualcomm/builders/op_to.py +++ b/backends/qualcomm/builders/op_to.py @@ -16,7 +16,7 @@ @register_node_visitor class To(NodeVisitor): - target = ["aten._to_copy.default"] + target = ["aten._to_copy.default", "dim_order_ops._to_dim_order_copy.default"] sufixed_8_offset_diff = 128 sufixed_16_offset_diff = 32768 epsilon = 1e-6 diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index cb48cf38ba5..f6b33b20a46 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -75,6 +75,11 @@ class OpDequantize: op_name: str = "Dequantize" +@dataclass(init=False, frozen=True) +class OpElementWiseAbs: + op_name: str = "ElementWiseAbs" + + @dataclass(init=False, frozen=True) class OpElementWiseAdd: op_name: str = "ElementWiseAdd" @@ -95,6 +100,46 @@ class OpElementWiseDivide: op_name: str = "ElementWiseDivide" +@dataclass(init=False, frozen=True) +class OpElementWiseEqual: + op_name: str = "ElementWiseEqual" + + +@dataclass(init=False, frozen=True) +class OpElementWiseGreater: + op_name: str = "ElementWiseGreater" + + +@dataclass(init=False, frozen=True) +class OpElementWiseGreaterEqual: + op_name: str = "ElementWiseGreaterEqual" + + +@dataclass(init=False, frozen=True) +class OpElementWiseLess: + op_name: str = "ElementWiseLess" + + +@dataclass(init=False, frozen=True) +class OpElementWiseLessEqual: + op_name: str = "ElementWiseLessEqual" + + +@dataclass(init=False, frozen=True) +class OpElementWiseLog: + op_name: str = "ElementWiseLog" + + +@dataclass(init=False, frozen=True) +class OpElementWiseMaximum: + op_name: str = "ElementWiseMaximum" + + +@dataclass(init=False, frozen=True) +class OpElementWiseMinimum: + op_name: str = "ElementWiseMinimum" + + @dataclass(init=False, frozen=True) class OpElementWiseMultiply: op_name: str = "ElementWiseMultiply" diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index 1c24d00390d..7f49cfb7867 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -9,7 +9,6 @@ not_supported_operator = [ - exir_ops.edge.aten.arange.start_step, exir_ops.edge.aten.clone.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.slice_scatter.default, @@ -19,12 +18,15 @@ to_be_implemented_operator = [ exir_ops.edge.aten.any.dim, - exir_ops.edge.aten.eq.Scalar, - exir_ops.edge.aten.full_like.default, exir_ops.edge.aten.logical_not.default, exir_ops.edge.aten.where.self, ] +constant_operator = [ + exir_ops.edge.aten.arange.start_step, + exir_ops.edge.aten.full_like.default, +] + allow_list_operator = [ _operator.getitem, ] diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index 05054cc5d8c..93b1d50f5fe 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -29,6 +29,7 @@ from .common_defs import ( allow_list_operator, + constant_operator, not_supported_operator, to_be_implemented_operator, ) @@ -82,6 +83,8 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool: op_wrapper = self.node_visitors[node.target.__name__].define_node( node, self.nodes_to_wrappers ) + if node.target in constant_operator: + return True op_wrapper_list = [] if isinstance(op_wrapper, List): diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 68d512a4e09..e1792cb1830 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -172,6 +172,31 @@ def annotate_sub(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) +@register_annotator([torch.ops.aten.eq.Scalar, torch.ops.aten.eq.Tensor]) +def annotate_eq(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_binary(node, quantization_config) + + +@register_annotator([torch.ops.aten.ge.Scalar, torch.ops.aten.ge.Tensor]) +def annotate_ge(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_binary(node, quantization_config) + + +@register_annotator([torch.ops.aten.gt.Scalar, torch.ops.aten.gt.Tensor]) +def annotate_gt(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_binary(node, quantization_config) + + +@register_annotator([torch.ops.aten.le.Scalar, torch.ops.aten.le.Tensor]) +def annotate_le(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_binary(node, quantization_config) + + +@register_annotator([torch.ops.aten.lt.Scalar, torch.ops.aten.lt.Tensor]) +def annotate_lt(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_binary(node, quantization_config) + + @register_annotator( [torch.ops.aten.mul, torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar] ) @@ -179,6 +204,16 @@ def annotate_mul(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) +@register_annotator([torch.ops.aten.max.other, torch.ops.aten.maximum.default]) +def annotate_max(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_binary(node, quantization_config) + + +@register_annotator([torch.ops.aten.min.other, torch.ops.aten.minimum.default]) +def annotate_min(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_binary(node, quantization_config) + + @register_annotator( [torch.ops.aten.div, torch.ops.aten.div.Tensor, torch.ops.aten.divide.Tensor] ) @@ -256,6 +291,32 @@ def annotate_sum(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) +@register_annotator([torch.ops.aten.abs.default]) +def annotate_abs(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in_single_out(node, quantization_config) + + +@register_annotator( + [ + torch.torch.ops.aten.arange.default, + torch.torch.ops.aten.arange.start, + torch.torch.ops.aten.arange.start_step, + ] +) +def annotate_arange(node: Node, quantization_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + + if _is_float_tensor(node): + # workaround for node with kwargs could not be correctly annotated + node.kwargs = {} + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map={}, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + + @register_annotator([torch.ops.aten.ceil.default]) def annotate_ceil(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) @@ -271,6 +332,11 @@ def annotate_relu(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.repeat.default]) +def annotate_repeat(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in_single_out(node, quantization_config) + + @register_annotator([torch.ops.aten.cos.default]) def annotate_cos(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) @@ -286,6 +352,14 @@ def annotate_tanh(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.full_like.default]) +def annotate_full_like(node: Node, quantization_config: QuantizationConfig) -> None: + if _is_float_tensor(node): + # workaround for node with kwargs could not be correctly annotated + node.kwargs = {} + annotate_single_in_single_out(node, quantization_config) + + @register_annotator( [torch.ops.aten.hardswish.default, torch.ops.aten.hardswish_.default] ) @@ -403,6 +477,11 @@ def annotate_log_softmax(node: Node, quantization_config: QuantizationConfig) -> annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.log.default]) +def annotate_log(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in_single_out(node, quantization_config) + + @register_annotator([torch.ops.aten.pad.default]) def annotate_pad(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) @@ -828,7 +907,7 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None node.meta["source_fn_stack"] = [(node, torch.nn.Linear)] -@register_annotator([torch.ops.aten._native_batch_norm_legit_no_training.default]) +@register_annotator([torch.ops.aten.batch_norm.default]) def annotate_batch_norm(node: Node, quantization_config: QuantizationConfig) -> None: act, weight, bias = node.args[0:3] if _is_annotated([node]): diff --git a/backends/qualcomm/serialization/qc_compiler_spec.fbs b/backends/qualcomm/serialization/qc_compiler_spec.fbs index 8dd0b93513f..963a4b19fa4 100644 --- a/backends/qualcomm/serialization/qc_compiler_spec.fbs +++ b/backends/qualcomm/serialization/qc_compiler_spec.fbs @@ -17,6 +17,7 @@ enum HtpArch: int { V69 = 69, V73 = 73, V75 = 75, + V79 = 79, } table HtpInfo { @@ -37,6 +38,7 @@ enum QcomChipset: int { SSG2115P = 46, SM8650 = 57, SA8295 = 39, + SM8750 = 69 } /// Indicate the information of the specified SoC. diff --git a/backends/qualcomm/serialization/qc_schema.py b/backends/qualcomm/serialization/qc_schema.py index e03cc842100..25672efccae 100644 --- a/backends/qualcomm/serialization/qc_schema.py +++ b/backends/qualcomm/serialization/qc_schema.py @@ -25,6 +25,7 @@ class HtpArch(IntEnum): V69 = 69 V73 = 73 V75 = 75 + V79 = 79 @dataclass @@ -42,6 +43,7 @@ class QcomChipset(IntEnum): SSG2115P = 46 # v73 SM8650 = 57 # v75 SA8295 = 39 # v68 + SM8750 = 69 # v79 @dataclass @@ -55,6 +57,7 @@ class SocInfo: QcomChipset.SM8475: SocInfo(QcomChipset.SM8475, HtpInfo(HtpArch.V69, 8)), QcomChipset.SM8550: SocInfo(QcomChipset.SM8550, HtpInfo(HtpArch.V73, 8)), QcomChipset.SM8650: SocInfo(QcomChipset.SM8650, HtpInfo(HtpArch.V75, 8)), + QcomChipset.SM8750: SocInfo(QcomChipset.SM8750, HtpInfo(HtpArch.V79, 8)), QcomChipset.SSG2115P: SocInfo(QcomChipset.SSG2115P, HtpInfo(HtpArch.V73, 2)), QcomChipset.SA8295: SocInfo(QcomChipset.SA8295, HtpInfo(HtpArch.V68, 8)), } diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 96aab87826f..d66aa34e5af 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -8,6 +8,14 @@ # module with related operator only +class Abs(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.abs(x) + + class Add(torch.nn.Module): def __init__(self): super().__init__() @@ -33,12 +41,20 @@ def forward(self, x): class Arange(torch.nn.Module): - def __init__(self, x): + def __init__(self, start, end, step, dtype): super().__init__() - self.x = x + self.start = start + self.end = end + self.step = step + self.dtype = dtype def forward(self, y): - return torch.arange(self.x, dtype=torch.float32) + y + return ( + torch.arange( + start=self.start, end=self.end, step=self.step, dtype=self.dtype + ) + + y + ) class AvgPoolModule(torch.nn.Module): @@ -511,6 +527,23 @@ def forward(self, x): return self.embedding(x) +class Equal(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x == y + + +class EqualConstant(torch.nn.Module): + def __init__(self, constant): + super().__init__() + self.constant = constant + + def forward(self, x): + return x == self.constant + + class ExpandCopy(torch.nn.Module): def __init__(self): super().__init__() @@ -519,6 +552,15 @@ def forward(self, x): return x.expand(3, 4) +class FullLike(torch.nn.Module): + def __init__(self, fill): + super().__init__() + self.fill = fill + + def forward(self, x): + return torch.min(x, torch.full_like(x, self.fill)) + + class Gelu(torch.nn.Module): def __init__(self): super().__init__() @@ -528,6 +570,40 @@ def forward(self, x): return self.gelu(x) +class GreaterEqual(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x >= y + + +class GreaterThan(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x > y + + +class GreaterEqualConstant(torch.nn.Module): + def __init__(self, constant): + super().__init__() + self.constant = constant + + def forward(self, x): + return x >= self.constant + + +class GreaterThanConstant(torch.nn.Module): + def __init__(self, constant): + super().__init__() + self.constant = constant + + def forward(self, x): + return x > self.constant + + class GroupNorm(torch.nn.Module): def __init__(self, bias=True): super().__init__() @@ -636,6 +712,40 @@ def forward(self, x): return self.leaky_relu(x) +class LessEqual(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x <= y + + +class LessThan(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x < y + + +class LessEqualConstant(torch.nn.Module): + def __init__(self, constant): + super().__init__() + self.constant = constant + + def forward(self, x): + return x <= self.constant + + +class LessThanConstant(torch.nn.Module): + def __init__(self, constant): + super().__init__() + self.constant = constant + + def forward(self, x): + return self.constant < x + + class Linear(torch.nn.Module): def __init__(self, use_bias: bool = True): super().__init__() @@ -645,6 +755,14 @@ def forward(self, x): return self.linear(x) +class Log(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.log(x) + + class LogSoftmax(torch.nn.Module): def __init__(self): super().__init__() @@ -684,6 +802,22 @@ def forward(self, x): return torch.mean(x, (-1, -2)) +class Maximum(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.maximum(x, y) + + +class Minimum(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.minimum(x, y) + + class Mul(torch.nn.Module): def __init__(self): super().__init__() @@ -807,6 +941,14 @@ def forward(self, x): return self.relu(x) +class Repeat(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.repeat(1, 2, 3, 4) + + class Reshape(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 55596cf0386..a6af5335331 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -92,11 +92,20 @@ def setUp(self): shared_buffer=TestQNN.shared_buffer, ) - def test_qnn_backend_arange(self): - module = Arange(5) # noqa: F405 - sample_input = (torch.randn(5),) + def test_qnn_backend_abs(self): + module = Abs() # noqa: F405 + sample_input = (torch.randn(1, 2, 3, 4),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_arange(self): + modules = [ + Arange(start=1, end=11, step=1, dtype=torch.int32), # noqa: F405 + ] + sample_input = (torch.randn(10),) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_avg_pool2d(self): module = AvgPoolModule() # noqa: F405 sample_input = (torch.randn(1, 3, 2, 2),) @@ -311,16 +320,72 @@ def test_qnn_backend_embedding(self): sample_input = (torch.Tensor([[1, 2, 4, 5], [4, 3, 2, 9]]).to(torch.int32),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_equal(self): + test_comb = [ + { + QCOM_MODULE: Equal(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4)), + }, + { + QCOM_MODULE: EqualConstant(0.5), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(1, 2, 3, 4),), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + self.lower_module_and_test_output( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + def test_qnn_backend_expand_copy(self): module = ExpandCopy() # noqa: F405 sample_input = (torch.randn([3, 1]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_full_like(self): + module = FullLike(0.5) # noqa: F405 + sample_input = (torch.randn(1, 2, 3, 4),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_gelu(self): module = Gelu() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_greater_equal(self): + test_comb = [ + { + QCOM_MODULE: GreaterEqual(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4)), + }, + { + QCOM_MODULE: GreaterEqualConstant(0.5), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(1, 2, 3, 4),), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + self.lower_module_and_test_output( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + + def test_qnn_backend_greater_than(self): + test_comb = [ + { + QCOM_MODULE: GreaterThan(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4)), + }, + { + QCOM_MODULE: GreaterThanConstant(0.5), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(1, 2, 3, 4),), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + self.lower_module_and_test_output( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + def test_qnn_backend_group_norm(self): modules = [GroupNorm(), GroupNorm(bias=False)] # noqa: F405 sample_input = (torch.randn(3, 32, 56, 56),) @@ -391,16 +456,60 @@ def test_qnn_backend_leaky_relu(self): self.lower_module_and_test_output(module, sample_input) index += 1 + def test_qnn_backend_less_equal(self): + test_comb = [ + { + QCOM_MODULE: LessEqual(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4)), + }, + { + QCOM_MODULE: LessEqualConstant(0.5), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(1, 2, 3, 4),), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + self.lower_module_and_test_output( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + + def test_qnn_backend_less_than(self): + test_comb = [ + { + QCOM_MODULE: LessThan(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4)), + }, + { + QCOM_MODULE: LessThanConstant(0.5), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(1, 2, 3, 4),), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + self.lower_module_and_test_output( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + def test_qnn_backend_linear(self): module = Linear() # noqa: F405 sample_input = (torch.randn([3, 4]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_log(self): + module = Log() # noqa: F405 + sample_input = (torch.rand([1, 2, 3, 4]),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_log_softmax(self): module = LogSoftmax() # noqa: F405 sample_input = (torch.randn([1, 4, 8, 8]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_maximum(self): + module = Maximum() # noqa: F405 + sample_input = (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4)) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_max_pool2d(self): module = MaxPool2d() # noqa: F405 sample_input = (torch.randn(4, 3, 24, 24),) @@ -419,6 +528,11 @@ def test_qnn_backend_mha(self): sample_input = (torch.randn(1, 197, 96),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_minimum(self): + module = Minimum() # noqa: F405 + sample_input = (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4)) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_pad(self): module = Pad() # noqa: F405 sample_input = (torch.randn([1, 8, 128]),) @@ -464,6 +578,11 @@ def test_qnn_backend_relu(self): sample_input = (torch.randn([2, 5, 1, 3]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_repeat(self): + module = Repeat() # noqa: F405 + sample_input = (torch.randn([2, 2, 2, 2]),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_reshape(self): module = Reshape() # noqa: F405 sample_input = (torch.randn([3, 4]),) @@ -790,12 +909,22 @@ def test_qnn_backend_16a4w_per_channel_linear_with_bias(self): ) self.lower_module_and_test_output(module, sample_input) - def test_qnn_backend_arange(self): - module = Arange(5) # noqa: F405 - sample_input = (torch.randn(5),) + def test_qnn_backend_abs(self): + module = Abs() # noqa: F405 + sample_input = (torch.randn(1, 2, 3, 4),) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_arange(self): + modules = [ + Arange(start=1, end=6, step=0.5, dtype=torch.float32), # noqa: F405 + ] + sample_input = (torch.randn(10),) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_avg_pool2d(self): module = AvgPoolModule() # noqa: F405 sample_input = (torch.randn(1, 3, 2, 2),) @@ -1025,18 +1154,78 @@ def test_qnn_backend_embedding(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_equal(self): + test_comb = [ + { + QCOM_MODULE: Equal(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4)), + }, + { + QCOM_MODULE: EqualConstant(0.5), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(1, 2, 3, 4),), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + module = self.get_qdq_module( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + self.lower_module_and_test_output(module, test[QCOM_SAMPLE_INPUTS]) + def test_qnn_backend_expand_copy(self): module = ExpandCopy() # noqa: F405 sample_input = (torch.randn([3, 1]),) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_full_like(self): + module = FullLike(0.5) # noqa: F405 + sample_input = (torch.randn(1, 2, 3, 4),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_gelu(self): module = Gelu() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_greater_equal(self): + test_comb = [ + { + QCOM_MODULE: GreaterEqual(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4)), + }, + { + QCOM_MODULE: GreaterEqualConstant(0.5), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(1, 2, 3, 4),), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + module = self.get_qdq_module( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + self.lower_module_and_test_output(module, test[QCOM_SAMPLE_INPUTS]) + + def test_qnn_backend_greater_than(self): + test_comb = [ + { + QCOM_MODULE: GreaterThan(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4)), + }, + { + QCOM_MODULE: GreaterThanConstant(0.5), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(1, 2, 3, 4),), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + module = self.get_qdq_module( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + self.lower_module_and_test_output(module, test[QCOM_SAMPLE_INPUTS]) + def test_qnn_backend_group_norm(self): modules = [GroupNorm(), GroupNorm(bias=False)] # noqa: F405 sample_input = (torch.randn(3, 32, 56, 56),) @@ -1117,6 +1306,42 @@ def test_qnn_backend_leaky_relu(self): self.lower_module_and_test_output(module, sample_input) index += 1 + def test_qnn_backend_less_equal(self): + test_comb = [ + { + QCOM_MODULE: LessEqual(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4)), + }, + { + QCOM_MODULE: LessEqualConstant(0.5), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(1, 2, 3, 4),), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + module = self.get_qdq_module( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + self.lower_module_and_test_output(module, test[QCOM_SAMPLE_INPUTS]) + + def test_qnn_backend_less_than(self): + test_comb = [ + { + QCOM_MODULE: LessThan(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4)), + }, + { + QCOM_MODULE: LessThanConstant(0.5), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(1, 2, 3, 4),), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + module = self.get_qdq_module( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + self.lower_module_and_test_output(module, test[QCOM_SAMPLE_INPUTS]) + def test_qnn_backend_linear(self): module = Linear() # noqa: F405 sample_input = (torch.randn([3, 4]),) @@ -1133,12 +1358,24 @@ def test_qnn_backend_linear_qat(self): module = self.get_converted_sgd_trained_module(module, prepared, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_log(self): + module = Log() # noqa: F405 + sample_input = (torch.rand([1, 2, 3, 4]),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_log_softmax(self): module = LogSoftmax() # noqa: F405 sample_input = (torch.randn([1, 4, 8, 8]),) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_maximum(self): + module = Maximum() # noqa: F405 + sample_input = (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4)) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_max_pool2d(self): module = MaxPool2d() # noqa: F405 sample_input = (torch.randn(4, 3, 24, 24),) @@ -1159,6 +1396,12 @@ def test_qnn_backend_mha(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_minimum(self): + module = Minimum() # noqa: F405 + sample_input = (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4)) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_pad(self): module = Pad() # noqa: F405 sample_input = (torch.randn([1, 8, 128]),) @@ -1210,6 +1453,12 @@ def test_qnn_backend_relu(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_repeat(self): + module = Repeat() # noqa: F405 + sample_input = (torch.randn([2, 2, 2, 2]),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_reshape(self): module = Reshape() # noqa: F405 sample_input = (torch.randn([3, 4]),) @@ -1854,7 +2103,7 @@ def test_qnn_backend_draw_graph(self): delegated_program = capture_program(module, sample_input) """ - This piece of code simulates the behavior of the final preprocessing step to obtain the op wrapper list. + This piece of code simulates the behavior of the final preprocessing step to obtain the op wrapper list. In practice, users need to set a breakpoint in the preprocessing step and use the DrawGraph tool to visualize the graph. """ qnn_compiler_passes = PassManager( @@ -2364,7 +2613,7 @@ def test_qnn_backend_draw_graph(self): delegated_program = capture_program(module, sample_input) """ - This piece of code simulates the behavior of the final preprocessing step to obtain the op wrapper list. + This piece of code simulates the behavior of the final preprocessing step to obtain the op wrapper list. In practice, users need to set a breakpoint in the preprocessing step and use the DrawGraph tool to visualize the graph. """ qnn_compiler_passes = PassManager( diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index e13705b3a8f..a4acae9585b 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -1059,6 +1059,7 @@ def generate_qnn_executorch_compiler_spec( SM8475(Snapdragon 8 Gen 1+) SM8550(Snapdragon 8 Gen 2) SM8650(Snapdragon 8 Gen 3) + SM8750(Snapdragon 8 Elite) backend_options: Options required by different backends. debug: Enable verbose logging. Disclaimer: this option must change in the near future. @@ -1148,6 +1149,7 @@ def generate_qnn_executorch_compiler_spec( def get_soc_to_arch_map(): return { "SSG2115P": HtpArch.V73, + "SM8750": HtpArch.V79, "SM8650": HtpArch.V75, "SM8550": HtpArch.V73, "SM8475": HtpArch.V69, @@ -1159,6 +1161,7 @@ def get_soc_to_arch_map(): def get_soc_to_chipset_map(): return { "SSG2115P": QcomChipset.SSG2115P, + "SM8750": QcomChipset.SM8750, "SM8650": QcomChipset.SM8650, "SM8550": QcomChipset.SM8550, "SM8475": QcomChipset.SM8475, diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index c2d2f002aa8..23e384dee1b 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -564,6 +564,8 @@ def generate_inputs(dest_path: str, file_name: str, inputs=None, input_list=None for idx, data in enumerate(inputs): for i, d in enumerate(data): file_name = f"{dest_path}/input_{idx}_{i}.raw" + if not isinstance(d, torch.Tensor): + d = torch.tensor(d) d.detach().numpy().tofile(file_name) input_files.append(file_name)