diff --git a/backends/arm/common/annotation_meta.py b/backends/arm/common/annotation_meta.py new file mode 100644 index 00000000000..a857e36bb3f --- /dev/null +++ b/backends/arm/common/annotation_meta.py @@ -0,0 +1,19 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class ArmAnnotationInfo: + """ + Data class to carry Arm-specific annotation information through the pipeline. + This is intended to be attached to node.meta['custom'] and propagated + through partitioning and backend stages. As it's propagated through the pipeline, + it's intentionally minimal and only carries whether the node is quantized or not. + """ + + quantized: bool + CUSTOM_META_KEY: str = "_arm_annotation_info" diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index f7857894d40..d2630405176 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -19,6 +19,7 @@ FuseQuantizedActivationPass, ) from executorch.backends.arm._passes.insert_table_ops import TableOps +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo from executorch.backends.arm.constants import DQ_OPS, MAX_RANK, Q_OPS from executorch.backends.arm.operator_support.ethos_u55_support import ( EthosU55CastCheck, @@ -135,6 +136,7 @@ def tosa_support_factory( ] if not tosa_spec.support_float(): + negative_checks.append(CheckArmQuantized(reporter)) negative_checks.append(CheckProperQuantization(reporter)) if tosa_spec.is_U55_subset: negative_checks.append(EthosU55NotSupported(reporter)) @@ -162,7 +164,6 @@ class TOSAProINTSupportList(OperatorSupportBase): def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - return node.op == "call_function" and node.target in TOSA_PRO_INT_SupportList @@ -175,10 +176,80 @@ class TOSAProFPSupportList(OperatorSupportBase): def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList +class CheckArmQuantized(OperatorSupportBase): + """ + Check if the node was marked as quantized in the Arm backend. + This is used to ensure that nodes that were quantized in the Arm backend + are only partitioned if they are supported by the TOSA backend. + """ + + def __init__(self, reporter: WhyNoPartitionReporter): + self.reporter = reporter + + def _is_quantized(self, node: torch.fx.Node) -> bool: + """Checks if the node is quantized. + + A node is considered quantized if at least one criteria is met: + - Its dtype is not floating point or complex => integer + - It is one of the special cases where the node has been created in to_edge, e.g. + .Scalar operations that have been promoted .Tensor operations + where the scalar is replaced by a full op. + - It has been marked as quantized in the ArmAnnotationInfo custom meta. + + Args: + node (torch.fx.Node): The FX node to check. + + Returns: + bool: True if the node is quantized, False otherwise. + """ + node_dtype = get_first_fake_tensor(node).dtype + if not node_dtype.is_complex and not node_dtype.is_floating_point: + return True + if node.target in ( + exir_ops.edge.aten.full_like.default, + *ComputeConstantOpsAOT.targeted_ops, + ): + # Special cases where nodes have been created in to_edge, e.g. + # .Scalar operations that have been promoted .Tensor operations + # where the scalar is replaced by a full op. + if all(user.target in Q_OPS for user in node.users): + return True + for user in node.users: + if ( + user.target + == exir_ops.edge.dim_order_ops._to_dim_order_copy.default + ): + dim_order_dtype = get_first_fake_tensor(user).dtype + if dim_order_dtype.is_complex or dim_order_dtype.is_floating_point: + return False + else: + return False + return True + return ( + ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {}) + and node.meta["custom"][ArmAnnotationInfo.CUSTOM_META_KEY].quantized + ) + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + if node.op != "call_function": + return False + + if node.target in (*DQ_OPS, *Q_OPS): + return True + + if not self._is_quantized(node): + self.reporter.report_reject( + node, "Node was not marked as quantized in the Arm backend." + ) + return False + return True + + class CheckProperQuantization(OperatorSupportBase): """ For targeted nodes, check that it has been quantized as expected. In most cases this means that a pair of quantize @@ -351,7 +422,6 @@ def inside_int32_bounds(self, node: torch.fx.Node) -> bool: def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - vals = node.meta["val"] tensor_list = vals if isinstance(vals, (list, tuple)) else [vals] @@ -419,7 +489,6 @@ def is_node_supported( class CheckFloat64Inputs(OperatorSupportBase): - def __init__( self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter ): @@ -429,7 +498,6 @@ def __init__( def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - for input_node in node.all_input_nodes: tensor = get_first_fake_tensor(input_node) if tensor.dtype == torch.float64: diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index 90876386aa6..9ba0025fa42 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -15,6 +15,8 @@ from typing import cast +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo + from torch.fx import Node from torchao.quantization.pt2e.quantizer import QuantizationAnnotation @@ -66,4 +68,10 @@ def mark_node_as_annotated(node: Node) -> None: """ if Q_ANNOTATION_KEY not in node.meta: node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation() + annotation_info = ArmAnnotationInfo( + quantized=True, + ) node.meta[Q_ANNOTATION_KEY]._annotated = True + meta_custom = node.meta.get("custom", {}) + meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = annotation_info + node.meta["custom"] = meta_custom diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index b429bacd738..c1261b9b6ed 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -324,6 +324,7 @@ def _match_pattern( torch.ops.aten.view.default, torch.ops.aten.view_as.default, torch.ops.aten.view_copy.default, + torch.ops.aten._unsafe_view.default, torch.ops.aten.select.int, torch.ops.aten.select_copy.int, torch.ops.aten.slice.Tensor, @@ -356,6 +357,7 @@ def _match_pattern( ] _one_to_one_shared_input_or_input_act_qspec = [ + torch.ops.aten.alias.default, torch.ops.aten.clone.default, torch.ops.aten.hardtanh.default, torch.ops.aten.hardtanh_.default, @@ -588,10 +590,10 @@ def any_or_hardtanh_min_zero(n: Node): ] quant_properties.quant_output = None elif node.target in [ - torch.ops.aten.scalar_tensor.default, torch.ops.aten.full.default, torch.ops.aten.full, torch.ops.aten.fill_.Scalar, + torch.ops.aten.scalar_tensor.default, ]: quant_properties.quant_inputs = [] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) diff --git a/backends/arm/test/misc/test_int64.py b/backends/arm/test/misc/test_int64.py index d6d6d6cb39c..46a97fff1df 100644 --- a/backends/arm/test/misc/test_int64.py +++ b/backends/arm/test/misc/test_int64.py @@ -68,10 +68,6 @@ def forward(self, x: torch.Tensor): ConstAdd(torch.int64, 2**40), (torch.rand(10) - 0.5,), ), - "int64_in+float_const": ( - ConstAdd(torch.float32), - (torch.randint(0, 10, (10,)),), - ), "fp32_in+int64_buffer_chain": ( BufferChainAdd(torch.int64), (torch.rand(2, 5, 3) - 0.5,), @@ -94,7 +90,7 @@ def test_int64_tosa_FP(test_data: Tuple): ArmTester( model, inputs, - common.get_tosa_compile_spec("TOSA-1.0+FP", custom_path="tosa/int64"), + common.get_tosa_compile_spec("TOSA-1.0+FP"), ) .export() .to_edge_transform_and_lower() diff --git a/backends/arm/test/misc/test_quant_custom_meta.py b/backends/arm/test/misc/test_quant_custom_meta.py new file mode 100644 index 00000000000..d18a1d39e45 --- /dev/null +++ b/backends/arm/test/misc/test_quant_custom_meta.py @@ -0,0 +1,100 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize + + +class AddSigmoidMul(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x, y): + return self.sigmoid(x + y) * x + + +def get_selective_quantizer(modules): + quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) + quantizer.set_global(get_symmetric_quantization_config()) + for module in modules: + quantizer.set_module_type(module, None) + + return Quantize(quantizer, get_symmetric_quantization_config()) + + +def test_qdq_squeezed_fp_op(): + """Test that a float operation surrounded by quantize-dequantize pairs + is correctly handled by the partitioner and the TOSA backend. + Pattern: + q -> dq -> add -> q -> dq -> sigmoid -> q -> dq -> mul -> dq -> q + |_____Non-delegated____| + """ + aten_op = "torch.ops.aten.add.Tensor" + exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor" + module = AddSigmoidMul() + x = torch.randn(2, 3, 4) + y = torch.randn(2, 3, 4) + pipeline = TosaPipelineINT( + module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op + ) + pipeline.change_args("quantize", get_selective_quantizer([torch.nn.Sigmoid])) + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 2, + "executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + }, + ) + pipeline.run() + + +class MulAddSigmoidConv(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.sigmoid = torch.nn.Sigmoid() + self.conv = torch.nn.Conv1d(3, 3, 1) + + def forward(self, x, y): + return self.conv(self.sigmoid(x + y * x)) + + +def test_quantized_to_float_transition(): + """Test that a model executing quantized ops followed by float ops + is correctly handled by the partitioner and the TOSA backend. + Pattern: + q -> dq -> mul -> q -> dq -> add -> q -> dq -> sigmoid -> conv + |____Non-delegated___| + """ + aten_op = "torch.ops.aten.add.Tensor" + exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor" + module = MulAddSigmoidConv() + x = torch.randn(2, 3, 4) + y = torch.randn(2, 3, 4) + pipeline = TosaPipelineINT( + module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op + ) + pipeline.change_args( + "quantize", get_selective_quantizer([torch.nn.Sigmoid, torch.nn.Conv1d]) + ) + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 1, + "executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1, + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + }, + ) + pipeline.run() diff --git a/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py b/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py index 9506fe727db..a87736d80cd 100644 --- a/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py +++ b/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py @@ -37,7 +37,8 @@ class TestSD3Transformer2DModel: ops_after_partitioner_INT = { "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2, - "torch.ops.higher_order.executorch_call_delegate": 2, + "torch.ops.higher_order.executorch_call_delegate": 3, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, } def _prepare_inputs( diff --git a/backends/arm/test/models/test_nn_functional.py b/backends/arm/test/models/test_nn_functional.py index 4896074b544..e585e82ad9d 100644 --- a/backends/arm/test/models/test_nn_functional.py +++ b/backends/arm/test/models/test_nn_functional.py @@ -102,7 +102,6 @@ def test_nn_functional_FP(test_data): @parametrize( "test_data", module_tests, - {"normalize": "MLETORCH-1255: Unsupported dtype in InsertTableOpsPass"}, ) def test_nn_functional_INT(test_data): module, inputs = test_data diff --git a/backends/arm/test/ops/test_eye.py b/backends/arm/test/ops/test_eye.py index eef32259c10..5c829acc145 100644 --- a/backends/arm/test/ops/test_eye.py +++ b/backends/arm/test/ops/test_eye.py @@ -95,7 +95,7 @@ def test_eye_u85_INT(test_data: test_data_t): input_data(), EyeAdd.aten_op, use_to_edge_transform_and_lower=True, - ).dump_artifact("to_edge_transform_and_lower") + ) pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 6eb1dcbef72..4d3a1eaf1b1 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -299,6 +299,10 @@ def ops_to_not_decompose( torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardswish.default, ] + constant_ops_to_not_decompose = [ + torch.ops.aten.eye.default, + torch.ops.aten.linspace.default, + ] def filter_fn(node: torch.fx.Node) -> bool: """Return True to keep selected ops intact inside quantized regions. @@ -315,35 +319,62 @@ def filter_fn(node: torch.fx.Node) -> bool: bool: True to keep the op intact; otherwise, False. """ + dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default q = torch.ops.quantized_decomposed.quantize_per_tensor.default if node.target in ops_to_not_decompose_if_quant_op: - # Assume we should not decompose the operator (it is quantized) - should_not_decompose = True - input_nodes = node.all_input_nodes - ouput_nodes = node.users + output_nodes = node.users - for inp in input_nodes: - if inp.target != dq: - should_not_decompose = False + should_not_decompose = all(inp.target == dq for inp in input_nodes) - for out in ouput_nodes: - if out.target != q: - should_not_decompose = False + should_not_decompose = should_not_decompose and all( + out.target == q for out in output_nodes + ) return should_not_decompose + elif node.target in constant_ops_to_not_decompose: + # We only want to tag nodes as do_not_decompose if we are sure that + # we can partition them. We partition them if one or more of the + # following is true: + # 1. The TOSA spec supports floating point. + # 2. The node outputs an integer type. + # 3. All the node outputs are quantized. + # 4. All users cast the output to an integer type. + # If none of the above is true we will not tag the node and + # it will be decomposed. + if self.tosa_spec.support_float(): + return True + + dtype = get_first_fake_tensor(node).dtype + if not dtype.is_floating_point and not dtype.is_complex: + return True + + output_nodes = node.users + if all(out.target == q for out in output_nodes): + return True + + for user in output_nodes: + if user.target == torch.ops.aten.to.dtype: + cast_dtype = get_first_fake_tensor(user).dtype + if cast_dtype.is_complex or cast_dtype.is_floating_point: + return False + else: + return False + return False # By default, do not decompose the operator return True - ops_to_not_decompose = [ - torch.ops.aten.linear.default, - torch.ops.aten.eye.default, - torch.ops.aten.linspace.default, - torch.ops.aten.logit.default, - ] + ops_to_not_decompose_if_quant_op + ops_to_not_decompose = ( + [ + torch.ops.aten.linear.default, + torch.ops.aten.logit.default, + ] + + ops_to_not_decompose_if_quant_op + + constant_ops_to_not_decompose + ) if not self.tosa_spec.is_U55_subset: # Tosa operator "RESIZE" is not supported on U55. Since upsample_bilinear2d