Skip to content

Commit de5962d

Browse files
martinlsmMartin Lindström
andauthored
Arm backend: Make SupportedTOSAOperatorChecks work for INT+FP (#16072)
cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Martin Lindström <[email protected]> Co-authored-by: Martin Lindström <[email protected]>
1 parent a2078c6 commit de5962d

File tree

5 files changed

+156
-84
lines changed

5 files changed

+156
-84
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@
112112

113113
from executorch.backends.arm._passes.arm_pass import ArmPass
114114
from executorch.backends.arm.tosa.specification import (
115+
tosa_spec_in_set,
115116
TosaLoweringContext,
116117
TosaSpecification,
117118
)
@@ -308,16 +309,20 @@ def transform_to_backend_pipeline(
308309
self, exported_program: ExportedProgram, graph_module: GraphModule
309310
):
310311
"""Apply passes before transforming program to backend"""
311-
if self.tosa_spec in (
312-
TosaSpecification.create_from_string("TOSA-1.0+FP"),
313-
TosaSpecification.create_from_string("TOSA-1.0+INT"),
312+
313+
if not tosa_spec_in_set(
314+
self.tosa_spec,
315+
{
316+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
317+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
318+
},
314319
):
315-
return self._tosa_pipeline(exported_program, graph_module)
316-
else:
317-
raise NotImplementedError(
318-
f"No pass pipeline implemented for {self.tosa_spec}"
320+
raise RuntimeError(
321+
f"No pass pipeline found for TOSA specification: {self.tosa_spec}"
319322
)
320323

324+
return self._tosa_pipeline(exported_program, graph_module)
325+
321326
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
322327
# Preprocessing passes
323328
self.add_pass(RemoveGraphAssertsPass())

backends/arm/operator_support/slice_copy_support.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@ def is_node_tosa_supported(
4141
non-unit step sizes.
4242
4343
"""
44-
if tosa_spec not in self.tosa_specs:
45-
return False
46-
4744
args = node.args
4845
if len(args) == 5 and (step := args[4]) != 1:
4946
logger.warning(f"{node.target} with step size of {step} not supported.")

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 81 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,61 @@ def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]):
146146
return checker
147147

148148

149+
def _is_quantized_constant(node: torch.fx.Node) -> bool:
150+
if node.target not in (
151+
exir_ops.edge.aten.full_like.default,
152+
*ComputeConstantOpsAOTPass.targeted_ops,
153+
):
154+
return False
155+
156+
users = tuple(node.users)
157+
if users and all(user.target in Q_OPS for user in users):
158+
# The node feeds directly into only quantized ops.
159+
return True
160+
161+
for user in users:
162+
if user.target == exir_ops.edge.dim_order_ops._to_dim_order_copy.default:
163+
dim_order_dtype = get_first_fake_tensor(user).dtype
164+
if dim_order_dtype.is_complex or dim_order_dtype.is_floating_point:
165+
return False
166+
else:
167+
return False
168+
169+
return len(users) > 0
170+
171+
172+
def is_quantized(node: torch.fx.Node) -> bool:
173+
"""Checks if the node is quantized.
174+
175+
A node is considered quantized if any of the following is true:
176+
- Its output dtype is not floating point or complex => integer
177+
- It is an op that produces a constant that in turn feeds only quantized users
178+
- It has been marked as quantized in the ArmAnnotationInfo custom meta.
179+
180+
Args:
181+
node (torch.fx.Node): The FX node to check.
182+
183+
Returns:
184+
bool: True if the node is quantized, False otherwise.
185+
"""
186+
187+
node_dtype = get_first_fake_tensor(node).dtype
188+
# Integer-like dtype implies the node is already quantized.
189+
if not node_dtype.is_complex and not node_dtype.is_floating_point:
190+
return True
191+
192+
# Nodes introduced during lowering that exclusively feed quantized users.
193+
if _is_quantized_constant(node):
194+
return True
195+
196+
# Finally, fall back to the explicit annotation emitted by Arm passes.
197+
custom_meta = node.meta.get("custom", {})
198+
if ArmAnnotationInfo.CUSTOM_META_KEY in custom_meta:
199+
return custom_meta[ArmAnnotationInfo.CUSTOM_META_KEY]["quantized"]
200+
201+
return False
202+
203+
149204
def get_registered_tosa_support_checks(
150205
tosa_spec: TosaSpecification,
151206
) -> list[Type[SupportedTOSAOperatorCheck]]:
@@ -194,9 +249,11 @@ def tosa_support_factory(
194249
ControlFlowOpSupported(exported_program, tosa_spec, reporter),
195250
]
196251

197-
if tosa_spec.support_integer():
252+
if tosa_spec.support_integer() and tosa_spec.support_float():
253+
positive_checks.append(TOSAProINTFPSupportList())
254+
elif tosa_spec.support_integer():
198255
positive_checks.append(TOSAProINTSupportList())
199-
if tosa_spec.support_float():
256+
elif tosa_spec.support_float():
200257
positive_checks.append(TOSAProFPSupportList())
201258
# TODO: Refactor to use TOSAProSupportLists + negtive checks
202259
positive_checks += [
@@ -268,6 +325,27 @@ def is_node_supported(
268325
return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList
269326

270327

328+
class TOSAProINTFPSupportList(OperatorSupportBase):
329+
"""
330+
TOSA_PRO_INT_FP_SupportList:
331+
Ops supported in INT+FP profile via native TOSA ops, decomposition/transformation, pre-compute, or TableOp.
332+
"""
333+
334+
def is_node_supported(
335+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
336+
) -> bool:
337+
if node.op != "call_function":
338+
return False
339+
340+
# Select list based on whether the node is quantized.
341+
if is_quantized(node) or node.target in (*Q_OPS, *DQ_OPS):
342+
support_list = TOSA_PRO_INT_SupportList
343+
else:
344+
support_list = TOSA_PRO_FP_SupportList
345+
346+
return node.target in support_list
347+
348+
271349
class CheckArmQuantized(OperatorSupportBase):
272350
"""
273351
Check if the node was marked as quantized in the Arm backend.
@@ -278,60 +356,14 @@ class CheckArmQuantized(OperatorSupportBase):
278356
def __init__(self, reporter: WhyNoPartitionReporter):
279357
self.reporter = reporter
280358

281-
def _is_quantized(self, node: torch.fx.Node) -> bool:
282-
"""Checks if the node is quantized.
283-
284-
A node is considered quantized if at least one criteria is met:
285-
- Its dtype is not floating point or complex => integer
286-
- It is one of the special cases where the node has been created in to_edge, e.g.
287-
.Scalar operations that have been promoted .Tensor operations
288-
where the scalar is replaced by a full op.
289-
- It has been marked as quantized in the ArmAnnotationInfo custom meta.
290-
291-
Args:
292-
node (torch.fx.Node): The FX node to check.
293-
294-
Returns:
295-
bool: True if the node is quantized, False otherwise.
296-
"""
297-
node_dtype = get_first_fake_tensor(node).dtype
298-
if not node_dtype.is_complex and not node_dtype.is_floating_point:
299-
return True
300-
if node.target in (
301-
exir_ops.edge.aten.full_like.default,
302-
*ComputeConstantOpsAOTPass.targeted_ops,
303-
):
304-
# Special cases where nodes have been created in to_edge, e.g.
305-
# .Scalar operations that have been promoted .Tensor operations
306-
# where the scalar is replaced by a full op.
307-
if all(user.target in Q_OPS for user in node.users):
308-
return True
309-
for user in node.users:
310-
if (
311-
user.target
312-
== exir_ops.edge.dim_order_ops._to_dim_order_copy.default
313-
):
314-
dim_order_dtype = get_first_fake_tensor(user).dtype
315-
if dim_order_dtype.is_complex or dim_order_dtype.is_floating_point:
316-
return False
317-
else:
318-
return False
319-
return True
320-
return (
321-
ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {})
322-
and ArmAnnotationInfo(
323-
node.meta["custom"][ArmAnnotationInfo.CUSTOM_META_KEY]
324-
).quantized
325-
)
326-
327359
def is_node_supported(
328360
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
329361
) -> bool:
330362

331363
if node.target in (*DQ_OPS, *Q_OPS):
332364
return True
333365

334-
if not self._is_quantized(node):
366+
if not is_quantized(node):
335367
self.reporter.report_reject(
336368
node, "Node was not marked as quantized in the Arm backend."
337369
)

backends/arm/test/misc/test_quant_custom_meta.py

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import pytest
67
import torch
78
from executorch.backends.arm.quantizer import (
89
get_symmetric_quantization_config,
@@ -31,31 +32,41 @@ def get_selective_quantizer(modules):
3132
return Quantize(quantizer, get_symmetric_quantization_config())
3233

3334

34-
def test_qdq_squeezed_fp_op():
35+
@pytest.mark.parametrize("fp_extension", [True, False])
36+
def test_qdq_squeezed_fp_op(fp_extension: bool):
3537
"""Test that a float operation surrounded by quantize-dequantize pairs
3638
is correctly handled by the partitioner and the TOSA backend.
3739
Pattern:
3840
q -> dq -> add -> q -> dq -> sigmoid -> q -> dq -> mul -> dq -> q
39-
|_____Non-delegated____|
41+
|_____unquantized_____|
4042
"""
4143
aten_op = "torch.ops.aten.add.Tensor"
4244
exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"
4345
module = AddSigmoidMul()
4446
x = torch.randn(2, 3, 4)
4547
y = torch.randn(2, 3, 4)
4648
pipeline = TosaPipelineINT(
47-
module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op
49+
module=module,
50+
test_data=(x, y),
51+
aten_op=aten_op,
52+
exir_op=exir_op,
53+
tosa_extensions=["FP"] if fp_extension else None,
4854
)
4955
pipeline.change_args("quantize", get_selective_quantizer([torch.nn.Sigmoid]))
50-
pipeline.change_args(
51-
"check_count.exir",
52-
{
53-
"torch.ops.higher_order.executorch_call_delegate": 2,
54-
"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1,
55-
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2,
56-
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3,
57-
},
58-
)
56+
57+
if not fp_extension:
58+
# In case we don't have the FP extension, the unquantized part of the
59+
# graph should not be delegated to the Arm backend. Modify the op count
60+
# checks to reflect this behavior.
61+
pipeline.change_args(
62+
"check_count.exir",
63+
{
64+
"torch.ops.higher_order.executorch_call_delegate": 2,
65+
"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1,
66+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2,
67+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3,
68+
},
69+
)
5970
pipeline.run()
6071

6172

@@ -69,32 +80,41 @@ def forward(self, x, y):
6980
return self.conv(self.sigmoid(x + y * x))
7081

7182

72-
def test_quantized_to_float_transition():
83+
@pytest.mark.parametrize("fp_extension", [True, False])
84+
def test_quantized_to_float_transition(fp_extension: bool):
7385
"""Test that a model executing quantized ops followed by float ops
7486
is correctly handled by the partitioner and the TOSA backend.
7587
Pattern:
7688
q -> dq -> mul -> q -> dq -> add -> q -> dq -> sigmoid -> conv
77-
|____Non-delegated___|
89+
|___unquantized___|
7890
"""
7991
aten_op = "torch.ops.aten.add.Tensor"
8092
exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"
8193
module = MulAddSigmoidConv()
8294
x = torch.randn(2, 3, 4)
8395
y = torch.randn(2, 3, 4)
8496
pipeline = TosaPipelineINT(
85-
module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op
97+
module=module,
98+
test_data=(x, y),
99+
aten_op=aten_op,
100+
exir_op=exir_op,
101+
tosa_extensions=["FP"] if fp_extension else None,
86102
)
103+
if not fp_extension:
104+
# In case we don't have the FP extension, the unquantized part of the
105+
# graph should not be delegated to the Arm backend. Modify the op count
106+
# checks to reflect this behavior.
107+
pipeline.change_args(
108+
"check_count.exir",
109+
{
110+
"torch.ops.higher_order.executorch_call_delegate": 1,
111+
"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1,
112+
"executorch_exir_dialects_edge__ops_aten_convolution_default": 1,
113+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1,
114+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
115+
},
116+
)
87117
pipeline.change_args(
88118
"quantize", get_selective_quantizer([torch.nn.Sigmoid, torch.nn.Conv1d])
89119
)
90-
pipeline.change_args(
91-
"check_count.exir",
92-
{
93-
"torch.ops.higher_order.executorch_call_delegate": 1,
94-
"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1,
95-
"executorch_exir_dialects_edge__ops_aten_convolution_default": 1,
96-
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1,
97-
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
98-
},
99-
)
100120
pipeline.run()

backends/arm/tosa/specification.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,3 +386,21 @@ def get_context_spec() -> TosaSpecification:
386386
return TosaLoweringContext.tosa_spec_var.get()
387387
except LookupError:
388388
raise RuntimeError("Function must be executed within a TosaLoweringContext")
389+
390+
391+
def tosa_spec_in_set(spec: TosaSpecification, specs: Set[TosaSpecification]) -> bool:
392+
"""Check if a specification matches any in a set, considering base specs.
393+
394+
Args:
395+
spec (TosaSpecification): Specification to check.
396+
specs (Set[TosaSpecification]): Set of specifications to match against.
397+
398+
Returns:
399+
bool: True if a match is found, False otherwise.
400+
401+
"""
402+
base_specs = TosaSpecMapping._get_base_specs(spec)
403+
for base in base_specs:
404+
if base in specs:
405+
return True
406+
return False

0 commit comments

Comments
 (0)