diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 6d50bbb51b2..d8e6237ce32 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -11,6 +11,7 @@ from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d from .convert_upsample_bicubic2d import ConvertUpsampleBicubicWithBilinear from .decompose_any import DecomposeAny +from .decompose_cdist import DecomposeCDist from .decompose_einsum import DecomposeEinsum from .decompose_expm1 import DecomposeExpM1 from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm @@ -27,6 +28,7 @@ from .recompose_pixel_unshuffle import RecomposePixelUnshuffle from .recompose_rms_norm import RecomposeRmsNorm from .reduce_dynamic_range import ReduceDynamicRange +from .remove_0d_tensor import Remove0DTensor from .remove_redundancy import RemoveRedundancy from .replace_arange_args import ReplaceArangeArgs from .replace_index_put_input import ReplaceIndexPutInput @@ -40,8 +42,9 @@ AnnotateUnbind, ConvertBmmToMatmul, ConvertConv1dToConv2d, - DecomposeAny, ConvertUpsampleBicubicWithBilinear, + DecomposeAny, + DecomposeCDist, DecomposeEinsum, DecomposeExpM1, DecomposeLinalgVectorNorm, @@ -58,6 +61,7 @@ RecomposePixelUnshuffle, RecomposeRmsNorm, ReduceDynamicRange, + Remove0DTensor, RemoveRedundancy, ReplaceArangeArgs, ReplaceIndexPutInput, diff --git a/backends/qualcomm/_passes/annotate_stack.py b/backends/qualcomm/_passes/annotate_stack.py index c42804af2f2..5fbfde058b2 100644 --- a/backends/qualcomm/_passes/annotate_stack.py +++ b/backends/qualcomm/_passes/annotate_stack.py @@ -17,14 +17,16 @@ class AnnotateStack(ExportPass): generated after quantization process. """ - decomp_ops = [torch.ops.aten.unbind.int] + decomp_ops = [torch.ops.aten.stack.default] def __init__(self, edge_program: torch.export.ExportedProgram): super(AnnotateStack, self).__init__() self.edge_program = edge_program def _annotate_stack(self, graph_module: torch.fx.GraphModule): - partitions = get_source_partitions(graph_module.graph, [torch.stack, "stack"]) + partitions = get_source_partitions( + graph_module.graph, [torch.stack, torch.ops.aten.stack.default, "stack"] + ) for _, src_partitions in partitions.items(): for src_partition in src_partitions: output = src_partition.output_nodes[0] diff --git a/backends/qualcomm/_passes/annotate_unbind.py b/backends/qualcomm/_passes/annotate_unbind.py index 0efa1638bc4..426285e872b 100644 --- a/backends/qualcomm/_passes/annotate_unbind.py +++ b/backends/qualcomm/_passes/annotate_unbind.py @@ -24,7 +24,9 @@ def __init__(self, edge_program: torch.export.ExportedProgram): self.edge_program = edge_program def _annotate_unbind(self, graph_module: torch.fx.GraphModule): - partitions = get_source_partitions(graph_module.graph, [torch.unbind, "unbind"]) + partitions = get_source_partitions( + graph_module.graph, [torch.unbind, torch.ops.aten.unbind.int, "unbind"] + ) for _, src_partitions in partitions.items(): for src_partition in src_partitions: if src_partition.input_nodes[0].target in dq_ops: diff --git a/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py b/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py index 947b631dbbf..72dc29c2880 100644 --- a/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py +++ b/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter +from executorch.backends.qualcomm.utils.constants import QCOM_REQUANTIZE from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -43,6 +44,7 @@ def call(self, graph_module: torch.fx.GraphModule): unsqueeze_node.meta = copy_meta( input_node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)} ) + with graph_module.graph.inserting_after(unsqueeze_node): filter_node = node.args[1] @@ -92,6 +94,14 @@ def call(self, graph_module: torch.fx.GraphModule): ), ) squeeze_node.meta = copy_meta(node.meta) + + if QCOM_REQUANTIZE in input_node.meta: + input_node.meta.pop(QCOM_REQUANTIZE) + if QCOM_REQUANTIZE in node.meta: + squeeze_node.meta[QCOM_REQUANTIZE] = node.meta[ + QCOM_REQUANTIZE + ] + conv2d_node.meta.pop(QCOM_REQUANTIZE, None) for user in node.users.copy(): user.replace_input_with(node, squeeze_node) graph.eliminate_dead_code() diff --git a/backends/qualcomm/_passes/decompose_cdist.py b/backends/qualcomm/_passes/decompose_cdist.py new file mode 100644 index 00000000000..d18a0295ffb --- /dev/null +++ b/backends/qualcomm/_passes/decompose_cdist.py @@ -0,0 +1,81 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.pass_base import ExportPass, PassResult + + +class CDist(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + # Step 1: Compute differences + diff = x.unsqueeze(-2) - y.unsqueeze(-3) + + # Step 2: Square differences + sq_diff = diff**2 + + # Step 3: Sum of squares + sum_sq_diff = sq_diff.sum(dim=-1) + + # Step 4: Square root + distances = torch.sqrt(sum_sq_diff) + + return distances + + +class DecomposeCDist(ExportPass): + """ + Decompose for math equivalent op. + """ + + def __init__(self) -> None: + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + for node in graph.nodes: + model = CDist() + if torch.ops.aten.cdist.default == node.target: + if len(node.args) > 2: + assert ( + node.args[2] == 2 + ), "Currently only p=2 is supported for CDist Decomposition" + decomposed_module = torch.export.export( + model, + (node.args[0].meta["val"], node.args[1].meta["val"]), + strict=True, + ).module() + with graph.inserting_before(node): + # remap is used to map original node values to new node values, + # which ensures that reference to nodes are correctly updated in the new graph + remap = {"x": node.args[0], "y": node.args[1]} + + for decomposed_node in decomposed_module.graph.nodes: + # no need to copy existent 'output' + if decomposed_node.op == "output": + for user in node.users.copy(): + # remap + user.replace_input_with( + node, + remap[decomposed_node.args[0][0]], + ) + # no need to copy existent placeholders + elif decomposed_node.op == "placeholder": + # replace node map from string to graph node + remap[decomposed_node] = remap.pop(decomposed_node.name) + else: + remap[decomposed_node] = graph.node_copy( + decomposed_node, + arg_transform=lambda x, remap=remap: remap[x], + ) + + graph.erase_node(node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/lift_constant_scalar_operands.py b/backends/qualcomm/_passes/lift_constant_scalar_operands.py index 93abfe621bc..9b3a308813e 100644 --- a/backends/qualcomm/_passes/lift_constant_scalar_operands.py +++ b/backends/qualcomm/_passes/lift_constant_scalar_operands.py @@ -53,7 +53,13 @@ class TensorOpInfo: } -SKIP_LIFT_OPS = {aten.full_like.default, aten.arange.start_step} +SKIP_LIFT_OPS = { + aten.full_like.default, + aten.arange.start_step, + aten.arange.default, + aten.scalar_tensor.default, + aten.elu.default, +} class LiftConstantScalarOperands(ExportPass): diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index ed3c8eb217e..dda01e4a8a1 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -16,6 +16,7 @@ ConvertConv1dToConv2d, ConvertUpsampleBicubicWithBilinear, DecomposeAny, + DecomposeCDist, DecomposeEinsum, DecomposeExpM1, DecomposeLinalgVectorNorm, @@ -32,6 +33,7 @@ RecomposePixelUnshuffle, RecomposeRmsNorm, ReduceDynamicRange, + Remove0DTensor, RemoveRedundancy, ReplaceArangeArgs, ReplaceIndexPutInput, @@ -71,7 +73,7 @@ def get_capture_program_passes(): # If a pass is activated, it will be executed by default. default_passes_and_setting = [ (AnnotateQuantAttrs, True), - (AnnotateStack, False), + (AnnotateStack, True), (AnnotateUnbind, True), (ConvertBmmToMatmul, True), (ConvertConv1dToConv2d, True), @@ -84,6 +86,7 @@ def get_capture_program_passes(): (LayoutTransform, True), (RecomposePixelUnshuffle, True), (RecomposeRmsNorm, False), + (Remove0DTensor, True), (RemoveRedundancy, True), (ReplaceIndexPutInput, True), (TagQuantIO, False), @@ -176,7 +179,23 @@ def transform_for_to_edge_pipeline( return exported_program + # Before quantizer + def transform_for_annotation_pipeline(self, graph_module: GraphModule): + self.add_pass(ReduceDynamicRange()) + self.add_pass(RecomposePixelUnshuffle(quantization_capture=True)) + self.add_pass(ReplaceArangeArgs()) + self.add_pass(DecomposeCDist()) + self.add_pass(DecomposeScaledDotProductAttention()) + self.add_pass(DecomposeSilu()) + self.add_pass(DecomposeEinsum()) + self.add_pass(DecomposeExpM1()) + self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) + self.add_pass(ReplaceInfValues()) + self.add_pass(LiftConstantScalarOperands()) + return self._transform(graph_module) + def transform_for_export_pipeline(self, exported_program: ExportedProgram): + self.add_pass(DecomposeCDist()) self.add_pass(DecomposeScaledDotProductAttention()) self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) self.add_pass(DecomposeExpM1()) @@ -191,16 +210,3 @@ def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram): self.add_pass(LayoutTransform(exported_program, insert_permute=True)) self.add_pass(FuseConsecutiveTranspose()) return self._transform(exported_program.graph_module) - - def transform_for_annotation_pipeline(self, graph_module: GraphModule): - self.add_pass(ReduceDynamicRange()) - self.add_pass(RecomposePixelUnshuffle(quantization_capture=True)) - self.add_pass(ReplaceArangeArgs()) - self.add_pass(DecomposeScaledDotProductAttention()) - self.add_pass(DecomposeSilu()) - self.add_pass(DecomposeEinsum()) - self.add_pass(DecomposeExpM1()) - self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) - self.add_pass(ReplaceInfValues()) - self.add_pass(LiftConstantScalarOperands()) - return self._transform(graph_module) diff --git a/backends/qualcomm/_passes/remove_0d_tensor.py b/backends/qualcomm/_passes/remove_0d_tensor.py new file mode 100644 index 00000000000..1e1d711c2b8 --- /dev/null +++ b/backends/qualcomm/_passes/remove_0d_tensor.py @@ -0,0 +1,36 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class Remove0DTensor(ExportPass): + """ + QNN does not allow 0D tensor, we remove the node that will output an 0D tensor. + Before adding operations to the list of nodes to be removed, please ensure that it will not change the logic. + """ + + remove_ops = { + exir_ops.edge.aten.select.int, + exir_ops.edge.aten.select_copy.int, + } + + def __init__(self, quantization_capture=False) -> None: + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + for node in graph.nodes: + if node.target in self.remove_ops and len(node.meta["val"].shape) == 0: + for user_n in list(node.users.keys()): + user_n.replace_input_with(node, node.args[0]) + graph.erase_node(node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index 7b5e72d461d..d9eb188614c 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -34,7 +34,7 @@ not_supported_operator, to_be_implemented_operator, ) -from .utils import generate_qnn_executorch_option, get_skip_decomp_table +from .utils import filter_fn, generate_qnn_executorch_option, get_skip_decomp_table class QnnOperatorSupport(OperatorSupportBase): @@ -181,5 +181,4 @@ def ops_to_not_decompose( self, ep: ExportedProgram ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: do_not_decompose = get_skip_decomp_table() - - return do_not_decompose, None + return (do_not_decompose, filter_fn) diff --git a/backends/qualcomm/partition/utils.py b/backends/qualcomm/partition/utils.py index 6931e35e6e3..816d1ac1d9b 100644 --- a/backends/qualcomm/partition/utils.py +++ b/backends/qualcomm/partition/utils.py @@ -24,6 +24,21 @@ def generate_qnn_executorch_option( return qnn_compile_spec_buffer +# Logic to determine whether to skip decompose and has higher priority than get_skip_decomp_table() +def filter_fn(node: torch.fx.Node) -> bool: + # QNN does not support int32/int64 IO for the following OPs. + potential_i32_i64_io_ops = [ + torch.ops.aten.stack.default, + torch.ops.aten.unbind.int, + ] + if node.target in potential_i32_i64_io_ops and node.meta["val"].dtype in [ + torch.int32, + torch.int64, + ]: + return False + return True + + def get_skip_decomp_table() -> List[torch._ops.OperatorBase]: do_not_decompose = [ torch.ops.aten.adaptive_avg_pool2d.default, @@ -41,7 +56,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]: torch.ops.aten.stack.default, torch.ops.aten.upsample_bicubic2d.vec, # This request is ignored because it is in a blocklist. Refer to exir/program/_program.py - # torch.ops.aten.unbind.int, + torch.ops.aten.unbind.int, torch.ops.pt2e_quant.quantize_affine.default, torch.ops.pt2e_quant.dequantize_affine.default, ] diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index f59df46b3c1..469a801feeb 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -97,6 +97,7 @@ def annotate_in_out_obs_sharing_op( QUANT_ANNOTATION_KEY not in input_act.meta or not input_act.meta[QUANT_ANNOTATION_KEY]._annotated or input_act.meta[QUANT_ANNOTATION_KEY].output_qspec is None + or not _is_float_tensor(input_act) ): return @@ -132,9 +133,10 @@ def annotate_single_in_single_out( return input_qspec_map = {} - input_act = node.args[0] - assert isinstance(input_act, Node) - input_qspec_map[input_act] = quantization_config.input_activation + if _is_float_tensor(node.args[0]): + input_act = node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = quantization_config.input_activation if _is_float_tensor(node): node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( @@ -177,7 +179,9 @@ def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None ) -@register_annotator([torch.ops.aten.add, torch.ops.aten.add.Tensor]) +@register_annotator( + [torch.ops.aten.add, torch.ops.aten.add.Tensor, torch.ops.aten.add_.Tensor] +) def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) @@ -933,6 +937,11 @@ def annotate_bmm(node: Node, quantization_config: QuantizationConfig) -> None: node.meta["source_fn_stack"] = [(node, torch.bmm)] +@register_annotator([torch.ops.aten.cdist.default]) +def annotate_cdist(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_binary(node, quantization_config) + + @register_annotator( [ torch.ops.aten.conv2d.default, @@ -941,7 +950,7 @@ def annotate_bmm(node: Node, quantization_config: QuantizationConfig) -> None: torch.ops.aten.conv_transpose1d.default, ] ) -def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: +def annotate_conv(node: Node, quantization_config: QuantizationConfig) -> None: if _is_annotated([node]): return @@ -1118,15 +1127,17 @@ def annotate_cat(node: Node, quantization_config: QuantizationConfig) -> None: input_qspec_map = {} assert isinstance(first_input_node, Node) assert isinstance(node, Node) - input_qspec_map[first_input_node] = quantization_config.input_activation - share_qparams_with_input_act0_qspec = SharedQuantizationSpec( - (first_input_node, node) - ) + if _is_float_tensor(first_input_node): + input_qspec_map[first_input_node] = quantization_config.input_activation + share_qparams_with_input_act0_qspec = SharedQuantizationSpec( + (first_input_node, node) + ) for input_node in input_nodes[1:]: if input_node not in input_qspec_map: assert isinstance(input_node, Node) - input_qspec_map[input_node] = share_qparams_with_input_act0_qspec + if _is_float_tensor(input_node): + input_qspec_map[input_node] = share_qparams_with_input_act0_qspec node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, @@ -1140,7 +1151,6 @@ def annotate_unbind(node: Node, quantization_config: QuantizationConfig) -> None # Seems like unbind.int can be either float or int. Only quant when input is float. if _is_annotated([node]) or not _is_float_tensor(node.args[0]): return - input_qspec_map = {} input_act = node.args[0] assert isinstance(input_act, Node) diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 33237f3bebe..bda91609f1c 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -6,7 +6,10 @@ from typing import Sequence import torch -from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY +from executorch.backends.qualcomm.quantizer.annotators import ( + _is_float_tensor, + QUANT_ANNOTATION_KEY, +) from executorch.backends.qualcomm.quantizer.quantizer import ( get_16a8w_qnn_ptq_config, get_8a8w_qnn_ptq_config, @@ -23,6 +26,38 @@ from torch.fx import Node +def annotate_mimi_decoder(gm: torch.fx.GraphModule): + """ + The 1st transpose conv in mimi decoder is really sensitive to scale/offset in 16a8w, which causes execution failure. + Annotate 1st transpose conv as 8a8w to prevent execution failure. + """ + quantization_config_8a8w = get_8a8w_qnn_ptq_config() + for node in gm.graph.nodes: + if not _is_float_tensor(node): + continue + elif node.target == torch.ops.aten.conv_transpose1d.default: + input_qspec_map = {} + input_act = node.args[0] + assert isinstance(input_act, Node) + input_spec = quantization_config_8a8w.input_activation + input_qspec_map[input_act] = input_spec + + weight = node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = quantization_config_8a8w.weight + + if len(node.args) > 2 and isinstance(node.args[2], Node): + bias = node.args[2] + input_qspec_map[bias] = quantization_config_8a8w.bias + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config_8a8w.output_activation, + _annotated=True, + ) + break + + def annotate_linear_16a8w_in_affine_layer(gm: torch.fx.GraphModule) -> None: def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: input_qspec_map = {} diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 6570b4befcb..3cba9db39b0 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -190,6 +190,14 @@ def forward(self, x, y): return torch.cat((y, y, x, x), axis=2) +class CDist(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.cdist(x, y, p=2) + + class Ceil(torch.nn.Module): def __init__(self): super().__init__() @@ -1588,3 +1596,14 @@ def forward(self, x): return torch.nn.functional.softmax( torch.where(x >= 0, 0.1, float("-inf")), dim=-1 ) + + +# Mimi Decoder has 0D tensor which QNN cannot handle. +class ZeroDimTensor(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + input1 = torch.zeros(1) + selected_element = torch.select(input1, 0, 0) + return torch.add(x, selected_element) diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 6579dc7c468..57beb201987 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -176,6 +176,14 @@ def test_qnn_backend_cat(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_cdist(self): + module = CDist() # noqa: F405 + sample_input = ( + torch.randn(1, 125, 256), + torch.randn(1, 2048, 256), + ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_chunk_single(self): module = Chunk() # noqa: F405 sample_input = (torch.randn(1, 1, 4, 3),) @@ -870,14 +878,14 @@ def test_qnn_backend_where(self): Where(), # noqa: F405 WhereConstant(torch.randn(3, 2), torch.randn(3, 2)), # noqa: F405 WhereConstantOther(), # noqa: F405 - # WhereConstantAll(), # noqa: F405 TODO: constant dtype does not propogate when doing const i64->32, causing where to fail since where does not support int64 output + WhereConstantAll(), # noqa: F405 WhereConstantInf(), # noqa: F405 ] sample_inputs = [ (torch.randn(3, 2), torch.randn(3, 2), torch.randn(3, 2)), (torch.randn(3, 2),), (torch.randn(3, 2),), - # (torch.randn(3, 2),), + (torch.randn(3, 2),), (torch.randn(30, 20),), ] for i, module in enumerate(modules): @@ -1006,6 +1014,11 @@ def test_qnn_backend_view_permute_matmul(self): sample_input = (torch.randn([1, 8, 512]), torch.randn([1, 2, 8, 256])) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_zero_dim_tensor(self): + module = ZeroDimTensor() # noqa: F405 + sample_input = (torch.randn(1, 256, 125),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_example_models(self): # TODO Fix MobileBertModelExample and TorchVisionViTModel instances = [ @@ -1199,6 +1212,15 @@ def test_qnn_backend_cat(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_cdist(self): + module = CDist() # noqa: F405 + sample_input = ( + torch.randn(1, 125, 256), + torch.randn(1, 2048, 256), + ) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_chunk_single(self): module = Chunk() # noqa: F405 sample_input = (torch.randn(1, 1, 4, 3),) @@ -2030,14 +2052,14 @@ def test_qnn_backend_where(self): Where(), # noqa: F405 WhereConstant(torch.randn(3, 2), torch.randn(3, 2)), # noqa: F405 WhereConstantOther(), # noqa: F405 - # WhereConstantAll(), # noqa: F405, TODO: constant dtype does not propogate when doing const i64->32, causing where to fail since where does not support int64 output + WhereConstantAll(), # noqa: F405 WhereConstantInf(), # noqa: F405 ] sample_inputs = [ (torch.randn(3, 2), torch.randn(3, 2), torch.randn(3, 2)), (torch.randn(3, 2),), (torch.randn(3, 2),), - # (torch.randn(3, 2),), + (torch.randn(3, 2),), (torch.randn(30, 20),), ] for i, module in enumerate(modules): @@ -2211,6 +2233,12 @@ def test_qnn_backend_view_permute_matmul(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_zero_dim_tensor(self): + module = ZeroDimTensor() # noqa: F405 + sample_input = (torch.randn(1, 256, 125),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_example_models(self): instances = [ { diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index f7b966ee8ea..e0ebc5beebe 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -16,7 +16,7 @@ import torch -from executorch.backends.qualcomm._passes import AnnotateStack +from executorch.backends.qualcomm._passes import AnnotateStack, AnnotateUnbind from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager from executorch.backends.qualcomm.builders.node_visitor import ( @@ -304,11 +304,12 @@ def get_decomp_table(passes_job) -> Dict[torch._ops.OperatorBase, Callable]: skip_decompositions = get_skip_decomp_table() # If we want to annotate the decomposed ops, then we should decompose the operation. - if passes_job and passes_job.get(AnnotateStack, False): + if passes_job: skip_decompositions = [ skip_decomp_op for skip_decomp_op in skip_decompositions - if skip_decomp_op not in AnnotateStack.decomp_ops + if skip_decomp_op + not in AnnotateStack.decomp_ops + AnnotateUnbind.decomp_ops ] remove_decompositions(source_decompositions, skip_decompositions) diff --git a/examples/qualcomm/oss_scripts/moshi/mimi.py b/examples/qualcomm/oss_scripts/moshi/mimi.py new file mode 100644 index 00000000000..6b59a71ae64 --- /dev/null +++ b/examples/qualcomm/oss_scripts/moshi/mimi.py @@ -0,0 +1,402 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# import argparse +import io +import json +import os +import random +from multiprocessing.connection import Client + +import numpy as np +import requests + +import sphn +import torch + +import torch.nn as nn +import torchaudio + +from executorch.backends.qualcomm.quantizer.custom_annotation import ( + annotate_mimi_decoder, +) +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype + +from executorch.examples.qualcomm.utils import ( + build_executorch_binary, + make_output_dir, + make_quantizer, + parse_skip_delegation_node, + setup_common_args_and_variables, + SimpleADB, +) + +from huggingface_hub import hf_hub_download +from moshi.models import loaders + +from torch.ao.quantization.observer import MinMaxObserver + + +def seed_all(seed): + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # for multi-GPU setups + random.seed(seed) + np.random.seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def read_mp3_from_url(url): + response = requests.get(url) + response.raise_for_status() # Ensure request is successful + + # Convert to a file-like object + audio_stream = io.BytesIO(response.content) + + # Load audio using torchaudio + waveform, sample_rate = torchaudio.load(audio_stream, format="mp3") + + return waveform.numpy(), sample_rate + + +def compute_scores(cpu_decode_res: torch.Tensor, htp_decode_res: torch.Tensor): + assert cpu_decode_res.shape == htp_decode_res.shape, "Tensor shapes do not match" + abs_diff = torch.abs(cpu_decode_res - htp_decode_res) + atol = torch.max(abs_diff) + print("Atol: ", atol) + + cpu_decode_res = cpu_decode_res.float() + htp_decode_res = htp_decode_res.float() + error = cpu_decode_res - htp_decode_res + original_power = torch.mean(torch.pow(cpu_decode_res, 2)) + error_power = torch.mean(torch.pow(error, 2)) + sqnr = 10 * torch.log10(original_power / error_power) + print("SQNR: ", sqnr) + + +def test_decoder_with_emb_input(mimi, args): + class MimiDecode(nn.Module): + def __init__(self, mimi: nn.Module): + super().__init__() + self.mimi_model = mimi + + def forward(self, x): + x = x.transpose(1, 2) + x = self.mimi_model.upsample(x) + (emb,) = self.mimi_model.decoder_transformer(x) + emb.transpose(1, 2) + with self.mimi_model._context_for_encoder_decoder: + out = self.mimi_model.decoder(emb) + return out + + emb_input = torch.rand(1, 1, 512, device="cpu") + mimi_decode = MimiDecode(mimi).eval() + cpu_res = mimi_decode(emb_input) + pte_filename = "mimi_decoder_emb_qnn" + + quantizer = make_quantizer( + quant_dtype=QuantDtype.use_16a8w, + per_channel_conv=True, + per_channel_linear=True, + act_observer=MinMaxObserver, + ) + quantizer.add_custom_quant_annotations((annotate_mimi_decoder,)) + + emb_inputs = [(emb_input,)] + build_executorch_binary( + mimi_decode, + emb_inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + emb_inputs, + custom_quantizer=quantizer, + quant_dtype=QuantDtype.use_16a8w, + shared_buffer=args.shared_buffer, + ) + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=f"{args.artifact}/{pte_filename}.pte", + workspace=f"/data/local/tmp/executorch/{pte_filename}", + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + ) + adb.push(inputs=emb_inputs, input_list="input_0_0.raw\n") + adb.execute() + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + + adb.pull(output_path=args.artifact) + + emb_predictions = [] + for i in range(len(emb_inputs)): + np_arr = np.fromfile( + os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 + ) + emb_predictions.append(torch.from_numpy(np_arr).view(1, 1, 1920)) + print("Emb input test results") + compute_scores(cpu_res, emb_predictions[0]) + + +def mimi_encode( + mimi, + encode_inputs, + encoder_input_list, + pcm_chunk_size, + skip_node_id_set, + skip_node_op_set, +) -> torch.Tensor: + class MimiEncode(nn.Module): + def __init__(self, mimi: nn.Module): + super().__init__() + self.mimi_model = mimi + + def forward(self, x): + return self.mimi_model.encode(x) + + mimi_encode_model = MimiEncode(mimi) + + pte_filename = "mimi_encoder_qnn" + build_executorch_binary( + mimi_encode_model.eval(), + encode_inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + encode_inputs, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + quant_dtype=QuantDtype.use_8a8w, + shared_buffer=args.shared_buffer, + ) + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=f"{args.artifact}/{pte_filename}.pte", + workspace=f"/data/local/tmp/executorch/{pte_filename}", + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + ) + adb.push(inputs=encode_inputs, input_list=encoder_input_list) + adb.execute() + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + + adb.pull(output_path=args.artifact) + + encoder_predictions = [] + # Num chunks should align with args.chunks_per_batch + num_chunks = encode_inputs[0][0].shape[-1] // pcm_chunk_size + for i in range(len(encode_inputs)): + np_arr = np.fromfile( + os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.int64 + ) + encoder_predictions.append(torch.from_numpy(np_arr).view(1, 8, num_chunks)) + return encoder_predictions + + +def mimi_decode( + mimi, encode_res_list, pcm_chunk_size, skip_node_id_set, skip_node_op_set +) -> torch.Tensor: + class MimiDecode(nn.Module): + def __init__(self, mimi: nn.Module): + super().__init__() + self.mimi_model = mimi + + def forward(self, x): + return self.mimi_model.decode(x) + + mimi_decode_model = MimiDecode(mimi) + decode_inputs, decode_input_list = [], "" + for index, encoder_res in enumerate(encode_res_list): + decode_inputs.append((encoder_res.to(torch.int32),)) + decode_input_list += f"input_{index}_0.raw\n" + + pte_filename = "mimi_decoder_qnn" + + quantizer = make_quantizer( + quant_dtype=QuantDtype.use_16a8w, + per_channel_conv=True, + per_channel_linear=True, + act_observer=MinMaxObserver, + ) + quantizer.add_custom_quant_annotations((annotate_mimi_decoder,)) + + build_executorch_binary( + mimi_decode_model.eval(), + decode_inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + decode_inputs, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + custom_quantizer=quantizer, + quant_dtype=QuantDtype.use_16a8w, + shared_buffer=args.shared_buffer, + ) + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=f"{args.artifact}/{pte_filename}.pte", + workspace=f"/data/local/tmp/executorch/{pte_filename}", + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + ) + adb.push(inputs=decode_inputs, input_list=decode_input_list) + adb.execute() + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + + adb.pull(output_path=args.artifact) + + decoder_predictions = [] + # Num chunks should align with args.chunks_per_batch + num_chunks = decode_inputs[0][0].shape[-1] + shape = num_chunks * pcm_chunk_size + for i in range(len(decode_inputs)): + np_arr = np.fromfile( + os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 + ) + decoder_predictions.append(torch.from_numpy(np_arr).view(1, 1, shape)) + htp_decode_res = torch.cat(decoder_predictions, dim=-1) + + return htp_decode_res + + +def export_mimi(mimi, args, max_duration_sec=10.0): + skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + os.makedirs(args.artifact, exist_ok=True) + + if args.emb_input_test: + test_decoder_with_emb_input(mimi, args) + return + + sample_rate = mimi.sample_rate + url = "https://huggingface.co/lmz/moshi-swift/resolve/main/bria-24khz.mp3" + sample_pcm, sample_sr = read_mp3_from_url(url) + sample_rate = mimi.sample_rate + sample_pcm = torch.tensor(sample_pcm, device="cpu") + max_duration_len = int(sample_rate * max_duration_sec) + if sample_pcm.shape[-1] > max_duration_len: + sample_pcm = sample_pcm[..., :max_duration_len] + sample_pcm = sample_pcm[None].to(device="cpu") + + encoder_inputs, encoder_input_list = [], "" + # 1920 chunk_size = 0.08sec + pcm_chunk_size = int(mimi.sample_rate / mimi.frame_rate) + batch_size = pcm_chunk_size * args.chunks_per_batch + count = 0 + for start_idx in range(0, sample_pcm.shape[-1], batch_size): + end_idx = min(sample_pcm.shape[-1], start_idx + batch_size) + chunk = sample_pcm[..., start_idx:end_idx] + encoder_inputs.append((chunk,)) + encoder_input_list += f"input_{count}_0.raw\n" + count += 1 + + print("streaming encoding...") + cpu_encode_res = mimi.encode(sample_pcm) + htp_encode_res = mimi_encode( + mimi, + encoder_inputs, + encoder_input_list, + pcm_chunk_size, + skip_node_id_set, + skip_node_op_set, + ) + + # Leave it here for now, uncomment this to check htp_encoder with cpu_decoder + # htp_res = torch.cat(htp_encode_res, dim=-1) + # cpu_decode_htp_encode = mimi.decode(htp_res) + # sphn.write_wav("cpu_decode_htp_encode.wav", cpu_decode_htp_encode[0, 0].cpu().numpy(), sample_rate) + + print("streaming decoding...") + cpu_decode_res = mimi.decode(cpu_encode_res) + # TODO: Enable streaming mode, which is the correct way to execute 1 chunk at a time. + # with mimi.streaming(1): + htp_decode_res = mimi_decode( + mimi, htp_encode_res, pcm_chunk_size, skip_node_id_set, skip_node_op_set + ) + compute_scores(cpu_decode_res, htp_decode_res) + + sphn.write_wav( + f"{args.artifact}/cpu_decode_res.wav", + cpu_decode_res[0, 0].cpu().numpy(), + sample_rate, + ) + sphn.write_wav( + f"{args.artifact}/htp_decode_res.wav", + htp_decode_res[0, 0].cpu().numpy(), + sample_rate, + ) + + +def main(args): + seed_all(42424242) + + print("loading mimi") + if args.mimi_weight is None: + args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME) + mimi = loaders.get_mimi(args.mimi_weight, "cpu") + print("mimi loaded") + + with torch.no_grad(): + export_mimi(mimi, args) + + +if __name__ == "__main__": + + parser = setup_common_args_and_variables() + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. Default ./mimi", + default="./mimi", + type=str, + ) + + parser.add_argument( + "--chunks_per_batch", + help="Number of chunks to process per time. Default is 1 chunk per batch, which equals to 0.08 second", + default=1, + type=int, + ) + + parser.add_argument( + "--emb_input_test", + help="This is just a metrics used to compute accuracy scores, not recommended for general users.", + action="store_true", + default=False, + ) + + parser.add_argument("--mimi-weight", type=str) + parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO) + + args = parser.parse_args() + try: + main(args) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e) diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 83e5b9a0442..242170712e1 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -288,7 +288,7 @@ def build_executorch_binary( skip_node_id_set=None, skip_node_op_set=None, quant_dtype: Optional[QuantDtype] = None, - custom_quantizer=None, + custom_quantizer: Optional[QnnQuantizer] = None, shared_buffer=False, metadata=None, dump_intermediate_outputs=False, @@ -325,8 +325,8 @@ def build_executorch_binary( shared_buffer=shared_buffer, dump_intermediate_outputs=dump_intermediate_outputs, ) - if quant_dtype is not None: - captured_model = torch.export.export(model, inputs, strict=True).module() + if quant_dtype is not None or custom_quantizer is not None: + captured_model = torch.export.export(model, inputs, strict=False).module() if qat_training_data: quantizer = custom_quantizer or make_quantizer( quant_dtype=quant_dtype, is_qat=True