Skip to content

Commit ef42e7f

Browse files
peroscarandersson8218
authored andcommitted
Arm backend: Propagate node info from quantizer to backend
Use the Node meta 'custom' field to propagate information from quantizer to partitioner using a new ArmAnnotationInfo data class. This allows us to track quantized node reliably which is useful in order to track which nodes should 'fold' it's quantization parameter and which should be kept in fp when mixing integer and float in a sub-graph. Co-authored-by: Per Åstrand <[email protected]> Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I31309d65cac50e497318eae8678880684ec77cda
1 parent 4ea9ddf commit ef42e7f

File tree

10 files changed

+250
-21
lines changed

10 files changed

+250
-21
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from dataclasses import dataclass
7+
8+
9+
@dataclass(frozen=True)
10+
class ArmAnnotationInfo:
11+
"""
12+
Data class to carry Arm-specific annotation information through the pipeline.
13+
This is intended to be attached to node.meta['custom'] and propagated
14+
through partitioning and backend stages. As it's propagated through the pipeline,
15+
it's intentionally minimal and only carries whether the node is quantized or not.
16+
"""
17+
18+
quantized: bool
19+
CUSTOM_META_KEY: str = "_arm_annotation_info"

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
FuseQuantizedActivationPass,
1919
)
2020
from executorch.backends.arm._passes.insert_table_ops import TableOps
21+
from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo
2122
from executorch.backends.arm.constants import DQ_OPS, MAX_RANK, Q_OPS
2223
from executorch.backends.arm.operator_support.ethos_u55_support import (
2324
EthosU55CastCheck,
@@ -134,6 +135,7 @@ def tosa_support_factory(
134135
]
135136

136137
if not tosa_spec.support_float():
138+
negative_checks.append(CheckArmQuantized(reporter))
137139
negative_checks.append(CheckProperQuantization(reporter))
138140
if tosa_spec.is_U55_subset:
139141
negative_checks.append(EthosU55NotSupported(reporter))
@@ -161,7 +163,6 @@ class TOSAProINTSupportList(OperatorSupportBase):
161163
def is_node_supported(
162164
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
163165
) -> bool:
164-
165166
return node.op == "call_function" and node.target in TOSA_PRO_INT_SupportList
166167

167168

@@ -174,10 +175,80 @@ class TOSAProFPSupportList(OperatorSupportBase):
174175
def is_node_supported(
175176
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
176177
) -> bool:
177-
178178
return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList
179179

180180

181+
class CheckArmQuantized(OperatorSupportBase):
182+
"""
183+
Check if the node was marked as quantized in the Arm backend.
184+
This is used to ensure that nodes that were quantized in the Arm backend
185+
are only partitioned if they are supported by the TOSA backend.
186+
"""
187+
188+
def __init__(self, reporter: WhyNoPartitionReporter):
189+
self.reporter = reporter
190+
191+
def _is_quantized(self, node: torch.fx.Node) -> bool:
192+
"""Checks if the node is quantized.
193+
194+
A node is considered quantized if at least one criteria is met:
195+
- Its dtype is not floating point or complex => integer
196+
- It is one of the special cases where the node has been created in to_edge, e.g.
197+
.Scalar operations that have been promoted .Tensor operations
198+
where the scalar is replaced by a full op.
199+
- It has been marked as quantized in the ArmAnnotationInfo custom meta.
200+
201+
Args:
202+
node (torch.fx.Node): The FX node to check.
203+
204+
Returns:
205+
bool: True if the node is quantized, False otherwise.
206+
"""
207+
node_dtype = get_first_fake_tensor(node).dtype
208+
if not node_dtype.is_complex and not node_dtype.is_floating_point:
209+
return True
210+
if node.target in (
211+
exir_ops.edge.aten.full_like.default,
212+
*ComputeConstantOpsAOT.targeted_ops,
213+
):
214+
# Special cases where nodes have been created in to_edge, e.g.
215+
# .Scalar operations that have been promoted .Tensor operations
216+
# where the scalar is replaced by a full op.
217+
if all(user.target in Q_OPS for user in node.users):
218+
return True
219+
for user in node.users:
220+
if (
221+
user.target
222+
== exir_ops.edge.dim_order_ops._to_dim_order_copy.default
223+
):
224+
dim_order_dtype = get_first_fake_tensor(user).dtype
225+
if dim_order_dtype.is_complex or dim_order_dtype.is_floating_point:
226+
return False
227+
else:
228+
return False
229+
return True
230+
return (
231+
ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {})
232+
and node.meta["custom"][ArmAnnotationInfo.CUSTOM_META_KEY].quantized
233+
)
234+
235+
def is_node_supported(
236+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
237+
) -> bool:
238+
if node.op != "call_function":
239+
return False
240+
241+
if node.target in (*DQ_OPS, *Q_OPS):
242+
return True
243+
244+
if not self._is_quantized(node):
245+
self.reporter.report_reject(
246+
node, "Node was not marked as quantized in the Arm backend."
247+
)
248+
return False
249+
return True
250+
251+
181252
class CheckProperQuantization(OperatorSupportBase):
182253
"""
183254
For targeted nodes, check that it has been quantized as expected. In most cases this means that a pair of quantize
@@ -350,7 +421,6 @@ def inside_int32_bounds(self, node: torch.fx.Node) -> bool:
350421
def is_node_supported(
351422
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
352423
) -> bool:
353-
354424
vals = node.meta["val"]
355425
tensor_list = vals if isinstance(vals, (list, tuple)) else [vals]
356426

@@ -416,7 +486,6 @@ def is_node_supported(
416486

417487

418488
class CheckFloat64Inputs(OperatorSupportBase):
419-
420489
def __init__(
421490
self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter
422491
):
@@ -426,7 +495,6 @@ def __init__(
426495
def is_node_supported(
427496
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
428497
) -> bool:
429-
430498
for input_node in node.all_input_nodes:
431499
tensor = get_first_fake_tensor(input_node)
432500
if tensor.dtype == torch.float64:

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
# All rights reserved.
3+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
@@ -14,6 +14,8 @@
1414

1515
from typing import cast
1616

17+
from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo
18+
1719
from torch.fx import Node
1820

1921
from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
@@ -65,4 +67,10 @@ def mark_node_as_annotated(node: Node) -> None:
6567
"""
6668
if Q_ANNOTATION_KEY not in node.meta:
6769
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation()
70+
annotation_info = ArmAnnotationInfo(
71+
quantized=True,
72+
)
6873
node.meta[Q_ANNOTATION_KEY]._annotated = True
74+
meta_custom = node.meta.get("custom", {})
75+
meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = annotation_info
76+
node.meta["custom"] = meta_custom

backends/arm/quantizer/quantization_annotator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ def _match_pattern(
394394
torch.ops.aten.view.default,
395395
torch.ops.aten.view_as.default,
396396
torch.ops.aten.view_copy.default,
397+
torch.ops.aten._unsafe_view.default,
397398
torch.ops.aten.select.int,
398399
torch.ops.aten.select_copy.int,
399400
torch.ops.aten.slice.Tensor,
@@ -426,6 +427,7 @@ def _match_pattern(
426427
]
427428

428429
_one_to_one_shared_input_or_input_act_qspec = [
430+
torch.ops.aten.alias.default,
429431
torch.ops.aten.clone.default,
430432
torch.ops.aten.hardtanh.default,
431433
torch.ops.aten.hardtanh_.default,
@@ -693,10 +695,10 @@ def any_or_hardtanh_min_zero(n: Node):
693695
]
694696
quant_properties.quant_output = None
695697
elif node.target in [
696-
torch.ops.aten.scalar_tensor.default,
697698
torch.ops.aten.full.default,
698699
torch.ops.aten.full,
699700
torch.ops.aten.fill_.Scalar,
701+
torch.ops.aten.scalar_tensor.default,
700702
]:
701703
quant_properties.quant_inputs = []
702704
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)

backends/arm/test/misc/test_int64.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,6 @@ def forward(self, x: torch.Tensor):
6868
ConstAdd(torch.int64, 2**40),
6969
(torch.rand(10) - 0.5,),
7070
),
71-
"int64_in+float_const": (
72-
ConstAdd(torch.float32),
73-
(torch.randint(0, 10, (10,)),),
74-
),
7571
"fp32_in+int64_buffer_chain": (
7672
BufferChainAdd(torch.int64),
7773
(torch.rand(2, 5, 3) - 0.5,),
@@ -94,7 +90,7 @@ def test_int64_tosa_FP(test_data: Tuple):
9490
ArmTester(
9591
model,
9692
inputs,
97-
common.get_tosa_compile_spec("TOSA-1.0+FP", custom_path="tosa/int64"),
93+
common.get_tosa_compile_spec("TOSA-1.0+FP"),
9894
)
9995
.export()
10096
.to_edge_transform_and_lower()
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm.quantizer import (
8+
get_symmetric_quantization_config,
9+
TOSAQuantizer,
10+
)
11+
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT
12+
from executorch.backends.arm.tosa import TosaSpecification
13+
from executorch.backends.xnnpack.test.tester import Quantize
14+
15+
16+
class AddSigmoidMul(torch.nn.Module):
17+
def __init__(self, *args, **kwargs):
18+
super().__init__(*args, **kwargs)
19+
self.sigmoid = torch.nn.Sigmoid()
20+
21+
def forward(self, x, y):
22+
return self.sigmoid(x + y) * x
23+
24+
25+
def get_selective_quantizer(modules):
26+
quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT"))
27+
quantizer.set_global(get_symmetric_quantization_config())
28+
for module in modules:
29+
quantizer.set_module_type(module, None)
30+
31+
return Quantize(quantizer, get_symmetric_quantization_config())
32+
33+
34+
def test_qdq_squeezed_fp_op():
35+
"""Test that a float operation surrounded by quantize-dequantize pairs
36+
is correctly handled by the partitioner and the TOSA backend.
37+
Pattern:
38+
q -> dq -> add -> q -> dq -> sigmoid -> q -> dq -> mul -> dq -> q
39+
|_____Non-delegated____|
40+
"""
41+
aten_op = "torch.ops.aten.add.Tensor"
42+
exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"
43+
module = AddSigmoidMul()
44+
x = torch.randn(2, 3, 4)
45+
y = torch.randn(2, 3, 4)
46+
pipeline = TosaPipelineINT(
47+
module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op
48+
)
49+
pipeline.change_args("quantize", get_selective_quantizer([torch.nn.Sigmoid]))
50+
pipeline.change_args(
51+
"check_count.exir",
52+
{
53+
"torch.ops.higher_order.executorch_call_delegate": 2,
54+
"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1,
55+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2,
56+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3,
57+
},
58+
)
59+
pipeline.run()
60+
61+
62+
class MulAddSigmoidConv(torch.nn.Module):
63+
def __init__(self, *args, **kwargs):
64+
super().__init__(*args, **kwargs)
65+
self.sigmoid = torch.nn.Sigmoid()
66+
self.conv = torch.nn.Conv1d(3, 3, 1)
67+
68+
def forward(self, x, y):
69+
return self.conv(self.sigmoid(x + y * x))
70+
71+
72+
def test_quantized_to_float_transition():
73+
"""Test that a model executing quantized ops followed by float ops
74+
is correctly handled by the partitioner and the TOSA backend.
75+
Pattern:
76+
q -> dq -> mul -> q -> dq -> add -> q -> dq -> sigmoid -> conv
77+
|____Non-delegated___|
78+
"""
79+
aten_op = "torch.ops.aten.add.Tensor"
80+
exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"
81+
module = MulAddSigmoidConv()
82+
x = torch.randn(2, 3, 4)
83+
y = torch.randn(2, 3, 4)
84+
pipeline = TosaPipelineINT(
85+
module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op
86+
)
87+
pipeline.change_args(
88+
"quantize", get_selective_quantizer([torch.nn.Sigmoid, torch.nn.Conv1d])
89+
)
90+
pipeline.change_args(
91+
"check_count.exir",
92+
{
93+
"torch.ops.higher_order.executorch_call_delegate": 1,
94+
"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1,
95+
"executorch_exir_dialects_edge__ops_aten_convolution_default": 1,
96+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1,
97+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
98+
},
99+
)
100+
pipeline.run()

backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ class TestSD3Transformer2DModel:
3737

3838
ops_after_partitioner_INT = {
3939
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2,
40-
"torch.ops.higher_order.executorch_call_delegate": 2,
40+
"torch.ops.higher_order.executorch_call_delegate": 3,
41+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
4142
}
4243

4344
def _prepare_inputs(

backends/arm/test/models/test_nn_functional.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def test_nn_functional_FP(test_data):
102102
@parametrize(
103103
"test_data",
104104
module_tests,
105-
{"normalize": "MLETORCH-1255: Unsupported dtype in InsertTableOpsPass"},
106105
)
107106
def test_nn_functional_INT(test_data):
108107
module, inputs = test_data

backends/arm/test/ops/test_eye.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_eye_u85_INT(test_data: test_data_t):
9595
input_data(),
9696
EyeAdd.aten_op,
9797
use_to_edge_transform_and_lower=True,
98-
).dump_artifact("to_edge_transform_and_lower")
98+
)
9999
pipeline.pop_stage("check.quant_nodes")
100100
pipeline.run()
101101

0 commit comments

Comments
 (0)