Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions backends/arm/common/annotation_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass


@dataclass(frozen=True)
class ArmAnnotationInfo:
"""
Data class to carry Arm-specific annotation information through the pipeline.
This is intended to be attached to node.meta['custom'] and propagated
through partitioning and backend stages. As it's propagated through the pipeline,
it's intentionally minimal and only carries whether the node is quantized or not.
"""

quantized: bool
CUSTOM_META_KEY: str = "_arm_annotation_info"
78 changes: 73 additions & 5 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
FuseQuantizedActivationPass,
)
from executorch.backends.arm._passes.insert_table_ops import TableOps
from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo
from executorch.backends.arm.constants import DQ_OPS, MAX_RANK, Q_OPS
from executorch.backends.arm.operator_support.ethos_u55_support import (
EthosU55CastCheck,
Expand Down Expand Up @@ -135,6 +136,7 @@ def tosa_support_factory(
]

if not tosa_spec.support_float():
negative_checks.append(CheckArmQuantized(reporter))
negative_checks.append(CheckProperQuantization(reporter))
if tosa_spec.is_U55_subset:
negative_checks.append(EthosU55NotSupported(reporter))
Expand Down Expand Up @@ -162,7 +164,6 @@ class TOSAProINTSupportList(OperatorSupportBase):
def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:

return node.op == "call_function" and node.target in TOSA_PRO_INT_SupportList


Expand All @@ -175,10 +176,80 @@ class TOSAProFPSupportList(OperatorSupportBase):
def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:

return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList


class CheckArmQuantized(OperatorSupportBase):
"""
Check if the node was marked as quantized in the Arm backend.
This is used to ensure that nodes that were quantized in the Arm backend
are only partitioned if they are supported by the TOSA backend.
"""

def __init__(self, reporter: WhyNoPartitionReporter):
self.reporter = reporter

def _is_quantized(self, node: torch.fx.Node) -> bool:
"""Checks if the node is quantized.

A node is considered quantized if at least one criteria is met:
- Its dtype is not floating point or complex => integer
- It is one of the special cases where the node has been created in to_edge, e.g.
.Scalar operations that have been promoted .Tensor operations
where the scalar is replaced by a full op.
- It has been marked as quantized in the ArmAnnotationInfo custom meta.

Args:
node (torch.fx.Node): The FX node to check.

Returns:
bool: True if the node is quantized, False otherwise.
"""
node_dtype = get_first_fake_tensor(node).dtype
if not node_dtype.is_complex and not node_dtype.is_floating_point:
return True
if node.target in (
exir_ops.edge.aten.full_like.default,
*ComputeConstantOpsAOT.targeted_ops,
):
# Special cases where nodes have been created in to_edge, e.g.
# .Scalar operations that have been promoted .Tensor operations
# where the scalar is replaced by a full op.
if all(user.target in Q_OPS for user in node.users):
return True
for user in node.users:
if (
user.target
== exir_ops.edge.dim_order_ops._to_dim_order_copy.default
):
dim_order_dtype = get_first_fake_tensor(user).dtype
if dim_order_dtype.is_complex or dim_order_dtype.is_floating_point:
return False
else:
return False
return True
return (
ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {})
and node.meta["custom"][ArmAnnotationInfo.CUSTOM_META_KEY].quantized
)

def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:
if node.op != "call_function":
return False

if node.target in (*DQ_OPS, *Q_OPS):
return True

if not self._is_quantized(node):
self.reporter.report_reject(
node, "Node was not marked as quantized in the Arm backend."
)
return False
return True


class CheckProperQuantization(OperatorSupportBase):
"""
For targeted nodes, check that it has been quantized as expected. In most cases this means that a pair of quantize
Expand Down Expand Up @@ -351,7 +422,6 @@ def inside_int32_bounds(self, node: torch.fx.Node) -> bool:
def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:

vals = node.meta["val"]
tensor_list = vals if isinstance(vals, (list, tuple)) else [vals]

Expand Down Expand Up @@ -419,7 +489,6 @@ def is_node_supported(


class CheckFloat64Inputs(OperatorSupportBase):

def __init__(
self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter
):
Expand All @@ -429,7 +498,6 @@ def __init__(
def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:

for input_node in node.all_input_nodes:
tensor = get_first_fake_tensor(input_node)
if tensor.dtype == torch.float64:
Expand Down
10 changes: 9 additions & 1 deletion backends/arm/quantizer/arm_quantizer_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -15,6 +15,8 @@

from typing import cast

from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo

from torch.fx import Node

from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
Expand Down Expand Up @@ -66,4 +68,10 @@ def mark_node_as_annotated(node: Node) -> None:
"""
if Q_ANNOTATION_KEY not in node.meta:
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation()
annotation_info = ArmAnnotationInfo(
quantized=True,
)
node.meta[Q_ANNOTATION_KEY]._annotated = True
meta_custom = node.meta.get("custom", {})
meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = annotation_info
node.meta["custom"] = meta_custom
4 changes: 3 additions & 1 deletion backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def _match_pattern(
torch.ops.aten.view.default,
torch.ops.aten.view_as.default,
torch.ops.aten.view_copy.default,
torch.ops.aten._unsafe_view.default,
torch.ops.aten.select.int,
torch.ops.aten.select_copy.int,
torch.ops.aten.slice.Tensor,
Expand Down Expand Up @@ -356,6 +357,7 @@ def _match_pattern(
]

_one_to_one_shared_input_or_input_act_qspec = [
torch.ops.aten.alias.default,
torch.ops.aten.clone.default,
torch.ops.aten.hardtanh.default,
torch.ops.aten.hardtanh_.default,
Expand Down Expand Up @@ -588,10 +590,10 @@ def any_or_hardtanh_min_zero(n: Node):
]
quant_properties.quant_output = None
elif node.target in [
torch.ops.aten.scalar_tensor.default,
torch.ops.aten.full.default,
torch.ops.aten.full,
torch.ops.aten.fill_.Scalar,
torch.ops.aten.scalar_tensor.default,
]:
quant_properties.quant_inputs = []
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
Expand Down
6 changes: 1 addition & 5 deletions backends/arm/test/misc/test_int64.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,6 @@ def forward(self, x: torch.Tensor):
ConstAdd(torch.int64, 2**40),
(torch.rand(10) - 0.5,),
),
"int64_in+float_const": (
ConstAdd(torch.float32),
(torch.randint(0, 10, (10,)),),
),
"fp32_in+int64_buffer_chain": (
BufferChainAdd(torch.int64),
(torch.rand(2, 5, 3) - 0.5,),
Expand All @@ -94,7 +90,7 @@ def test_int64_tosa_FP(test_data: Tuple):
ArmTester(
model,
inputs,
common.get_tosa_compile_spec("TOSA-1.0+FP", custom_path="tosa/int64"),
common.get_tosa_compile_spec("TOSA-1.0+FP"),
)
.export()
.to_edge_transform_and_lower()
Expand Down
100 changes: 100 additions & 0 deletions backends/arm/test/misc/test_quant_custom_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm.quantizer import (
get_symmetric_quantization_config,
TOSAQuantizer,
)
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT
from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.xnnpack.test.tester import Quantize


class AddSigmoidMul(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.sigmoid = torch.nn.Sigmoid()

def forward(self, x, y):
return self.sigmoid(x + y) * x


def get_selective_quantizer(modules):
quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT"))
quantizer.set_global(get_symmetric_quantization_config())
for module in modules:
quantizer.set_module_type(module, None)

return Quantize(quantizer, get_symmetric_quantization_config())


def test_qdq_squeezed_fp_op():
"""Test that a float operation surrounded by quantize-dequantize pairs
is correctly handled by the partitioner and the TOSA backend.
Pattern:
q -> dq -> add -> q -> dq -> sigmoid -> q -> dq -> mul -> dq -> q
|_____Non-delegated____|
"""
aten_op = "torch.ops.aten.add.Tensor"
exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"
module = AddSigmoidMul()
x = torch.randn(2, 3, 4)
y = torch.randn(2, 3, 4)
pipeline = TosaPipelineINT(
module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op
)
pipeline.change_args("quantize", get_selective_quantizer([torch.nn.Sigmoid]))
pipeline.change_args(
"check_count.exir",
{
"torch.ops.higher_order.executorch_call_delegate": 2,
"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1,
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2,
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3,
},
)
pipeline.run()


class MulAddSigmoidConv(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.sigmoid = torch.nn.Sigmoid()
self.conv = torch.nn.Conv1d(3, 3, 1)

def forward(self, x, y):
return self.conv(self.sigmoid(x + y * x))


def test_quantized_to_float_transition():
"""Test that a model executing quantized ops followed by float ops
is correctly handled by the partitioner and the TOSA backend.
Pattern:
q -> dq -> mul -> q -> dq -> add -> q -> dq -> sigmoid -> conv
|____Non-delegated___|
"""
aten_op = "torch.ops.aten.add.Tensor"
exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"
module = MulAddSigmoidConv()
x = torch.randn(2, 3, 4)
y = torch.randn(2, 3, 4)
pipeline = TosaPipelineINT(
module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op
)
pipeline.change_args(
"quantize", get_selective_quantizer([torch.nn.Sigmoid, torch.nn.Conv1d])
)
pipeline.change_args(
"check_count.exir",
{
"torch.ops.higher_order.executorch_call_delegate": 1,
"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1,
"executorch_exir_dialects_edge__ops_aten_convolution_default": 1,
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1,
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
},
)
pipeline.run()
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class TestSD3Transformer2DModel:

ops_after_partitioner_INT = {
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2,
"torch.ops.higher_order.executorch_call_delegate": 2,
"torch.ops.higher_order.executorch_call_delegate": 3,
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
}

def _prepare_inputs(
Expand Down
1 change: 0 additions & 1 deletion backends/arm/test/models/test_nn_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def test_nn_functional_FP(test_data):
@parametrize(
"test_data",
module_tests,
{"normalize": "MLETORCH-1255: Unsupported dtype in InsertTableOpsPass"},
)
def test_nn_functional_INT(test_data):
module, inputs = test_data
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_eye.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_eye_u85_INT(test_data: test_data_t):
input_data(),
EyeAdd.aten_op,
use_to_edge_transform_and_lower=True,
).dump_artifact("to_edge_transform_and_lower")
)
pipeline.pop_stage("check.quant_nodes")
pipeline.run()

Expand Down
Loading
Loading