From 2a3c818497bb72d6fe646a9d91b5793f4a069264 Mon Sep 17 00:00:00 2001 From: winskuo-quic Date: Fri, 21 Feb 2025 11:39:06 +0800 Subject: [PATCH 1/2] Qualcomm AI Engine Direct - Moshi/Mimi Enablement --- backends/qualcomm/_passes/__init__.py | 14 +- backends/qualcomm/_passes/annotate_stack.py | 36 +++ .../_passes/convert_conv1d_to_conv2d.py | 97 +++++++ backends/qualcomm/_passes/decompose_expm1.py | 47 ++++ backends/qualcomm/_passes/decompose_silu.py | 13 +- backends/qualcomm/_passes/layout_transform.py | 2 + .../qualcomm/_passes/remove_empty_tensor.py | 35 +++ .../qualcomm/_passes/replace_arange_args.py | 49 ++++ .../qualcomm/_passes/replace_inf_buffer.py | 26 -- .../qualcomm/_passes/tensor_i64_to_i32.py | 11 +- backends/qualcomm/_passes/utils.py | 14 + backends/qualcomm/builders/__init__.py | 8 + backends/qualcomm/builders/op_and.py | 59 +++++ backends/qualcomm/builders/op_conv2d.py | 169 +----------- backends/qualcomm/builders/op_elu.py | 68 +++++ backends/qualcomm/builders/op_exp.py | 59 +++++ backends/qualcomm/builders/op_pad.py | 7 +- .../qualcomm/builders/op_scalar_tensor.py | 49 ++++ backends/qualcomm/builders/op_select_copy.py | 1 - backends/qualcomm/builders/op_sqrt.py | 4 +- backends/qualcomm/builders/qnn_constants.py | 26 +- backends/qualcomm/partition/common_defs.py | 3 + backends/qualcomm/quantizer/annotators.py | 67 ++++- backends/qualcomm/quantizer/quantizer.py | 11 +- backends/qualcomm/tests/models.py | 76 ++++++ backends/qualcomm/tests/test_qnn_delegate.py | 80 +++++- backends/qualcomm/utils/utils.py | 16 +- .../oss_scripts/moshi/install_requirments.sh | 15 ++ examples/qualcomm/oss_scripts/moshi/mimi.py | 241 ++++++++++++++++++ .../oss_scripts/moshi/moshi_example.py | 189 ++++++++++++++ examples/qualcomm/utils.py | 2 +- 31 files changed, 1259 insertions(+), 235 deletions(-) create mode 100644 backends/qualcomm/_passes/annotate_stack.py create mode 100644 backends/qualcomm/_passes/convert_conv1d_to_conv2d.py create mode 100644 backends/qualcomm/_passes/decompose_expm1.py create mode 100644 backends/qualcomm/_passes/remove_empty_tensor.py create mode 100644 backends/qualcomm/_passes/replace_arange_args.py delete mode 100644 backends/qualcomm/_passes/replace_inf_buffer.py create mode 100644 backends/qualcomm/builders/op_and.py create mode 100644 backends/qualcomm/builders/op_elu.py create mode 100644 backends/qualcomm/builders/op_exp.py create mode 100644 backends/qualcomm/builders/op_scalar_tensor.py create mode 100755 examples/qualcomm/oss_scripts/moshi/install_requirments.sh create mode 100644 examples/qualcomm/oss_scripts/moshi/mimi.py create mode 100644 examples/qualcomm/oss_scripts/moshi/moshi_example.py diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index c5499b52d80..d3761208fc9 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -1,11 +1,14 @@ from .annotate_decomposed import AnnotateDecomposed from .annotate_quant_attrs import AnnotateQuantAttrs +from .annotate_stack import AnnotateStack from .constant_i64_to_i32 import ConstantI64toI32 from .convert_bmm_to_matmul import ConvertBmmToMatmul +from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d from .convert_interpolate_with_upsample2d import ConvertInterpolateWithUpsample2D from .convert_to_linear import ConvertToLinear from .decompose_any import DecomposeAny from .decompose_einsum import DecomposeEinsum +from .decompose_expm1 import DecomposeExpM1 from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm from .decompose_silu import DecomposeSilu from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape @@ -19,22 +22,27 @@ from .recompose_prelu import RecomposePReLU from .recompose_rms_norm import RecomposeRmsNorm from .reduce_dynamic_range import ReduceDynamicRange +from .remove_empty_tensor import RemoveEmptyTensor from .remove_redundancy import RemoveRedundancy +from .replace_arange_args import ReplaceArangeArgs from .replace_index_put_input import ReplaceIndexPutInput -from .replace_inf_buffer import ReplaceInfBuffer +from .replace_inf_values import ReplaceInfValues from .tensor_i64_to_i32 import TensorI64toI32 __all__ = [ AnnotateDecomposed, AnnotateQuantAttrs, + AnnotateStack, ConstantI64toI32, ConvertBmmToMatmul, + ConvertConv1dToConv2d, ConvertInterpolateWithUpsample2D, RecomposePReLU, ConvertToLinear, DecomposeAny, DecomposeEinsum, + DecomposeExpM1, DecomposeLinalgVectorNorm, DecomposeSilu, ExpandBroadcastTensorShape, @@ -47,8 +55,10 @@ RecomposePixelUnshuffle, RecomposeRmsNorm, ReduceDynamicRange, + RemoveEmptyTensor, RemoveRedundancy, + ReplaceArangeArgs, ReplaceIndexPutInput, - ReplaceInfBuffer, + ReplaceInfValues, TensorI64toI32, ] diff --git a/backends/qualcomm/_passes/annotate_stack.py b/backends/qualcomm/_passes/annotate_stack.py new file mode 100644 index 00000000000..e2565ba9356 --- /dev/null +++ b/backends/qualcomm/_passes/annotate_stack.py @@ -0,0 +1,36 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS +from executorch.exir.pass_base import ExportPass, PassResult + + +class AnnotateStack(ExportPass): + """ + During decomposition stage, some unsqueeze op will appear. + These unsqueeze op does not carry quant attributes and will need to use previous node's quant attributes + """ + + def __init__(self) -> None: + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + for node in graph.nodes: + if ( + node.meta.get("torch_fn", ("", ""))[1] + == "builtin_function_or_method.stack" + ): + if ( + QCOM_QUANT_ATTRS not in node.meta + and QCOM_QUANT_ATTRS in node.args[0].meta + ): + node.meta[QCOM_QUANT_ATTRS] = node.args[0].meta[QCOM_QUANT_ATTRS] + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py b/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py new file mode 100644 index 00000000000..61aad6efdbb --- /dev/null +++ b/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py @@ -0,0 +1,97 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + +from .utils import copy_meta + + +class ConvertConv1dToConv2d(ExportPass): + """ + Conv1d is not supported by QNN. + Change it to input -> unsqueeze -> conv2d -> squeeze -> output + """ + + def __init__(self, edge_program: torch.export.ExportedProgram): + super(ConvertConv1dToConv2d, self).__init__() + self.edge_program = edge_program + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + conv_op = exir_ops.edge.aten.convolution.default + for node in graph.nodes: + if node.target == conv_op and node.meta["val"].dim() == 3: + + input_node = node.args[0] + with graph_module.graph.inserting_after(input_node): + unsqueeze_op = exir_ops.edge.aten.unsqueeze_copy.default + unsqueeze_node = graph.create_node( + "call_function", + unsqueeze_op, + ( + input_node, + 2, + ), + ) + unsqueeze_node.meta = copy_meta(input_node.meta) + unsqueeze_node.meta["val"] = unsqueeze_node.meta["val"].unsqueeze(2) + with graph_module.graph.inserting_after(unsqueeze_node): + + filter_node = node.args[1] + filter_node.meta["val"] = ( + filter_node.meta["val"].unsqueeze(2).contiguous() + ) + filter_tensor = get_parameter(filter_node, self.edge_program) + # Wrap with nn.Parameter. In FP mode, unsqueeze will make output not a nn.Parameter, which makes program to fail during edge_program._validate() + filter_tensor = nn.Parameter(filter_tensor.unsqueeze(2)) + set_parameter(filter_tensor, filter_node, self.edge_program) + + bias_node = node.args[2] + stride = [1] + node.args[3] + padding = [0] + node.args[4] + dilation = [1] + node.args[5] + transpose = node.args[6] + output_padding = [0] + node.args[7] + groups = node.args[8] + + conv2d_node = graph.create_node( + "call_function", + conv_op, + ( + unsqueeze_node, + filter_node, + bias_node, + stride, + padding, + dilation, + transpose, + output_padding, + groups, + ), + ) + conv2d_node.meta = copy_meta(node.meta) + conv2d_node.meta["val"] = conv2d_node.meta["val"].unsqueeze(2) + + with graph_module.graph.inserting_after(conv2d_node): + squeeze_op = exir_ops.edge.aten.squeeze_copy.dims + squeeze_node = graph.create_node( + "call_function", + squeeze_op, + ( + conv2d_node, + [2], + ), + ) + squeeze_node.meta = copy_meta(node.meta) + for user in node.users.copy(): + user.replace_input_with(node, squeeze_node) + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/decompose_expm1.py b/backends/qualcomm/_passes/decompose_expm1.py new file mode 100644 index 00000000000..f9cdbc42bcc --- /dev/null +++ b/backends/qualcomm/_passes/decompose_expm1.py @@ -0,0 +1,47 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.pass_base import ExportPass, PassResult + +from .utils import copy_meta + + +class DecomposeExpM1(ExportPass): + """ + Decompose for expm1 to exponential and minus 1. + """ + + def __init__(self, quantization_capture=False) -> None: + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + for node in graph.nodes: + if node.target == torch.ops.aten.special_expm1.default: + input_node = node.args[0] + with graph_module.graph.inserting_after(input_node): + exp_op = torch.ops.aten.exp.default + exp_node = graph.create_node("call_function", exp_op, (input_node,)) + exp_node.meta = copy_meta(node.meta) + with graph_module.graph.inserting_after(exp_node): + sub_op = torch.ops.aten.sub.Tensor + sub_node = graph.create_node( + "call_function", + sub_op, + ( + exp_node, + 1, + ), + ) + sub_node.meta = copy_meta(node.meta) + for user in node.users.copy(): + user.replace_input_with(node, sub_node) + graph.erase_node(node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/decompose_silu.py b/backends/qualcomm/_passes/decompose_silu.py index 96c48920419..c3ac45a8d9d 100644 --- a/backends/qualcomm/_passes/decompose_silu.py +++ b/backends/qualcomm/_passes/decompose_silu.py @@ -3,22 +3,17 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict import torch from executorch.exir.pass_base import ExportPass, PassResult +from .utils import copy_meta + class DecomposeSilu(ExportPass): def __init__(self): super(DecomposeSilu, self).__init__() - def _copy_meta(self, meta: Dict): - copied = {} - for k, v in meta.items(): - copied[k] = v - return copied - def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph for node in graph.nodes: @@ -34,14 +29,14 @@ def call(self, graph_module: torch.fx.GraphModule): torch.ops.aten.sigmoid.default, (silu_node_input,), ) - sigmoid_node.meta = self._copy_meta(silu_node.meta) + sigmoid_node.meta = copy_meta(silu_node.meta) with graph_module.graph.inserting_after(sigmoid_node): mul_node = graph.create_node( "call_function", torch.ops.aten.mul.Tensor, (silu_node_input, sigmoid_node), ) - mul_node.meta = self._copy_meta(silu_node.meta) + mul_node.meta = copy_meta(silu_node.meta) for user in silu_node.users.copy(): user.replace_input_with(silu_node, mul_node) diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 967ae7afd2b..f3b8fd59065 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -53,7 +53,9 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.clamp.default, exir_ops.edge.aten.constant_pad_nd.default, exir_ops.edge.aten.div.Tensor, + exir_ops.edge.aten.elu.default, exir_ops.edge.aten.eq.Tensor, + exir_ops.edge.aten.exp.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.full_like.default, exir_ops.edge.aten.ge.Tensor, diff --git a/backends/qualcomm/_passes/remove_empty_tensor.py b/backends/qualcomm/_passes/remove_empty_tensor.py new file mode 100644 index 00000000000..8b1ae6d0450 --- /dev/null +++ b/backends/qualcomm/_passes/remove_empty_tensor.py @@ -0,0 +1,35 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class RemoveEmptyTensor(ExportPass): + """ + QNN does not allow 0D tensor, we remove the node that will output an empty tensor. + Before adding operations to the list of nodes to be removed, please ensure that it will not change the logic. + """ + + remove_ops = { + exir_ops.edge.aten.select.int, + } + + def __init__(self, quantization_capture=False) -> None: + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + for node in graph.nodes: + if node.target in self.remove_ops and len(node.meta["val"].shape) == 0: + for user_n in list(node.users.keys()): + user_n.replace_input_with(node, node.args[0]) + graph.erase_node(node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/replace_arange_args.py b/backends/qualcomm/_passes/replace_arange_args.py new file mode 100644 index 00000000000..a96e5091fb1 --- /dev/null +++ b/backends/qualcomm/_passes/replace_arange_args.py @@ -0,0 +1,49 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.pass_base import ExportPass, PassResult + +from .utils import copy_meta + + +class ReplaceArangeArgs(ExportPass): + """ + During annotation, kwargs for arange will be removed due to restrictions by quantizer. + This causes arange having no dtype, which means FP nodes might become an INT node during calibration. + This can cause calibration to fail since QDQ can only be applied on FP nodes but not INT nodes. + To hint the dtype, we provide step size as 1.0 instead of 1, which makes the node a fp node. + """ + + def __init__(self, quantization_capture=False) -> None: + super().__init__() + self.quantization_capture = quantization_capture + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + for node in graph.nodes: + if node.target == torch.ops.aten.arange.default: + if torch.is_floating_point(node.meta["val"]) and len(node.args) == 1: + with graph_module.graph.inserting_after(node): + step_arange_op = torch.torch.ops.aten.arange.start_step + step_arange_node = graph.create_node( + "call_function", + step_arange_op, + ( + 0, + node.args[0], + 1.0, + ), + ) + step_arange_node.meta = copy_meta(node.meta) + + for user in node.users.copy(): + user.replace_input_with(node, step_arange_node) + graph.erase_node(node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/replace_inf_buffer.py b/backends/qualcomm/_passes/replace_inf_buffer.py deleted file mode 100644 index 776bc9beeba..00000000000 --- a/backends/qualcomm/_passes/replace_inf_buffer.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import torch -from executorch.exir.pass_base import ExportPass, PassResult - - -class ReplaceInfBuffer(ExportPass): - """ - Due to limitation in Qnn, we need to change inf or -inf to arbitrary value in quantization. - """ - - def __init__(self): - super(ReplaceInfBuffer, self).__init__() - - def call(self, graph_module: torch.fx.GraphModule): - for buf_name, tensor in graph_module.named_buffers(): - if tensor.is_floating_point(): - tensor[tensor == float("inf")] = 255 - tensor[tensor == float("-inf")] = -255 - setattr(graph_module, buf_name, tensor) - - graph_module.recompile() - return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/tensor_i64_to_i32.py b/backends/qualcomm/_passes/tensor_i64_to_i32.py index b590e30884c..6950bc70ea2 100644 --- a/backends/qualcomm/_passes/tensor_i64_to_i32.py +++ b/backends/qualcomm/_passes/tensor_i64_to_i32.py @@ -24,6 +24,9 @@ class TensorI64toI32(ExportPass): cast_ops = { torch.ops.aten.argmin.default, + torch.ops.aten.arange.start_step, + torch.ops.aten.full.default, + torch.ops.aten.scalar_tensor.default, } def __init__(self, edge_program): @@ -45,6 +48,7 @@ def _cast_to_int32(self, core_ep: ExirExportedProgram): else: n.meta[QCOM_ORIG_DTYPE] = n.meta["val"].dtype continue + if n.target in self.cast_ops: node_val = n.meta["val"] if self._is_tensor_of_dtype(node_val, torch.int64): @@ -61,7 +65,12 @@ def _cast_to_int32(self, core_ep: ExirExportedProgram): cast_node.args = args for user in users: - user.replace_input_with(n, cast_node) + # This node is used to check dtype, which will cause lowering to fail since we are changing int64 to int32 + if ( + user.target + != torch.ops.aten._assert_tensor_metadata.default + ): + user.replace_input_with(n, cast_node) core_ep.exported_program._graph_signature = _get_updated_graph_signature( core_ep.exported_program._graph_signature, diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index 68056d53aca..8db33415672 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Dict + import torch from executorch.backends.qualcomm.builders.utils import get_parameter from executorch.backends.qualcomm.utils.constants import QCOM_ENCODING @@ -24,6 +26,13 @@ } +def copy_meta(meta: Dict): + copied = {} + for k, v in meta.items(): + copied[k] = v + return copied + + def get_quant_attrs( edge_program: torch.export.ExportedProgram, quant_node: torch.fx.Node ): @@ -60,8 +69,10 @@ def get_passes_dependency_for_capture_program(): from executorch.backends.qualcomm._passes import ( AnnotateDecomposed, AnnotateQuantAttrs, + AnnotateStack, ConstantI64toI32, ConvertBmmToMatmul, + ConvertConv1dToConv2d, ConvertInterpolateWithUpsample2D, ConvertToLinear, DecomposeAny, @@ -87,8 +98,10 @@ def get_passes_dependency_for_capture_program(): ConvertBmmToMatmul, ConvertInterpolateWithUpsample2D, ], + AnnotateStack: [FoldQDQ], ConstantI64toI32: [ConvertInterpolateWithUpsample2D], ConvertBmmToMatmul: [ConvertToLinear], + ConvertConv1dToConv2d: [FoldQDQ], ConvertInterpolateWithUpsample2D: [RemoveRedundancy], ConvertToLinear: [RecomposePixelUnshuffle], DecomposeAny: [RemoveRedundancy], @@ -97,6 +110,7 @@ def get_passes_dependency_for_capture_program(): FoldQDQ: [AnnotateQuantAttrs, AnnotateDecomposed], LayoutTransform: [ AnnotateQuantAttrs, + ConvertConv1dToConv2d, ExpandBroadcastTensorShape, ], RecomposePixelUnshuffle: [RemoveRedundancy], diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index a16d4fb5057..5e9573932bf 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -9,6 +9,7 @@ op_abs, op_adaptive_avg_pool2d, op_add, + op_and, op_arange, op_argmin, op_avg_pool2d, @@ -22,8 +23,10 @@ op_depth_to_space, op_dequantize, op_div, + op_elu, op_embedding, op_eq, + op_exp, op_expand, op_full, op_full_like, @@ -61,6 +64,7 @@ op_reshape, op_rms_norm, op_rsqrt, + op_scalar_tensor, op_select_copy, op_sigmoid, op_sin, @@ -88,6 +92,7 @@ op_abs, op_adaptive_avg_pool2d, op_add, + op_and, op_arange, op_argmin, op_avg_pool2d, @@ -101,8 +106,10 @@ op_depth_to_space, op_dequantize, op_div, + op_elu, op_embedding, op_eq, + op_exp, op_expand, op_full, op_full_like, @@ -140,6 +147,7 @@ op_reshape, op_rms_norm, op_rsqrt, + op_scalar_tensor, op_select_copy, op_sigmoid, op_sin, diff --git a/backends/qualcomm/builders/op_and.py b/backends/qualcomm/builders/op_and.py new file mode 100644 index 00000000000..44e6f2893f5 --- /dev/null +++ b/backends/qualcomm/builders/op_and.py @@ -0,0 +1,59 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpElementWiseAnd, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class OpAnd(NodeVisitor): + target = ["aten.bitwise_and.Tensor"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + and_output_tensors = [output_tensor_wrapper] + + and_input_tensors = [] + for index in range(2): + input_node = node.args[index] + input_tensor = self.get_tensor(input_node, node) + tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + tensor_type, + nodes_to_wrappers, + ) + and_input_tensors.append(input_tensor_wrapper) + and_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseAnd.op_name, + ) + and_op.AddInputTensors(and_input_tensors) + and_op.AddOutputTensors(and_output_tensors) + return and_op diff --git a/backends/qualcomm/builders/op_conv2d.py b/backends/qualcomm/builders/op_conv2d.py index a6051636d3e..933c12742cd 100644 --- a/backends/qualcomm/builders/op_conv2d.py +++ b/backends/qualcomm/builders/op_conv2d.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import warnings from typing import cast, Dict, List import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper @@ -17,8 +16,6 @@ from .qnn_constants import ( OpConv2d, OpDepthWiseConv2d, - OpExpandDims, - OpReshape, OpTransposeConv2d, QNN_OP_PACKAGE_NAME_QTI_AISW, ) @@ -102,176 +99,16 @@ def _add_conv_op_parameter( return conv_op - def _define_conv1d( - self, - node: torch.fx.Node, - nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: - """ - Conv1D is a special case for convolutional operation. QNN does not support Conv1D, therefore, - we need to cast from input -> Conv1d -> output to input -> unsqueeze -> Conv2d -> squeeze -> output. - """ - transpose_conv = cast(bool, node.args[6]) - if transpose_conv: - print("ConvTranspose1d is not yet supported") - return - - op_wrapper_list = [] # op_wrapper to return - unsqueeze_input_node = node.args[0] - input_quant_encoding, input_quant_configs = self.get_quant_encoding_conf( - unsqueeze_input_node, node - ) - - unsqueeze_input_tensor = self.get_tensor(unsqueeze_input_node, node) - unsqueeze_input_tensor_wrapper = self.define_tensor( - unsqueeze_input_node, - node, - unsqueeze_input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - nodes_to_wrappers, - ) - unsqueeze_output_tensor = unsqueeze_input_tensor.unsqueeze(1).contiguous() - dtype = self.get_data_type(unsqueeze_output_tensor, input_quant_configs) - unsqueeze_output_tensor_wrapper = self.define_custom_tensor_wrapper( - node_name=node.name + "_unsqueeze", - tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - dtype=dtype, - quant_encoding=input_quant_encoding, - quant_configs=input_quant_configs, - dims=unsqueeze_output_tensor.size(), - tensor=unsqueeze_output_tensor, - is_fake_tensor=True, - nodes_to_wrappers=nodes_to_wrappers, - ) - unsqueeze_op = PyQnnWrapper.PyQnnOpWrapper( - node.name + "_unsqueeze", - QNN_OP_PACKAGE_NAME_QTI_AISW, - OpExpandDims.op_name, - ) - unsqueeze_op.AddInputTensors([unsqueeze_input_tensor_wrapper]) - unsqueeze_op.AddOutputTensors([unsqueeze_output_tensor_wrapper]) - unsqueeze_op.AddScalarParam( - OpExpandDims.param_axis, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - {QCOM_DATA: np.uint32(1)}, - ) - op_wrapper_list.append(unsqueeze_op) - - filter_node = node.args[1] - filter_tensor = ( - get_parameter(filter_node, self.edge_program).unsqueeze(2).contiguous() - ) - filter_axis_order = (2, 3, 1, 0) - filter_tensor = filter_tensor.permute(dims=filter_axis_order).contiguous() - filter_tensor_wrapper = self.define_tensor( - filter_node, - node, - filter_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, - nodes_to_wrappers, - ) - conv_input_tensors = [unsqueeze_output_tensor_wrapper, filter_tensor_wrapper] - if node.args[2] is not None: - bias_node = node.args[2] - bias_tensor = get_parameter(bias_node, self.edge_program) - bias_tensor_wrapper = self.define_tensor( - bias_node, - node, - bias_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, - nodes_to_wrappers, - ) - conv_input_tensors.append(bias_tensor_wrapper) - - stride = [1] + cast(List[int], node.args[3]) - padding = [0] + cast(List[int], node.args[4]) - dilation = [1] + cast(List[int], node.args[5]) - groups = cast(int, node.args[8]) - - # args[6] = transposed - if cast(bool, node.args[6]): - warnings.warn( - "[QNN Delegate Op Builder]: Currently, No support for transposed convolution.", - stacklevel=1, - ) - return - - # args[7] = output padding - if not all(out_pad == 0 for out_pad in cast(List[int], node.args[7])): - warnings.warn( - "[QNN Delegate Op Builder]: QNN does not support output padding.", - stacklevel=1, - ) - return - - stride_shape = [len(stride)] - padding_shape = [2, 2] - dilation_shape = [len(dilation)] - - conv_op = PyQnnWrapper.PyQnnOpWrapper( - node.name + "_squeeze", - QNN_OP_PACKAGE_NAME_QTI_AISW, - OpConv2d.op_name, - ) - conv_output_tensor = self.get_tensor(node, node) - conv_output_tensor = conv_output_tensor.unsqueeze(1).contiguous() - dtype = self.get_data_type(conv_output_tensor, input_quant_configs) - conv_output_tensor_wrapper = self.define_custom_tensor_wrapper( - node_name=node.name + "_squeeze", - tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - dtype=dtype, - quant_encoding=input_quant_encoding, - quant_configs=input_quant_configs, - dims=conv_output_tensor.size(), - tensor=conv_output_tensor, - is_fake_tensor=True, - nodes_to_wrappers=nodes_to_wrappers, - ) - conv_op = self._add_conv_op_parameter( - OpConv2d, - conv_op, - conv_input_tensors, - [conv_output_tensor_wrapper], - stride, - stride_shape, - padding, - padding_shape, - dilation, - dilation_shape, - groups=groups, - ) - op_wrapper_list.append(conv_op) - - squeeze_op = PyQnnWrapper.PyQnnOpWrapper( - node.name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - OpReshape.op_name, - ) - squeeze_output_tensor = self.get_tensor(node, node) - squeeze_output_tensor_wrapper = self.define_tensor( - node, - node, - squeeze_output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - nodes_to_wrappers, - node_name=node.name, - ) - squeeze_op.AddInputTensors([conv_output_tensor_wrapper]) - squeeze_op.AddOutputTensors([squeeze_output_tensor_wrapper]) - op_wrapper_list.append(squeeze_op) - - return op_wrapper_list - def define_node( self, node: torch.fx.Node, nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - if get_parameter(node.args[1], self.edge_program).dim() == 3: - return self._define_conv1d(node, nodes_to_wrappers) - input_node = node.args[0] input_tensor = self.get_tensor(input_node, node) + assert ( + input_tensor.dim() == 4 + ), "All Conv Should be converted to Conv2D in ConvertConv1dToConv2d" input_tensor_wrapper = self.define_tensor( input_node, node, diff --git a/backends/qualcomm/builders/op_elu.py b/backends/qualcomm/builders/op_elu.py new file mode 100644 index 00000000000..f9cc089c7bb --- /dev/null +++ b/backends/qualcomm/builders/op_elu.py @@ -0,0 +1,68 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import numpy as np +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpElu, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Elu(NodeVisitor): + target = ["aten.elu.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + # tensor input + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + elu_input_tensors = [input_tensor_wrapper] + + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + elu_output_tensors = [output_tensor_wrapper] + + elu_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElu.op_name, + ) + elu_op.AddInputTensors(elu_input_tensors) + elu_op.AddOutputTensors(elu_output_tensors) + + if len(node.args) == 2: + elu_op.AddScalarParam( + OpElu.param_alpha, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + {QCOM_DATA: np.uint32(node.args[1])}, + ) + + return elu_op diff --git a/backends/qualcomm/builders/op_exp.py b/backends/qualcomm/builders/op_exp.py new file mode 100644 index 00000000000..8c4794c9725 --- /dev/null +++ b/backends/qualcomm/builders/op_exp.py @@ -0,0 +1,59 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpElementWiseExp, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Exp(NodeVisitor): + target = ["aten.exp.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + # tensor input + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + exp_input_tensors = [input_tensor_wrapper] + + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + exp_output_tensors = [output_tensor_wrapper] + + exp_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseExp.op_name, + ) + exp_op.AddInputTensors(exp_input_tensors) + exp_op.AddOutputTensors(exp_output_tensors) + + return exp_op diff --git a/backends/qualcomm/builders/op_pad.py b/backends/qualcomm/builders/op_pad.py index 10948859be9..bcc6ded6454 100644 --- a/backends/qualcomm/builders/op_pad.py +++ b/backends/qualcomm/builders/op_pad.py @@ -47,20 +47,19 @@ def define_node( nodes_to_wrappers, ) pad_output_tensors = [output_tensor_wrapper] - pad_amount_shape = [input_tensor.dim(), 2] # pytorch padding start from the last index pad_amount = np.reshape(cast(List[int], node.args[1]), (-1, 2))[::-1].astype( np.uint32 ) - # fullfill the pad amount for each idex of tensor + # fulfill the pad amount for each idex of tensor if zero_amounts := pad_amount_shape[0] - pad_amount.shape[0]: pad_amount = np.concatenate( (np.array([(0, 0)] * zero_amounts), pad_amount) ).astype(np.uint32) - if QCOM_AXIS_ORDER in node.meta: - pad_amount = np.transpose(pad_amount, node.meta[QCOM_AXIS_ORDER]) + pad_amount = pad_amount[list(node.meta[QCOM_AXIS_ORDER])] + pad_amount_val = node.args[2] pad_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_scalar_tensor.py b/backends/qualcomm/builders/op_scalar_tensor.py new file mode 100644 index 00000000000..14a95bfb5d0 --- /dev/null +++ b/backends/qualcomm/builders/op_scalar_tensor.py @@ -0,0 +1,49 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor, register_node_visitor + + +@register_node_visitor +class Arange(NodeVisitor): + target = ["scalar_tensor.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + val = node.args[0] + out_tensor = torch.tensor([val], dtype=node.meta["val"].dtype) + + # negative infinite + if torch.isinf(out_tensor)[0] and (out_tensor < 0): + out_tensor = torch.tensor( + [torch.finfo(torch.float32).min], dtype=node.meta["val"].dtype + ) + # positive infinite + elif torch.isinf(out_tensor)[0] and (out_tensor > 0): + out_tensor = torch.tensor( + [torch.finfo(torch.float32).max], dtype=node.meta["val"].dtype + ) + # since we can derive the constant value of current op in AoT stage + # we only build static tensor here for consumers of current node + # to correctly reference the data + self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + ) diff --git a/backends/qualcomm/builders/op_select_copy.py b/backends/qualcomm/builders/op_select_copy.py index 148888f1497..783bab42f01 100644 --- a/backends/qualcomm/builders/op_select_copy.py +++ b/backends/qualcomm/builders/op_select_copy.py @@ -84,5 +84,4 @@ def define_node( PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: np.uint32(math.pow(2, dim))}, ) - return stride_slice_op diff --git a/backends/qualcomm/builders/op_sqrt.py b/backends/qualcomm/builders/op_sqrt.py index dc6691460ca..030e6c3e10a 100644 --- a/backends/qualcomm/builders/op_sqrt.py +++ b/backends/qualcomm/builders/op_sqrt.py @@ -10,7 +10,7 @@ import torch from .node_visitor import NodeVisitor, register_node_visitor -from .qnn_constants import OpSqrt, QNN_OP_PACKAGE_NAME_QTI_AISW +from .qnn_constants import OpElementWiseSqrt, QNN_OP_PACKAGE_NAME_QTI_AISW @register_node_visitor @@ -51,7 +51,7 @@ def define_node( sqrt_op = PyQnnWrapper.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, - OpSqrt.op_name, + OpElementWiseSqrt.op_name, ) sqrt_op.AddInputTensors(sqrt_input_tensors) sqrt_op.AddOutputTensors(sqrt_output_tensors) diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 5e0b63d6d19..c7c2d8666b2 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -85,6 +85,11 @@ class OpElementWiseAdd: op_name: str = "ElementWiseAdd" +@dataclass(init=False, frozen=True) +class OpElementWiseAnd: + op_name: str = "ElementWiseAnd" + + @dataclass(init=False, frozen=True) class OpElementWiseCeil: op_name = "ElementWiseCeil" @@ -100,6 +105,11 @@ class OpElementWiseDivide: op_name: str = "ElementWiseDivide" +@dataclass(init=False, frozen=True) +class OpElementWiseExp: + op_name: str = "ElementWiseExp" + + @dataclass(init=False, frozen=True) class OpElementWiseEqual: op_name: str = "ElementWiseEqual" @@ -188,11 +198,22 @@ class OpElementWiseSelect: op_name = "ElementWiseSelect" +@dataclass(init=False, frozen=True) +class OpElementWiseSqrt: + op_name: str = "ElementWiseSquareRoot" + + @dataclass(init=False, frozen=True) class OpElementWiseSubtract: op_name = "ElementWiseSubtract" +@dataclass(init=False, frozen=True) +class OpElu: + op_name: str = "Elu" + param_alpha: str = "alpha" + + @dataclass(init=False, frozen=True) class OpExpandDims: op_name: str = "ExpandDims" @@ -418,11 +439,6 @@ class OpSplit: param_split_index: str = "split_index" -@dataclass(init=False, frozen=True) -class OpSqrt: - op_name: str = "ElementWiseSquareRoot" - - @dataclass(init=False, frozen=True) class OpSqueeze: op_name: str = "Squeeze" diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index 8254bb64db0..b427c59ce07 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -5,6 +5,8 @@ # LICENSE file in the root directory of this source tree. import _operator +import torch + from executorch.exir.dialects._ops import ops as exir_ops not_supported_operator = [ @@ -20,6 +22,7 @@ exir_ops.edge.aten.arange.start_step, exir_ops.edge.aten.full.default, exir_ops.edge.aten.full_like.default, + torch.ops.aten.scalar_tensor.default, ] allow_list_operator = [ diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index a232d231c27..c87f4686a80 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -331,7 +331,6 @@ def annotate_abs(node: Node, quantization_config: QuantizationConfig) -> None: def annotate_arange(node: Node, quantization_config: QuantizationConfig) -> None: if _is_annotated([node]): return - if _is_float_tensor(node): # workaround for node with kwargs could not be correctly annotated node.kwargs = {} @@ -378,6 +377,20 @@ def annotate_sin(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.scalar_tensor.default]) +def annotate_scalar_tensor(node: Node, quantization_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + if _is_float_tensor(node): + # workaround for node with kwargs could not be correctly annotated + node.kwargs = {} + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map={}, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + + @register_annotator([torch.ops.aten.tanh.default]) def annotate_tanh(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) @@ -712,6 +725,11 @@ def annotate_transpose(node: Node, quantization_config: QuantizationConfig) -> N annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.elu.default]) +def annotate_elu(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in_single_out(node, quantization_config) + + @register_annotator([torch.ops.aten.embedding.default]) def annotate_embedding(node: Node, quantization_config: QuantizationConfig) -> None: weight = node.args[0] @@ -758,6 +776,13 @@ def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> N ) +@register_annotator([torch.ops.aten.exp.default]) +def annotate_exp(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_in_out_obs_sharing_op(node, quantization_config) + if not _is_annotated([node]): + annotate_single_in_single_out(node, quantization_config) + + @register_annotator([torch.ops.aten.expand.default, torch.ops.aten.expand_as.default]) def annotate_expand(node: Node, quantization_config: QuantizationConfig) -> None: annotate_in_out_obs_sharing_op(node, quantization_config) @@ -889,6 +914,7 @@ def annotate_bmm(node: Node, quantization_config: QuantizationConfig) -> None: torch.ops.aten.conv2d.default, torch.ops.aten.conv1d.default, torch.ops.aten.conv_transpose2d.input, + torch.ops.aten.conv_transpose1d.default, ] ) def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: @@ -1078,19 +1104,20 @@ def annotate_unbind(node: Node, quantization_config: QuantizationConfig) -> None if _is_annotated([node]): return - input_qspec_map = {} - input_act = node.args[0] - assert isinstance(input_act, Node) - input_qspec_map[input_act] = quantization_config.input_activation - - node_tensor = node.meta.get("val") - if torch.is_tensor(node_tensor) and node_tensor.dtype == torch.int64: - return + # Seems like unbind.int can be either float or int. Only quant when input is float. + if _is_float_tensor(node.args[0]): + input_qspec_map = {} + input_act = node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = quantization_config.input_activation + node_tensor = node.meta.get("val") + if torch.is_tensor(node_tensor) and node_tensor.dtype == torch.int64: + return - node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - _annotated=True, - ) + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) @register_annotator([torch.ops.aten.split.Tensor, torch.ops.aten.chunk.default]) @@ -1136,3 +1163,17 @@ def annotate_where(node: Node, quantization_config: QuantizationConfig) -> None: _annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated([node]) + + +@register_annotator([torch.ops.aten.zeros.default]) +def annotate_zeros(node: Node, quantization_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + if _is_float_tensor(node): + # workaround for node with kwargs could not be correctly annotated + node.kwargs = {} + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map={}, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index f5f07f6a365..e983613c518 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -10,12 +10,14 @@ import torch from executorch.backends.qualcomm._passes import ( DecomposeEinsum, + DecomposeExpM1, DecomposeLinalgVectorNorm, DecomposeSilu, LiftConstantScalarOperands, RecomposePixelUnshuffle, ReduceDynamicRange, - ReplaceInfBuffer, + ReplaceArangeArgs, + ReplaceInfValues, ) from executorch.backends.transforms.decompose_sdpa import ( DecomposeScaledDotProductAttention, @@ -222,11 +224,16 @@ def set_per_channel_linear_quant(self, enable: bool) -> None: def transform_for_annotation(self, model: GraphModule) -> GraphModule: model = ReduceDynamicRange()(model).graph_module model = RecomposePixelUnshuffle(quantization_capture=True)(model).graph_module + model = ReplaceArangeArgs()(model).graph_module model = DecomposeScaledDotProductAttention()(model).graph_module model = DecomposeSilu()(model).graph_module model = DecomposeEinsum()(model).graph_module + model = DecomposeExpM1()(model).graph_module model = DecomposeLinalgVectorNorm(aten_dialect_capture=True)(model).graph_module - model = ReplaceInfBuffer()(model).graph_module + model = ReplaceInfValues()(model).graph_module + from executorch.backends.qualcomm.utils.utils import draw_graph + + draw_graph("checking", ".", model) model = LiftConstantScalarOperands()(model).graph_module return model diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index bdb5541353b..01f00fdcff8 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -6,6 +6,17 @@ import torch +# module with related operator only +class And(torch.nn.Module): + def __init__(self, pos, neg): + super().__init__() + self.pos = pos + self.neg = neg + + def forward(self, x, y): + bitwise_and = torch.bitwise_and(x, y).bool() + return torch.where(bitwise_and, self.pos, self.neg) + # module with related operator only class Abs(torch.nn.Module): @@ -455,6 +466,17 @@ def forward(self, x): return self.conv(x) +class ConvTranspose1dSingle(torch.nn.Module): + def __init__(self, bias=True): + super().__init__() + self.conv_transpose = torch.nn.ConvTranspose1d( + in_channels=1, out_channels=3, kernel_size=3, stride=2, padding=1, bias=bias + ) + + def forward(self, x): + return self.conv_transpose(x) + + class ConvTranspose2dSingle(torch.nn.Module): def __init__(self, bias=True): super().__init__() @@ -594,6 +616,15 @@ def forward(self, i, j): return torch.relu(torch.einsum("i,j->ij", i, j)) +class Elu(torch.nn.Module): + def __init__(self): + super().__init__() + self.elu = torch.nn.ELU(alpha=0.5) + + def forward(self, i): + return self.elu(i) + + class Embedding(torch.nn.Module): def __init__(self): super().__init__() @@ -638,6 +669,14 @@ def forward(self, x): return y.expand_as(x) +class ExpM1(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.special.expm1(x) + + class Full(torch.nn.Module): def __init__(self, fill, shape): super().__init__() @@ -1447,3 +1486,40 @@ def __init__(self, pos, neg): def forward(self, x): return torch.where(x >= torch.zeros(x.shape), self.pos, self.neg) + + +class WhereConstantOther(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.where(x >= 0, torch.ones(x.shape), 0) + + +class WhereConstantAll(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.where(x >= 0, 1, 0) * 2.0 + + +class WhereConstantInf(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.softmax( + torch.where(x >= 0, 0.1, float("-inf")), dim=-1 + ) + + +# Mimi Decoder has 0D tensor which QNN cannot handle. +class ZeroDimTensor(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + input1 = torch.zeros(1) + selected_element = torch.select(input1, 0, 0) + return torch.add(x, selected_element) diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 986243d7a9c..80e3becec79 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -169,7 +169,7 @@ def test_qnn_backend_clamp(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) - def test_qnn_backend_conv1d(self): + def test_qnn_backend_conv1ds(self): modules = [Conv1dSequential(), Conv1dSequential(bias=False)] # noqa: F405 sample_input = (torch.randn([1, 1, 3]),) for i, module in enumerate(modules): @@ -193,6 +193,16 @@ def test_qnn_backend_conv2d_channel_last(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv_transpose1d(self): + modules = [ + ConvTranspose1dSingle(), # noqa: F405 + ConvTranspose1dSingle(bias=False), # noqa: F405 + ] + sample_input = (torch.randn([1, 1, 3]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv_transpose2d(self): modules = [ ConvTranspose2dSingle(), # noqa: F405 @@ -253,6 +263,11 @@ def test_qnn_backend_element_wise_add(self): self.lower_module_and_test_output(module, sample_input) index += 1 + def test_qnn_backend_element_wise_and(self): + module = And(torch.tensor(1.7), torch.tensor(0.2)) # noqa: F405 + sample_input = (torch.tensor([1, 0, 1, 0], dtype=torch.bool), torch.tensor([1, 1, 0, 0], dtype=torch.bool),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_element_wise_ceil(self): module = Ceil() # noqa: F405 sample_input = (torch.randn([2, 5, 1, 3]),) @@ -340,6 +355,12 @@ def test_qnn_backend_element_wise_sub(self): self.lower_module_and_test_output(module, sample_input) index += 1 + @unittest.expectedFailure + def test_qnn_backend_elu(self): + module = Elu() # noqa: F405 + sample_input = (torch.randn(2, 5, 1, 3),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_embedding(self): module = Embedding() # noqa: F405 sample_input = (torch.Tensor([[1, 2, 4, 5], [4, 3, 2, 9]]).to(torch.int32),) @@ -369,6 +390,11 @@ def test_qnn_backend_expand(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_expm1(self): + sample_input = (torch.randn(3, 4, 5),) + module = ExpM1() # noqa: F405 + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_full(self): shape = (1, 2, 3, 4) module = Full(0.5, shape) # noqa: F405 @@ -763,10 +789,16 @@ def test_qnn_backend_where(self): modules = [ Where(), # noqa: F405 WhereConstant(torch.randn(3, 2), torch.randn(3, 2)), # noqa: F405 + WhereConstantOther(), # noqa: F405 + WhereConstantAll(), # noqa: F405 + WhereConstantInf(), # noqa: F405 ] sample_inputs = [ (torch.randn(3, 2), torch.randn(3, 2), torch.randn(3, 2)), (torch.randn(3, 2),), + (torch.randn(3, 2),), + (torch.randn(3, 2),), + (torch.randn(30, 20),), ] for i, module in enumerate(modules): self.lower_module_and_test_output(module, sample_inputs[i]) @@ -894,6 +926,11 @@ def test_qnn_backend_view_permute_matmul(self): sample_input = (torch.randn([1, 8, 512]), torch.randn([1, 2, 8, 256])) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_zero_dim_tensor(self): + module = ZeroDimTensor() # noqa: F405 + sample_input = (torch.randn(1, 256, 125),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_example_models(self): # TODO Fix MobileBertModelExample and TorchVisionViTModel instances = [ @@ -1121,6 +1158,17 @@ def test_qnn_backend_conv2d_channel_last(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv_transpose1d(self): + modules = [ + ConvTranspose1dSingle(), # noqa: F405 + ConvTranspose1dSingle(bias=False), # noqa: F405 + ] + sample_input = (torch.randn([1, 1, 3]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv_transpose2d(self): modules = [ ConvTranspose2dSingle(), # noqa: F405 @@ -1186,6 +1234,12 @@ def test_qnn_backend_element_wise_add(self): self.lower_module_and_test_output(module, sample_input) index += 1 + def test_qnn_backend_element_wise_and(self): + module = And(torch.tensor(1.7), torch.tensor(0.2)) # noqa: F405 + sample_input = (torch.tensor([1, 0, 1, 0], dtype=torch.bool), torch.tensor([1, 1, 0, 0], dtype=torch.bool),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_element_wise_ceil(self): module = Ceil() # noqa: F405 sample_input = (torch.randn([2, 5, 1, 3]),) @@ -1278,6 +1332,12 @@ def test_qnn_backend_element_wise_sub(self): self.lower_module_and_test_output(module, sample_input) index += 1 + def test_qnn_backend_elu(self): + module = Elu() # noqa: F405 + sample_input = (torch.randn(2, 5, 1, 3),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_embedding(self): module = Embedding() # noqa: F405 sample_input = (torch.Tensor([[1, 2, 4, 5], [4, 3, 2, 9]]).to(torch.int32),) @@ -1310,6 +1370,12 @@ def test_qnn_backend_expand(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_expm1(self): + sample_input = (torch.randn(3, 4, 5),) + module = ExpM1() # noqa: F405 + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_full(self): shape = (1, 2, 3, 4) module = Full(0.5, shape) # noqa: F405 @@ -1771,10 +1837,16 @@ def test_qnn_backend_where(self): modules = [ Where(), # noqa: F405 WhereConstant(torch.randn(3, 2), torch.randn(3, 2)), # noqa: F405 + # WhereConstantOther(), # noqa: F405 + WhereConstantAll(), # noqa: F405 + # WhereConstantInf(), # noqa: F405 ] sample_inputs = [ (torch.randn(3, 2), torch.randn(3, 2), torch.randn(3, 2)), (torch.randn(3, 2),), + # (torch.randn(3, 2),), + (torch.randn(3, 2),), + # (torch.randn(30, 20),), ] for i, module in enumerate(modules): module = self.get_qdq_module(module, sample_inputs[i]) @@ -1921,6 +1993,12 @@ def test_qnn_backend_view_permute_matmul(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_zero_dim_tensor(self): + module = ZeroDimTensor() # noqa: F405 + sample_input = (torch.randn(1, 256, 125),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_example_models(self): instances = [ { diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index edd674d6381..9bbe7e067cd 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -19,11 +19,14 @@ from executorch.backends.qualcomm._passes import ( AnnotateDecomposed, AnnotateQuantAttrs, + AnnotateStack, ConstantI64toI32, ConvertBmmToMatmul, + ConvertConv1dToConv2d, ConvertInterpolateWithUpsample2D, ConvertToLinear, DecomposeAny, + DecomposeExpM1, DecomposeLinalgVectorNorm, ExpandBroadcastTensorShape, FoldQDQ, @@ -32,6 +35,7 @@ RecomposePixelUnshuffle, RecomposePReLU, RecomposeRmsNorm, + RemoveEmptyTensor, RemoveRedundancy, ReplaceIndexPutInput, ) @@ -327,6 +331,7 @@ def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]: # The below super ops are supported by QNN skip_decompositions = [ torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.aten.elu.default, torch.ops.aten.instance_norm.default, torch.ops.aten.pixel_shuffle.default, torch.ops.aten.pixel_unshuffle.default, @@ -354,8 +359,10 @@ def get_capture_program_passes(): default_passes_and_setting = [ (AnnotateDecomposed, True), (AnnotateQuantAttrs, True), + (AnnotateStack, True), (ConstantI64toI32, True), (ConvertBmmToMatmul, True), + (ConvertConv1dToConv2d, True), (ConvertInterpolateWithUpsample2D, True), (ConvertToLinear, True), (DecomposeAny, True), @@ -366,6 +373,7 @@ def get_capture_program_passes(): (RecomposePReLU, True), (RecomposePixelUnshuffle, True), (RecomposeRmsNorm, True), + (RemoveEmptyTensor, True), (RemoveRedundancy, True), (ReplaceIndexPutInput, True), (TensorI64toI32, True), @@ -438,9 +446,11 @@ def _transform( def _preprocess_module(module: torch.nn.Module, inputs: Tuple[torch.Tensor]): if isinstance(module, torch.fx.graph_module.GraphModule): return module - module = torch.export.export(module, inputs, strict=True).module() + module = torch.export.export(module, inputs, strict=False).module() + module = DecomposeScaledDotProductAttention()(module).graph_module module = DecomposeLinalgVectorNorm(True)(module).graph_module + module = DecomposeExpM1()(module).graph_module module = LiftConstantScalarOperands()(module).graph_module return module @@ -452,7 +462,9 @@ def capture_program( dynamic_shapes: Dict = None, ) -> exir.ExirExportedProgram: module = _preprocess_module(module, inputs) - ep = torch.export.export(module, inputs, dynamic_shapes=dynamic_shapes, strict=True) + ep = torch.export.export( + module, inputs, dynamic_shapes=dynamic_shapes, strict=False + ) decomposed_ep = ep.run_decompositions(get_decomp_table()) core_ep = ExirExportedProgram(decomposed_ep, False) core_ep.transform(TensorI64toI32(edge_program=core_ep)) diff --git a/examples/qualcomm/oss_scripts/moshi/install_requirments.sh b/examples/qualcomm/oss_scripts/moshi/install_requirments.sh new file mode 100755 index 00000000000..9a7cc044ae6 --- /dev/null +++ b/examples/qualcomm/oss_scripts/moshi/install_requirments.sh @@ -0,0 +1,15 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -x + +pip install -U moshi +pip install bitsandbytes +# Run llama2/install requirements for torchao deps +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +bash "$SCRIPT_DIR"/../../../models/llama/install_requirements.sh \ No newline at end of file diff --git a/examples/qualcomm/oss_scripts/moshi/mimi.py b/examples/qualcomm/oss_scripts/moshi/mimi.py new file mode 100644 index 00000000000..f3bc6f26eec --- /dev/null +++ b/examples/qualcomm/oss_scripts/moshi/mimi.py @@ -0,0 +1,241 @@ +# Copyright (c) Kyutai, all rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# import argparse +import json +import os +import random +import time +from multiprocessing.connection import Client +import requests +import io +import torchaudio + +import numpy as np + +import sphn +import torch + +import torch.nn as nn +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype + +# from executorch.examples.models.llama.llama_transformer import Transformer + +# from executorch.examples.models.llama.model_args import ModelArgs + +from executorch.examples.qualcomm.utils import ( + build_executorch_binary, + make_output_dir, + parse_skip_delegation_node, + setup_common_args_and_variables, + SimpleADB, +) + +from huggingface_hub import hf_hub_download +from moshi.models import loaders + +from torch.profiler import profile, ProfilerActivity + + +def seed_all(seed): + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # for multi-GPU setups + random.seed(seed) + np.random.seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +def read_mp3_from_url(url): + response = requests.get(url) + response.raise_for_status() # Ensure request is successful + + # Convert to a file-like object + audio_stream = io.BytesIO(response.content) + + # Load audio using torchaudio + waveform, sample_rate = torchaudio.load(audio_stream, format="mp3") + + return waveform.numpy(), sample_rate + + +def mimi_encode(): + None +def mimi_decode(): + None +def mimi_test(mimi, args, max_duration_sec=10.0): + pcm_chunk_size = int(mimi.sample_rate / mimi.frame_rate) + sample_rate = mimi.sample_rate + url = "https://huggingface.co/lmz/moshi-swift/resolve/main/bria-24khz.mp3" + sample_pcm, sample_sr = read_mp3_from_url(url) + pcm_chunk_size = int(mimi.sample_rate / mimi.frame_rate) + sample_rate = mimi.sample_rate + sample_pcm = torch.tensor(sample_pcm, device='cpu') + max_duration_len = int(sample_rate * max_duration_sec) + if sample_pcm.shape[-1] > max_duration_len: + sample_pcm = sample_pcm[..., :max_duration_len] + sample_pcm = sample_pcm[None].to(device='cpu') + # sample_pcm = torch.ones(1, 1, 240000) + + print("streaming encoding...") + start_time = time.time() + all_codes = [] + + def run_loop(): + for start_idx in range(0, sample_pcm.shape[-1], pcm_chunk_size): + end_idx = min(sample_pcm.shape[-1], start_idx + pcm_chunk_size) + chunk = sample_pcm[..., start_idx:end_idx] + codes = mimi.encode(chunk) + if codes.shape[-1]: + print(start_idx, codes.shape, end="\r") + all_codes.append(codes) + if args.profile: + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + run_loop() + prof.export_chrome_trace("trace.json") + else: + run_loop() + all_codes_th = torch.cat(all_codes, dim=-1) + print(f"codes {all_codes_th.shape} generated in {time.time() - start_time:.2f}s") + print("streaming decoding...") + all_pcms = [] + with mimi.streaming(1): + for i in range(all_codes_th.shape[-1]): + codes = all_codes_th[..., i : i + 1] + pcm = mimi.decode(codes) + print(i, pcm.shape, end="\r") + all_pcms.append(pcm) + all_pcms = torch.cat(all_pcms, dim=-1) + pcm_ref = mimi.decode(all_codes_th) # same as mimi_decode(input[0]) + print("pcm", all_pcms.shape, all_pcms.dtype) + + assert torch.allclose(pcm_ref, all_pcms, atol=1e-5) + + class MimiDecode(nn.Module): + def __init__(self, mimi: nn.Module): + super().__init__() + self.mimi_model = mimi + + def forward(self, x): + return self.mimi_model.decode(x) + + mimi_decode = MimiDecode(mimi) + + skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + pte_filename = "mimi_qnn" + input = (all_codes_th.to(torch.int32),) + build_executorch_binary( + mimi_decode.eval(), + input, + args.model, + f"{args.artifact}/{pte_filename}", + [input], + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + quant_dtype=QuantDtype.use_8a8w, + shared_buffer=args.shared_buffer, + ) + + if args.compile_only: + return + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=f"{args.artifact}/{pte_filename}.pte", + workspace=f"/data/local/tmp/executorch/{pte_filename}", + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + ) + adb.push(inputs=[input], input_list="input_0_0.raw") + adb.execute() + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + + adb.pull(output_path=args.artifact) + + # top-k analysis + predictions = [] + predictions.append( + np.fromfile( + os.path.join(output_data_folder, "output_0_0.raw"), dtype=np.float32 + ) + ) + htp_res = torch.from_numpy(predictions[0]).view(1, 1, 240000) + cosine_sim = torch.nn.functional.cosine_similarity( + pcm_ref.flatten(), htp_res.flatten(), dim=0 + ).item() + print("Cos similarity: ", cosine_sim) + sphn.write_wav("streaming_out.wav", all_pcms[0, 0].cpu().numpy(), sample_rate) + sphn.write_wav("ref.wav", pcm_ref[0, 0].cpu().numpy(), sample_rate) + sphn.write_wav("htp.wav", htp_res[0,0].cpu().numpy(), sample_rate) + # With QNN 2.28.2 + # 0.9650231003761292 + # 8a8w cos similarity: 0.9635128378868103, 1 inference: 73.5ms, file size: ~59mb + # 16a16w cos similarity: failed at runner: Error from rpc transport, file size: ~104mb + # 16a4w cos similarity: failed at runner: Error from rpc transport, file size: ~53mb + # fp cos similarity: failed at runner: Error from rpc transport, file size: ~101mb (QNN 2.31.0 for this) + + class MimiEncode(nn.Module): + def __init__(self, mimi: nn.Module): + super().__init__() + self.mimi_model = mimi + + def forward(self, x): + return self.mimi_model.encode(x) + + mimi_encode = MimiEncode(mimi) + chunk = sample_pcm[..., 0:pcm_chunk_size] + out = mimi_encode(chunk) + exported_encode = torch.export.export(mimi_encode, (chunk,), strict=False).module() + + +def main(args): + seed_all(42424242) + + print("loading mimi") + if args.mimi_weight is None: + args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME) + mimi = loaders.get_mimi(args.mimi_weight, "cpu") + print("mimi loaded") + # emb = torch.load('emb.pt') + + with torch.no_grad(): + mimi_test(mimi, args) + +if __name__ == "__main__": + + parser = setup_common_args_and_variables() + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. Default ./ssd300_vgg16", + default="./mimi", + type=str, + ) + + parser.add_argument("--mimi-weight", type=str) + parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO) + # parser.add_argument( + # "--device", type=str, default="cpu" if torch.cuda.device_count() else "cpu" + # ) + parser.add_argument("--profile", action="store_true") + + args = parser.parse_args() + try: + main(args) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e) diff --git a/examples/qualcomm/oss_scripts/moshi/moshi_example.py b/examples/qualcomm/oss_scripts/moshi/moshi_example.py new file mode 100644 index 00000000000..edfda9f9aab --- /dev/null +++ b/examples/qualcomm/oss_scripts/moshi/moshi_example.py @@ -0,0 +1,189 @@ +# # Copyright (c) Qualcomm Innovation Center, Inc. +# # All rights reserved +# # +# # This source code is licensed under the BSD-style license found in the +# # LICENSE file in the root directory of this source tree. + +# import json +# import os +# import sys +# from multiprocessing.connection import Client + +# import numpy as np + +# import torch +# from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +# from executorch.examples.qualcomm.utils import ( +# build_executorch_binary, +# get_imagenet_dataset, +# make_output_dir, +# parse_skip_delegation_node, +# setup_common_args_and_variables, +# SimpleADB, +# topk_accuracy, +# ) + +# from huggingface_hub import hf_hub_download + +# # python examples/qualcomm/oss_scripts/moshi/moshi.py -b build-android -H mlgtw-linux2 -s acfa9311 -m SM8650 + + +# def main(args): +# pte_filename = "moshi_qnn" +# sys.path.insert(0, "../moshi/moshi/") +# from moshi.models import LMGen, loaders +# from moshi.modules.seanet import SEANetEncoder + +# skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + +# # ensure the working directory exist. +# os.makedirs(args.artifact, exist_ok=True) + +# # ---------------------------------method1------------------------ +# # _seanet_kwargs = { +# # "channels": 1, +# # "dimension": 512, +# # "causal": True, +# # "n_filters": 64, +# # "n_residual_layers": 1, +# # "activation": "ELU", +# # "compress": 2, +# # "dilation_base": 2, +# # "disable_norm_outer_blocks": 0, +# # "kernel_size": 7, +# # "residual_kernel_size": 3, +# # "last_kernel_size": 3, +# # # We train using weight_norm but then the weights are pre-processed for inference so +# # # that we can use a normal convolution. +# # "norm": "none", +# # "pad_mode": "constant", +# # "ratios": [8, 6, 5, 4], +# # "true_skip": True, +# # } + +# # mimi = SEANetEncoder(**_seanet_kwargs).eval() +# # ---------------------------------method1------------------------ + +# mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME) +# mimi = loaders.get_mimi(mimi_weight, device="cpu") +# mimi.set_num_codebooks(8) # up to 32 for mimi, but limited to 8 for moshi. + +# wav = (torch.randn(1, 1, 24000 * 10),) # should be [B, C=1, T] + +# build_executorch_binary( +# mimi.eval(), +# wav, +# args.model, +# f"{args.artifact}/{pte_filename}", +# [wav], +# skip_node_id_set=skip_node_id_set, +# skip_node_op_set=skip_node_op_set, +# quant_dtype=QuantDtype.use_8a8w, +# shared_buffer=args.shared_buffer, +# ) +# # import pdb; pdb.set_trace() +# # with torch.no_grad(): +# # codes = mimi.encode(wav) # [B, K = 8, T] +# # decoded = mimi.decode(codes) #################################################### Unused, more like showcase + +# # # Supports streaming too. +# # frame_size = int(mimi.sample_rate / mimi.frame_rate) +# # all_codes = [] +# # with mimi.streaming(batch_size=1): +# # for offset in range(0, wav.shape[-1], frame_size): +# # frame = wav[:, :, offset: offset + frame_size] +# # codes = mimi.encode(frame) +# # assert codes.shape[-1] == 1, codes.shape +# # all_codes.append(codes) + +# # ## WARNING: When streaming, make sure to always feed a total amount of audio that is a multiple +# # # of the frame size (1920), otherwise the last frame will not be complete, and thus +# # # will not be encoded. For simplicity, we recommend feeding in audio always in multiple +# # # of the frame size, so that you always know how many time steps you get back in `codes`. + +# # # Now if you have a GPU around. +# # # mimi.cuda() +# # moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME) +# # moshi = loaders.get_moshi_lm(moshi_weight, device='cpu') +# # lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7) # this handles sampling params etc. +# # out_wav_chunks = [] +# # # Now we will stream over both Moshi I/O, and decode on the fly with Mimi. +# # with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1): +# # for idx, code in enumerate(all_codes): +# # tokens_out = lm_gen.step(code) +# # # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 1] representing the text token. +# # if tokens_out is not None: +# # wav_chunk = mimi.decode(tokens_out[:, 1:]) +# # out_wav_chunks.append(wav_chunk) +# # print(idx, end='\r') +# # out_wav = torch.cat(out_wav_chunks, dim=-1) +# # import pdb; pdb.set_trace() + +# from datasets import Audio, load_dataset +# from transformers import AutoFeatureExtractor, MimiModel + +# librispeech_dummy = load_dataset( +# "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" +# ) + +# # load model and feature extractor +# model = MimiModel.from_pretrained("kyutai/mimi") +# feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi") + +# # load audio sample +# librispeech_dummy = librispeech_dummy.cast_column( +# "audio", Audio(sampling_rate=feature_extractor.sampling_rate) +# ) +# # audio_sample = lib[-1]["audio"]["array"] +# audio_sample = torch.randn(24000 * 10) +# inputs = feature_extractor( +# raw_audio=audio_sample, +# sampling_rate=feature_extractor.sampling_rate, +# return_tensors="pt", +# ) +# audio_sample = ( +# inputs["input_values"], +# inputs["padding_mask"], +# ) +# build_executorch_binary( +# model.eval(), +# audio_sample, +# args.model, +# f"{args.artifact}/{pte_filename}", +# [audio_sample], +# skip_node_id_set=skip_node_id_set, +# skip_node_op_set=skip_node_op_set, +# quant_dtype=QuantDtype.use_8a8w, +# shared_buffer=args.shared_buffer, +# ) +# import pdb + +# pdb.set_trace() +# encoder_outputs = model.encode( +# inputs["input_values"], inputs["padding_mask"] +# ) # torch.randn(1, 240000), torch.randn(1, 1, 240000) +# audio_values = model.decode(encoder_outputs.audio_codes, inputs["padding_mask"])[0] +# # or the equivalent with a forward pass +# audio_values = model(inputs["input_values"], inputs["padding_mask"]).audio_values + + +# if __name__ == "__main__": +# parser = setup_common_args_and_variables() + +# parser.add_argument( +# "-a", +# "--artifact", +# help="path for storing generated artifacts by this example. Default ./ssd300_vgg16", +# default="./moshi", +# type=str, +# ) + +# args = parser.parse_args() +# try: +# main(args) +# except Exception as e: +# if args.ip and args.port != -1: +# with Client((args.ip, args.port)) as conn: +# conn.send(json.dumps({"Error": str(e)})) +# else: +# raise Exception(e) diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 5ecce30078e..f1646c62f3f 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -324,7 +324,7 @@ def build_executorch_binary( None: The function writes the output to a specified .pte file. """ if quant_dtype is not None: - captured_model = torch.export.export(model, inputs, strict=True).module() + captured_model = torch.export.export(model, inputs, strict=False).module() if qat_training_data: quantizer = custom_quantizer or make_quantizer( quant_dtype=quant_dtype, is_qat=True From b68190944651238f3dea068727ab0155abb78253 Mon Sep 17 00:00:00 2001 From: winskuo-quic Date: Thu, 13 Mar 2025 17:19:38 +0800 Subject: [PATCH 2/2] Encoder Enablement --- backends/qualcomm/_passes/__init__.py | 2 + .../qualcomm/_passes/annotate_decomposed.py | 2 +- backends/qualcomm/_passes/annotate_stack.py | 14 +- backends/qualcomm/_passes/decompose_cdist.py | 77 +++++++ backends/qualcomm/_passes/layout_transform.py | 3 + .../qualcomm/_passes/replace_inf_values.py | 37 ++++ backends/qualcomm/builders/__init__.py | 2 + backends/qualcomm/builders/op_index.py | 9 +- backends/qualcomm/builders/op_stack.py | 72 +++++++ backends/qualcomm/quantizer/annotators.py | 53 +++-- backends/qualcomm/quantizer/quantizer.py | 5 +- backends/qualcomm/tests/models.py | 13 +- backends/qualcomm/tests/test_qnn_delegate.py | 36 +++- backends/qualcomm/utils/utils.py | 4 + examples/qualcomm/oss_scripts/moshi/mimi.py | 195 +++++++++--------- .../oss_scripts/moshi/moshi_example.py | 189 ----------------- install_requirements.py | 2 +- 17 files changed, 380 insertions(+), 335 deletions(-) create mode 100644 backends/qualcomm/_passes/decompose_cdist.py create mode 100644 backends/qualcomm/_passes/replace_inf_values.py create mode 100644 backends/qualcomm/builders/op_stack.py delete mode 100644 examples/qualcomm/oss_scripts/moshi/moshi_example.py diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index d3761208fc9..30f2addac6b 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -7,6 +7,7 @@ from .convert_interpolate_with_upsample2d import ConvertInterpolateWithUpsample2D from .convert_to_linear import ConvertToLinear from .decompose_any import DecomposeAny +from .decompose_cdist import DecomposeCDist from .decompose_einsum import DecomposeEinsum from .decompose_expm1 import DecomposeExpM1 from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm @@ -41,6 +42,7 @@ RecomposePReLU, ConvertToLinear, DecomposeAny, + DecomposeCDist, DecomposeEinsum, DecomposeExpM1, DecomposeLinalgVectorNorm, diff --git a/backends/qualcomm/_passes/annotate_decomposed.py b/backends/qualcomm/_passes/annotate_decomposed.py index a8a757ce9bf..e2d23304f3c 100644 --- a/backends/qualcomm/_passes/annotate_decomposed.py +++ b/backends/qualcomm/_passes/annotate_decomposed.py @@ -32,7 +32,7 @@ def _annotate_unbind(self, graph_module: torch.fx.GraphModule): n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy() def _annotate_stack(self, graph_module: torch.fx.GraphModule): - partitions = get_source_partitions(graph_module.graph, [torch.stack]) + partitions = get_source_partitions(graph_module.graph, [torch.stack]) # TODO: Add "stack" later for _, src_partitions in partitions.items(): for src_partition in src_partitions: output = src_partition.output_nodes[0] diff --git a/backends/qualcomm/_passes/annotate_stack.py b/backends/qualcomm/_passes/annotate_stack.py index e2565ba9356..507240cfd39 100644 --- a/backends/qualcomm/_passes/annotate_stack.py +++ b/backends/qualcomm/_passes/annotate_stack.py @@ -8,7 +8,7 @@ from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS from executorch.exir.pass_base import ExportPass, PassResult - +#TODO: Remove this and merge it with annotate_decomposed. class AnnotateStack(ExportPass): """ During decomposition stage, some unsqueeze op will appear. @@ -25,12 +25,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: node.meta.get("torch_fn", ("", ""))[1] == "builtin_function_or_method.stack" ): - if ( - QCOM_QUANT_ATTRS not in node.meta - and QCOM_QUANT_ATTRS in node.args[0].meta - ): - node.meta[QCOM_QUANT_ATTRS] = node.args[0].meta[QCOM_QUANT_ATTRS] - + input1 = node.args[0] if isinstance(node.args[0], torch.fx.node.Node) else node.args[0][0] + if QCOM_QUANT_ATTRS not in node.meta and QCOM_QUANT_ATTRS in input1.meta and node.meta["val"].is_floating_point(): + node.meta[QCOM_QUANT_ATTRS] = input1.meta[QCOM_QUANT_ATTRS] + graph.eliminate_dead_code() graph_module.recompile() - return PassResult(graph_module, True) + return PassResult(graph_module, True) \ No newline at end of file diff --git a/backends/qualcomm/_passes/decompose_cdist.py b/backends/qualcomm/_passes/decompose_cdist.py new file mode 100644 index 00000000000..1c28be832a8 --- /dev/null +++ b/backends/qualcomm/_passes/decompose_cdist.py @@ -0,0 +1,77 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.pass_base import ExportPass, PassResult + + +class CDist(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + # Step 1: Compute differences + diff = x.unsqueeze(2) - y.unsqueeze(1) + + # Step 2: Square differences + sq_diff = diff**2 + + # Step 3: Sum of squares + sum_sq_diff = sq_diff.sum(dim=-1) + + # Step 4: Square root + distances = torch.sqrt(sum_sq_diff) + + return distances + + +class DecomposeCDist(ExportPass): + """ + Decompose for math equivalent op. + """ + + def __init__(self) -> None: + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + for node in graph.nodes: + model = CDist() + if torch.ops.aten.cdist.default == node.target: + decomposed_module = torch.export.export( + model, + (node.args[0].meta["val"], node.args[1].meta["val"]), + strict=True, + ).module() + with graph.inserting_before(node): + # remap is used to map original node values to new node values, + # which ensures that reference to nodes are correctly updated in the new graph + remap = {"x": node.args[0], "y": node.args[1]} + + for decomposed_node in decomposed_module.graph.nodes: + # no need to copy existent 'output' + if decomposed_node.op == "output": + for user in node.users.copy(): + # remap + user.replace_input_with( + node, + remap[decomposed_node.args[0][0]], + ) + # no need to copy existent placeholders + elif decomposed_node.op == "placeholder": + # replace node map from string to graph node + remap[decomposed_node] = remap.pop(decomposed_node.name) + else: + remap[decomposed_node] = graph.node_copy( + decomposed_node, + arg_transform=lambda x, remap=remap: remap[x], + ) + + graph.erase_node(node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index f3b8fd59065..d0ca0d77a27 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -48,6 +48,7 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.abs.default, exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.bmm.default, + exir_ops.edge.aten.bitwise_and.Tensor, exir_ops.edge.aten.cat.default, exir_ops.edge.aten.ceil.default, exir_ops.edge.aten.clamp.default, @@ -88,11 +89,13 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.sqrt.default, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.sum.dim_IntList, + exir_ops.edge.aten.stack.default, exir_ops.edge.aten.topk.default, exir_ops.edge.aten._to_copy.default, exir_ops.edge.aten.where.self, *q_ops, *dq_ops, + torch.ops.aten.scalar_tensor.default, _operator.getitem, } diff --git a/backends/qualcomm/_passes/replace_inf_values.py b/backends/qualcomm/_passes/replace_inf_values.py new file mode 100644 index 00000000000..b63d54c357f --- /dev/null +++ b/backends/qualcomm/_passes/replace_inf_values.py @@ -0,0 +1,37 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch +from executorch.exir.pass_base import ExportPass, PassResult + + +class ReplaceInfValues(ExportPass): + """ + Due to limitation in Qnn, we need to change inf or -inf to arbitrary value in quantization. + This could be a buffer or a node's argument. + """ + + def __init__(self): + super(ReplaceInfValues, self).__init__() + + def call(self, graph_module: torch.fx.GraphModule): + for buf_name, tensor in graph_module.named_buffers(): + if tensor.is_floating_point(): + # 255 here is mainly for attention_mask in Llama for reasonable quant scale + tensor[tensor == float("inf")] = 255 + tensor[tensor == float("-inf")] = -255 + setattr(graph_module, buf_name, tensor) + + for node in graph_module.graph.nodes: + arg_list = list(node.args) + for index, arg in enumerate(arg_list): + if arg == float("-inf"): + arg_list[index] = torch.finfo(torch.float32).min + elif arg == float("inf"): + arg_list[index] = torch.finfo(torch.float32).max + node.args = tuple(arg_list) + + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index 5e9573932bf..d135c380a60 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -75,6 +75,7 @@ op_split_with_sizes, op_sqrt, op_squeeze, + op_stack, op_sub, op_sum_int_list, op_tanh, @@ -158,6 +159,7 @@ op_split_with_sizes, op_squeeze, op_sqrt, + op_stack, op_sub, op_sum_int_list, op_tanh, diff --git a/backends/qualcomm/builders/op_index.py b/backends/qualcomm/builders/op_index.py index e78284a5e32..483f20107c5 100644 --- a/backends/qualcomm/builders/op_index.py +++ b/backends/qualcomm/builders/op_index.py @@ -38,11 +38,8 @@ def define_node( nodes_to_wrappers, ) - if len(node.args[1]) > 1: - # TODO consider to implement it in a recursive way. - raise NotImplementedError("Not support tuple of tensor.") - - indices_node = node.args[1][0] + axis = len(node.args[1]) - 1 + indices_node = node.args[1][axis] indices_tensor = self.get_tensor(indices_node, node).to(torch.int32) assert indices_tensor.size(0) != 0, "Not support empty indices list" @@ -78,7 +75,7 @@ def define_node( gather_op.AddScalarParam( OpGather.param_axis, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, - {QCOM_DATA: np.int32(0)}, + {QCOM_DATA: np.int32(axis)}, ) return gather_op diff --git a/backends/qualcomm/builders/op_stack.py b/backends/qualcomm/builders/op_stack.py new file mode 100644 index 00000000000..de0d484e825 --- /dev/null +++ b/backends/qualcomm/builders/op_stack.py @@ -0,0 +1,72 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import cast, Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import numpy as np +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpPack, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Stack(NodeVisitor): + target = ["aten.stack.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node_list = node.args[0] + stack_input_tensors = [] + for input_node in input_node_list: + input_tensor = self.get_tensor(input_node, node) + stack_inp_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + stack_input_tensors.append(stack_inp_tensor_wrapper) + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + stack_output_tensors = [output_tensor_wrapper] + + dim = 0 if len(node.args) == 1 else cast(int, node.args[1]) + if dim < 0: + dim = dim % len(input_tensor.shape) + if QCOM_AXIS_ORDER in node.meta: + dim = node.meta[QCOM_AXIS_ORDER].index(dim) + import pdb; pdb.set_trace() + stack_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpPack.op_name, + ) + stack_op.AddInputTensors(stack_input_tensors) + stack_op.AddOutputTensors(stack_output_tensors) + + stack_op.AddScalarParam( + OpPack.param_axis, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(dim)}, + ) + + return stack_op diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index c87f4686a80..8da8fdbc282 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -97,6 +97,7 @@ def annotate_in_out_obs_sharing_op( QUANT_ANNOTATION_KEY not in input_act.meta or not input_act.meta[QUANT_ANNOTATION_KEY]._annotated or input_act.meta[QUANT_ANNOTATION_KEY].output_qspec is None + or not _is_float_tensor(input_act) ): return @@ -132,9 +133,10 @@ def annotate_single_in_single_out( return input_qspec_map = {} - input_act = node.args[0] - assert isinstance(input_act, Node) - input_qspec_map[input_act] = quantization_config.input_activation + if _is_float_tensor(node.args[0]): + input_act = node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = quantization_config.input_activation if _is_float_tensor(node): node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( @@ -176,6 +178,9 @@ def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None _annotated=True, ) +@register_annotator([torch.ops.aten.__and__.Tensor]) +def annotate_and(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_binary(node, quantization_config) @register_annotator([torch.ops.aten.add, torch.ops.aten.add.Tensor]) def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None: @@ -478,6 +483,8 @@ def annotate_prelu(node: Node, quantization_config: QuantizationConfig) -> None: @register_annotator([torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default]) def annotate_view(node: Node, quantization_config: QuantizationConfig) -> None: + # if node.args[0].target == torch.ops.aten.argmin.default: + # import pdb; pdb.set_trace() annotate_in_out_obs_sharing_op(node, quantization_config) if not _is_annotated([node]): annotate_single_in_single_out(node, quantization_config) @@ -835,15 +842,24 @@ def annotate_stack(node: Node, quantization_config: QuantizationConfig) -> None: input_qspec_map = {} for input_act in node.args[0]: assert isinstance(input_act, Node) - input_qspec_map[input_act] = quantization_config.input_activation + first_input_node = node.args[0][0] + if _is_float_tensor(first_input_node): + input_qspec_map[first_input_node] = quantization_config.input_activation + share_qparams_with_input_act0_qspec = SharedQuantizationSpec( + (first_input_node, node) + ) - node_tensor = node.meta.get("val") - if torch.is_tensor(node_tensor) and node_tensor.dtype == torch.int64: - continue + for input_node in node.args[0][1:]: + if input_node not in input_qspec_map: + assert isinstance(input_node, Node) + if _is_float_tensor(input_node): + input_qspec_map[input_node] = share_qparams_with_input_act0_qspec node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, - output_qspec=quantization_config.output_activation, + output_qspec=( + share_qparams_with_input_act0_qspec if _is_float_tensor(node) else None + ), _annotated=True, ) @@ -909,6 +925,11 @@ def annotate_bmm(node: Node, quantization_config: QuantizationConfig) -> None: node.meta["source_fn_stack"] = [(node, torch.bmm)] +@register_annotator([torch.ops.aten.cdist.default]) +def annotate_cdist(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_binary(node, quantization_config) + + @register_annotator( [ torch.ops.aten.conv2d.default, @@ -1082,19 +1103,23 @@ def annotate_cat(node: Node, quantization_config: QuantizationConfig) -> None: input_qspec_map = {} assert isinstance(first_input_node, Node) assert isinstance(node, Node) - input_qspec_map[first_input_node] = quantization_config.input_activation - share_qparams_with_input_act0_qspec = SharedQuantizationSpec( - (first_input_node, node) - ) + if _is_float_tensor(first_input_node): + input_qspec_map[first_input_node] = quantization_config.input_activation + share_qparams_with_input_act0_qspec = SharedQuantizationSpec( + (first_input_node, node) + ) for input_node in input_nodes[1:]: if input_node not in input_qspec_map: assert isinstance(input_node, Node) - input_qspec_map[input_node] = share_qparams_with_input_act0_qspec + if _is_float_tensor(input_node): + input_qspec_map[input_node] = share_qparams_with_input_act0_qspec node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, - output_qspec=share_qparams_with_input_act0_qspec, + output_qspec=( + share_qparams_with_input_act0_qspec if _is_float_tensor(node) else None + ), _annotated=True, ) diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index e983613c518..90e4dc13a58 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -9,6 +9,7 @@ import torch from executorch.backends.qualcomm._passes import ( + DecomposeCDist, DecomposeEinsum, DecomposeExpM1, DecomposeLinalgVectorNorm, @@ -230,10 +231,8 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule: model = DecomposeEinsum()(model).graph_module model = DecomposeExpM1()(model).graph_module model = DecomposeLinalgVectorNorm(aten_dialect_capture=True)(model).graph_module + model = DecomposeCDist()(model).graph_module model = ReplaceInfValues()(model).graph_module - from executorch.backends.qualcomm.utils.utils import draw_graph - - draw_graph("checking", ".", model) model = LiftConstantScalarOperands()(model).graph_module return model diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 01f00fdcff8..e9b640f0256 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -6,6 +6,7 @@ import torch + # module with related operator only class And(torch.nn.Module): def __init__(self, pos, neg): @@ -178,6 +179,14 @@ def forward(self, x, y): return torch.cat((y, y, x, x), axis=2) +class CDist(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.cdist(x, y, p=2) + + class Ceil(torch.nn.Module): def __init__(self): super().__init__() @@ -1376,8 +1385,8 @@ class Stack(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, x, y): - return torch.stack((x, y)) + def forward(self, x, y, z): + return torch.stack((x, y, z)) class Sub(torch.nn.Module): diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 80e3becec79..b3138bc5a0a 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -157,6 +157,14 @@ def test_qnn_backend_cat(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_cdist(self): + module = CDist() # noqa: F405 + sample_input = ( + torch.randn(1, 125, 256), + torch.randn(1, 2048, 256), + ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_chunk_single(self): module = Chunk() # noqa: F405 sample_input = (torch.randn(1, 1, 4, 3),) @@ -169,7 +177,7 @@ def test_qnn_backend_clamp(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) - def test_qnn_backend_conv1ds(self): + def test_qnn_backend_conv1d(self): modules = [Conv1dSequential(), Conv1dSequential(bias=False)] # noqa: F405 sample_input = (torch.randn([1, 1, 3]),) for i, module in enumerate(modules): @@ -265,7 +273,10 @@ def test_qnn_backend_element_wise_add(self): def test_qnn_backend_element_wise_and(self): module = And(torch.tensor(1.7), torch.tensor(0.2)) # noqa: F405 - sample_input = (torch.tensor([1, 0, 1, 0], dtype=torch.bool), torch.tensor([1, 1, 0, 0], dtype=torch.bool),) + sample_input = ( + torch.tensor([1, 0, 1, 0], dtype=torch.bool), + torch.tensor([1, 1, 0, 0], dtype=torch.bool), + ) self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_element_wise_ceil(self): @@ -747,7 +758,7 @@ def test_qnn_backend_slice_copy(self): def test_qnn_backend_stack(self): module = Stack() # noqa: F405 - sample_input = (torch.randn([1, 2, 3, 4]), torch.randn([1, 2, 3, 4])) + sample_input = (torch.randn([1, 2, 3, 4]), torch.randn([1, 2, 3, 4]), torch.randn([1, 2, 3, 4]),) self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_softmax(self): @@ -1117,6 +1128,15 @@ def test_qnn_backend_cat(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_cdist(self): + module = CDist() # noqa: F405 + sample_input = ( + torch.randn(1, 125, 256), + torch.randn(1, 2048, 256), + ) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_chunk_single(self): module = Chunk() # noqa: F405 sample_input = (torch.randn(1, 1, 4, 3),) @@ -1236,7 +1256,10 @@ def test_qnn_backend_element_wise_add(self): def test_qnn_backend_element_wise_and(self): module = And(torch.tensor(1.7), torch.tensor(0.2)) # noqa: F405 - sample_input = (torch.tensor([1, 0, 1, 0], dtype=torch.bool), torch.tensor([1, 1, 0, 0], dtype=torch.bool),) + sample_input = ( + torch.tensor([1, 0, 1, 0], dtype=torch.bool), + torch.tensor([1, 1, 0, 0], dtype=torch.bool), + ) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) @@ -1796,10 +1819,7 @@ def test_qnn_backend_squeeze(self): def test_qnn_backend_stack(self): module = Stack() # noqa: F405 - sample_input = ( - torch.randn([1, 2, 3, 4]), - torch.randn([1, 2, 3, 4]), - ) + sample_input = (torch.randn([1, 2, 3, 4]),torch.randn([1, 2, 3, 4]), torch.randn([1, 2, 3, 4]),) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 9bbe7e067cd..5e4c96f520c 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -26,6 +26,7 @@ ConvertInterpolateWithUpsample2D, ConvertToLinear, DecomposeAny, + DecomposeCDist, DecomposeExpM1, DecomposeLinalgVectorNorm, ExpandBroadcastTensorShape, @@ -338,6 +339,7 @@ def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]: torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardswish.default, torch.ops.aten._safe_softmax.default, + # torch.ops.aten.stack.default, # We should be able to enable this, but QNN does not support int IO. ] remove_decompositions(source_decompositions, skip_decompositions) @@ -448,6 +450,7 @@ def _preprocess_module(module: torch.nn.Module, inputs: Tuple[torch.Tensor]): return module module = torch.export.export(module, inputs, strict=False).module() + module = DecomposeCDist()(module).graph_module module = DecomposeScaledDotProductAttention()(module).graph_module module = DecomposeLinalgVectorNorm(True)(module).graph_module module = DecomposeExpM1()(module).graph_module @@ -465,6 +468,7 @@ def capture_program( ep = torch.export.export( module, inputs, dynamic_shapes=dynamic_shapes, strict=False ) + #TODO: Handle stack op. If we want to run annotate_decomposed pass for stack op, we need to make stack op decompose, which means we need to find a method to remove it from skip_decomp table decomposed_ep = ep.run_decompositions(get_decomp_table()) core_ep = ExirExportedProgram(decomposed_ep, False) core_ep.transform(TensorI64toI32(edge_program=core_ep)) diff --git a/examples/qualcomm/oss_scripts/moshi/mimi.py b/examples/qualcomm/oss_scripts/moshi/mimi.py index f3bc6f26eec..7d2b9cd6b67 100644 --- a/examples/qualcomm/oss_scripts/moshi/mimi.py +++ b/examples/qualcomm/oss_scripts/moshi/mimi.py @@ -3,27 +3,22 @@ # LICENSE file in the root directory of this source tree. # import argparse +import io import json import os import random -import time from multiprocessing.connection import Client -import requests -import io -import torchaudio import numpy as np +import requests import sphn import torch import torch.nn as nn +import torchaudio from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype -# from executorch.examples.models.llama.llama_transformer import Transformer - -# from executorch.examples.models.llama.model_args import ModelArgs - from executorch.examples.qualcomm.utils import ( build_executorch_binary, make_output_dir, @@ -35,9 +30,6 @@ from huggingface_hub import hf_hub_download from moshi.models import loaders -from torch.profiler import profile, ProfilerActivity - - def seed_all(seed): torch.manual_seed(seed) if torch.cuda.is_available(): @@ -48,6 +40,7 @@ def seed_all(seed): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False + def read_mp3_from_url(url): response = requests.get(url) response.raise_for_status() # Ensure request is successful @@ -61,58 +54,58 @@ def read_mp3_from_url(url): return waveform.numpy(), sample_rate -def mimi_encode(): - None -def mimi_decode(): - None -def mimi_test(mimi, args, max_duration_sec=10.0): - pcm_chunk_size = int(mimi.sample_rate / mimi.frame_rate) - sample_rate = mimi.sample_rate - url = "https://huggingface.co/lmz/moshi-swift/resolve/main/bria-24khz.mp3" - sample_pcm, sample_sr = read_mp3_from_url(url) - pcm_chunk_size = int(mimi.sample_rate / mimi.frame_rate) - sample_rate = mimi.sample_rate - sample_pcm = torch.tensor(sample_pcm, device='cpu') - max_duration_len = int(sample_rate * max_duration_sec) - if sample_pcm.shape[-1] > max_duration_len: - sample_pcm = sample_pcm[..., :max_duration_len] - sample_pcm = sample_pcm[None].to(device='cpu') - # sample_pcm = torch.ones(1, 1, 240000) +def mimi_encode(mimi, sample_pcm, skip_node_id_set, skip_node_op_set) -> torch.Tensor: + class MimiEncode(nn.Module): + def __init__(self, mimi: nn.Module): + super().__init__() + self.mimi_model = mimi - print("streaming encoding...") - start_time = time.time() - all_codes = [] - - def run_loop(): - for start_idx in range(0, sample_pcm.shape[-1], pcm_chunk_size): - end_idx = min(sample_pcm.shape[-1], start_idx + pcm_chunk_size) - chunk = sample_pcm[..., start_idx:end_idx] - codes = mimi.encode(chunk) - if codes.shape[-1]: - print(start_idx, codes.shape, end="\r") - all_codes.append(codes) - if args.profile: - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: - run_loop() - prof.export_chrome_trace("trace.json") - else: - run_loop() - all_codes_th = torch.cat(all_codes, dim=-1) - print(f"codes {all_codes_th.shape} generated in {time.time() - start_time:.2f}s") - print("streaming decoding...") - all_pcms = [] - with mimi.streaming(1): - for i in range(all_codes_th.shape[-1]): - codes = all_codes_th[..., i : i + 1] - pcm = mimi.decode(codes) - print(i, pcm.shape, end="\r") - all_pcms.append(pcm) - all_pcms = torch.cat(all_pcms, dim=-1) - pcm_ref = mimi.decode(all_codes_th) # same as mimi_decode(input[0]) - print("pcm", all_pcms.shape, all_pcms.dtype) - - assert torch.allclose(pcm_ref, all_pcms, atol=1e-5) + def forward(self, x): + return self.mimi_model.encode(x) + + mimi_encode = MimiEncode(mimi) + encode_input = (sample_pcm,) + pte_filename = "mimi_encoder_qnn" + build_executorch_binary( + mimi_encode.eval(), + encode_input, + args.model, + f"{args.artifact}/{pte_filename}", + [encode_input], + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + quant_dtype=QuantDtype.use_8a8w, + shared_buffer=args.shared_buffer, + ) + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=f"{args.artifact}/{pte_filename}.pte", + workspace=f"/data/local/tmp/executorch/{pte_filename}", + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + ) + adb.push(inputs=[encode_input], input_list="input_0_0.raw") + adb.execute() + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + + adb.pull(output_path=args.artifact) + + predictions = [] + predictions.append( + np.fromfile(os.path.join(output_data_folder, "output_0_0.raw"), dtype=np.int64) + ) + htp_res = torch.from_numpy(predictions[0]).view(1, 8, 125) + return htp_res + + +def mimi_decode(mimi, encode_res, skip_node_id_set, skip_node_op_set) -> torch.Tensor: class MimiDecode(nn.Module): def __init__(self, mimi: nn.Module): super().__init__() @@ -122,27 +115,20 @@ def forward(self, x): return self.mimi_model.decode(x) mimi_decode = MimiDecode(mimi) - - skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) - # ensure the working directory exist. - os.makedirs(args.artifact, exist_ok=True) - pte_filename = "mimi_qnn" - input = (all_codes_th.to(torch.int32),) + decode_input = (encode_res,) + pte_filename = "mimi_decoder_qnn" build_executorch_binary( mimi_decode.eval(), - input, + decode_input, args.model, f"{args.artifact}/{pte_filename}", - [input], + [decode_input], skip_node_id_set=skip_node_id_set, skip_node_op_set=skip_node_op_set, quant_dtype=QuantDtype.use_8a8w, shared_buffer=args.shared_buffer, ) - if args.compile_only: - return - adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), build_path=f"{args.build_folder}", @@ -153,7 +139,7 @@ def forward(self, x): soc_model=args.model, shared_buffer=args.shared_buffer, ) - adb.push(inputs=[input], input_list="input_0_0.raw") + adb.push(inputs=[decode_input], input_list="input_0_0.raw") adb.execute() # collect output data @@ -162,40 +148,44 @@ def forward(self, x): adb.pull(output_path=args.artifact) - # top-k analysis predictions = [] predictions.append( np.fromfile( os.path.join(output_data_folder, "output_0_0.raw"), dtype=np.float32 ) ) - htp_res = torch.from_numpy(predictions[0]).view(1, 1, 240000) - cosine_sim = torch.nn.functional.cosine_similarity( - pcm_ref.flatten(), htp_res.flatten(), dim=0 - ).item() - print("Cos similarity: ", cosine_sim) - sphn.write_wav("streaming_out.wav", all_pcms[0, 0].cpu().numpy(), sample_rate) - sphn.write_wav("ref.wav", pcm_ref[0, 0].cpu().numpy(), sample_rate) - sphn.write_wav("htp.wav", htp_res[0,0].cpu().numpy(), sample_rate) - # With QNN 2.28.2 - # 0.9650231003761292 - # 8a8w cos similarity: 0.9635128378868103, 1 inference: 73.5ms, file size: ~59mb - # 16a16w cos similarity: failed at runner: Error from rpc transport, file size: ~104mb - # 16a4w cos similarity: failed at runner: Error from rpc transport, file size: ~53mb - # fp cos similarity: failed at runner: Error from rpc transport, file size: ~101mb (QNN 2.31.0 for this) + htp_decode_res = torch.from_numpy(predictions[0]).view(1, 1, 240000) + return htp_decode_res - class MimiEncode(nn.Module): - def __init__(self, mimi: nn.Module): - super().__init__() - self.mimi_model = mimi - def forward(self, x): - return self.mimi_model.encode(x) +def export_mimi(mimi, args, max_duration_sec=10.0): + sample_rate = mimi.sample_rate + url = "https://huggingface.co/lmz/moshi-swift/resolve/main/bria-24khz.mp3" + sample_pcm, sample_sr = read_mp3_from_url(url) + sample_rate = mimi.sample_rate + sample_pcm = torch.tensor(sample_pcm, device="cpu") + max_duration_len = int(sample_rate * max_duration_sec) + if sample_pcm.shape[-1] > max_duration_len: + sample_pcm = sample_pcm[..., :max_duration_len] + sample_pcm = sample_pcm[None].to(device="cpu") - mimi_encode = MimiEncode(mimi) - chunk = sample_pcm[..., 0:pcm_chunk_size] - out = mimi_encode(chunk) - exported_encode = torch.export.export(mimi_encode, (chunk,), strict=False).module() + skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + print("streaming encoding...") + cpu_encode_res = mimi.encode(sample_pcm) + htp_encode_res = mimi_encode(mimi, sample_pcm, skip_node_id_set, skip_node_op_set) + cpu_decode_res = mimi.decode(cpu_encode_res) + htp_decode_res = mimi_decode( + mimi, htp_encode_res.to(torch.int32), skip_node_id_set, skip_node_op_set + ) + sphn.write_wav( + "cpu_decode_res.wav", cpu_decode_res[0, 0].cpu().numpy(), sample_rate + ) + sphn.write_wav( + "htp_decode_res.wav", htp_decode_res[0, 0].cpu().numpy(), sample_rate + ) def main(args): @@ -209,7 +199,8 @@ def main(args): # emb = torch.load('emb.pt') with torch.no_grad(): - mimi_test(mimi, args) + export_mimi(mimi, args) + if __name__ == "__main__": @@ -218,16 +209,14 @@ def main(args): parser.add_argument( "-a", "--artifact", - help="path for storing generated artifacts by this example. Default ./ssd300_vgg16", + help="path for storing generated artifacts by this example. Default ./mimi", default="./mimi", type=str, ) parser.add_argument("--mimi-weight", type=str) parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO) - # parser.add_argument( - # "--device", type=str, default="cpu" if torch.cuda.device_count() else "cpu" - # ) + parser.add_argument("--profile", action="store_true") args = parser.parse_args() diff --git a/examples/qualcomm/oss_scripts/moshi/moshi_example.py b/examples/qualcomm/oss_scripts/moshi/moshi_example.py deleted file mode 100644 index edfda9f9aab..00000000000 --- a/examples/qualcomm/oss_scripts/moshi/moshi_example.py +++ /dev/null @@ -1,189 +0,0 @@ -# # Copyright (c) Qualcomm Innovation Center, Inc. -# # All rights reserved -# # -# # This source code is licensed under the BSD-style license found in the -# # LICENSE file in the root directory of this source tree. - -# import json -# import os -# import sys -# from multiprocessing.connection import Client - -# import numpy as np - -# import torch -# from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype -# from executorch.examples.qualcomm.utils import ( -# build_executorch_binary, -# get_imagenet_dataset, -# make_output_dir, -# parse_skip_delegation_node, -# setup_common_args_and_variables, -# SimpleADB, -# topk_accuracy, -# ) - -# from huggingface_hub import hf_hub_download - -# # python examples/qualcomm/oss_scripts/moshi/moshi.py -b build-android -H mlgtw-linux2 -s acfa9311 -m SM8650 - - -# def main(args): -# pte_filename = "moshi_qnn" -# sys.path.insert(0, "../moshi/moshi/") -# from moshi.models import LMGen, loaders -# from moshi.modules.seanet import SEANetEncoder - -# skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) - -# # ensure the working directory exist. -# os.makedirs(args.artifact, exist_ok=True) - -# # ---------------------------------method1------------------------ -# # _seanet_kwargs = { -# # "channels": 1, -# # "dimension": 512, -# # "causal": True, -# # "n_filters": 64, -# # "n_residual_layers": 1, -# # "activation": "ELU", -# # "compress": 2, -# # "dilation_base": 2, -# # "disable_norm_outer_blocks": 0, -# # "kernel_size": 7, -# # "residual_kernel_size": 3, -# # "last_kernel_size": 3, -# # # We train using weight_norm but then the weights are pre-processed for inference so -# # # that we can use a normal convolution. -# # "norm": "none", -# # "pad_mode": "constant", -# # "ratios": [8, 6, 5, 4], -# # "true_skip": True, -# # } - -# # mimi = SEANetEncoder(**_seanet_kwargs).eval() -# # ---------------------------------method1------------------------ - -# mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME) -# mimi = loaders.get_mimi(mimi_weight, device="cpu") -# mimi.set_num_codebooks(8) # up to 32 for mimi, but limited to 8 for moshi. - -# wav = (torch.randn(1, 1, 24000 * 10),) # should be [B, C=1, T] - -# build_executorch_binary( -# mimi.eval(), -# wav, -# args.model, -# f"{args.artifact}/{pte_filename}", -# [wav], -# skip_node_id_set=skip_node_id_set, -# skip_node_op_set=skip_node_op_set, -# quant_dtype=QuantDtype.use_8a8w, -# shared_buffer=args.shared_buffer, -# ) -# # import pdb; pdb.set_trace() -# # with torch.no_grad(): -# # codes = mimi.encode(wav) # [B, K = 8, T] -# # decoded = mimi.decode(codes) #################################################### Unused, more like showcase - -# # # Supports streaming too. -# # frame_size = int(mimi.sample_rate / mimi.frame_rate) -# # all_codes = [] -# # with mimi.streaming(batch_size=1): -# # for offset in range(0, wav.shape[-1], frame_size): -# # frame = wav[:, :, offset: offset + frame_size] -# # codes = mimi.encode(frame) -# # assert codes.shape[-1] == 1, codes.shape -# # all_codes.append(codes) - -# # ## WARNING: When streaming, make sure to always feed a total amount of audio that is a multiple -# # # of the frame size (1920), otherwise the last frame will not be complete, and thus -# # # will not be encoded. For simplicity, we recommend feeding in audio always in multiple -# # # of the frame size, so that you always know how many time steps you get back in `codes`. - -# # # Now if you have a GPU around. -# # # mimi.cuda() -# # moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME) -# # moshi = loaders.get_moshi_lm(moshi_weight, device='cpu') -# # lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7) # this handles sampling params etc. -# # out_wav_chunks = [] -# # # Now we will stream over both Moshi I/O, and decode on the fly with Mimi. -# # with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1): -# # for idx, code in enumerate(all_codes): -# # tokens_out = lm_gen.step(code) -# # # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 1] representing the text token. -# # if tokens_out is not None: -# # wav_chunk = mimi.decode(tokens_out[:, 1:]) -# # out_wav_chunks.append(wav_chunk) -# # print(idx, end='\r') -# # out_wav = torch.cat(out_wav_chunks, dim=-1) -# # import pdb; pdb.set_trace() - -# from datasets import Audio, load_dataset -# from transformers import AutoFeatureExtractor, MimiModel - -# librispeech_dummy = load_dataset( -# "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" -# ) - -# # load model and feature extractor -# model = MimiModel.from_pretrained("kyutai/mimi") -# feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi") - -# # load audio sample -# librispeech_dummy = librispeech_dummy.cast_column( -# "audio", Audio(sampling_rate=feature_extractor.sampling_rate) -# ) -# # audio_sample = lib[-1]["audio"]["array"] -# audio_sample = torch.randn(24000 * 10) -# inputs = feature_extractor( -# raw_audio=audio_sample, -# sampling_rate=feature_extractor.sampling_rate, -# return_tensors="pt", -# ) -# audio_sample = ( -# inputs["input_values"], -# inputs["padding_mask"], -# ) -# build_executorch_binary( -# model.eval(), -# audio_sample, -# args.model, -# f"{args.artifact}/{pte_filename}", -# [audio_sample], -# skip_node_id_set=skip_node_id_set, -# skip_node_op_set=skip_node_op_set, -# quant_dtype=QuantDtype.use_8a8w, -# shared_buffer=args.shared_buffer, -# ) -# import pdb - -# pdb.set_trace() -# encoder_outputs = model.encode( -# inputs["input_values"], inputs["padding_mask"] -# ) # torch.randn(1, 240000), torch.randn(1, 1, 240000) -# audio_values = model.decode(encoder_outputs.audio_codes, inputs["padding_mask"])[0] -# # or the equivalent with a forward pass -# audio_values = model(inputs["input_values"], inputs["padding_mask"]).audio_values - - -# if __name__ == "__main__": -# parser = setup_common_args_and_variables() - -# parser.add_argument( -# "-a", -# "--artifact", -# help="path for storing generated artifacts by this example. Default ./ssd300_vgg16", -# default="./moshi", -# type=str, -# ) - -# args = parser.parse_args() -# try: -# main(args) -# except Exception as e: -# if args.ip and args.port != -1: -# with Client((args.ip, args.port)) as conn: -# conn.send(json.dumps({"Error": str(e)})) -# else: -# raise Exception(e) diff --git a/install_requirements.py b/install_requirements.py index 06dfbd9e9a6..9353dad180e 100644 --- a/install_requirements.py +++ b/install_requirements.py @@ -67,7 +67,7 @@ def python_is_compatible(): # NOTE: If a newly-fetched version of the executorch repo changes the value of # NIGHTLY_VERSION, you should re-run this script to install the necessary # package versions. -NIGHTLY_VERSION = "dev20250301" +NIGHTLY_VERSION = "dev20250311" def install_requirements(use_pytorch_nightly):