Skip to content

Commit b1248b6

Browse files
committed
Arm backend: Add full partition rejections
- Reject partitions that will be lowered to empty subgraphs, i.e. containing only clones/ noop expands. - Fix a bug exposed by the partition fix causing a infinite recursion search for qparams when mixing shared qspec ops and arithmetic ops. - For eye/one/zeros the graph is not really empty since it contains one constant, and this works now. Simply move the previosuly xfailing tests to MI/BI. (u55/ u85 still fails because of missing CPU ops) - Reject partitions with non supported memory formats instead of failing these. Signed-off-by: Adrian Lundell <[email protected]> Change-Id: Iefa034a8e731d70465eb4883602c958f51aca976
1 parent 275adee commit b1248b6

16 files changed

+437
-190
lines changed

backends/arm/_passes/convert_expand_copy_to_repeat.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,43 @@
88
import logging
99
from typing import cast
1010

11+
import torch
12+
1113
from executorch.exir.dialects._ops import ops as exir_ops
1214
from executorch.exir.pass_base import ExportPass
1315

1416
logger = logging.getLogger(__name__)
1517

1618

19+
def calculate_multiples(args):
20+
input_node_or_tensor = args[0]
21+
22+
if isinstance(input_node_or_tensor, torch.fx.node.Node):
23+
input_data = input_node_or_tensor.meta["val"]
24+
else:
25+
input_data = input_node_or_tensor.data
26+
27+
input_shape = input_data.shape
28+
29+
multiples = cast(list[int], args[1])
30+
expanded_rank = len(multiples)
31+
32+
# Expanded shape is 'input_shape' front-padded with ones.
33+
padding = expanded_rank - len(input_shape)
34+
extended_shape = [
35+
input_shape[i] if i >= 0 else 1 for i in range(-padding, len(input_shape))
36+
]
37+
38+
# To convert expand arg to repeat arg, non-repeated dims should have
39+
# multiples[dim] = 1. Passing -1 to expand arg means
40+
# not changing the size of that dimension.
41+
multiples = [
42+
multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1
43+
for i in range(expanded_rank)
44+
]
45+
return multiples
46+
47+
1748
class ConvertExpandCopyToRepeatPass(ExportPass):
1849
"""
1950
Replace expand copy with repeat since it is a repeat that can only repeat singleton dimensions.
@@ -26,23 +57,7 @@ def call_operator(self, op, args, kwargs, meta):
2657
if op != self.expand_copy:
2758
return super().call_operator(op, args, kwargs, meta)
2859

29-
input_shape = args[0].data.shape
30-
multiples = cast(list[int], args[1])
31-
expanded_rank = len(multiples)
32-
33-
# Expanded shape is 'input_shape' front-padded with ones.
34-
padding = expanded_rank - len(input_shape)
35-
extended_shape = [
36-
input_shape[i] if i >= 0 else 1 for i in range(-padding, len(input_shape))
37-
]
38-
39-
# To convert expand arg to repeat arg, non-repeated dims should have
40-
# multiples[dim] = 1. Passing -1 to expand arg means
41-
# not changing the size of that dimension.
42-
multiples = [
43-
multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1
44-
for i in range(expanded_rank)
45-
]
60+
multiples = calculate_multiples(args)
4661

4762
if all((x == 1 for x in multiples)):
4863
# All dimensions/repetitions occur only once. Remove node

backends/arm/_passes/remove_clone_pass.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@
66

77
# pyre-unsafe
88

9+
import logging
10+
911
from executorch.exir.dialects._ops import ops as exir_ops
1012
from executorch.exir.pass_base import ExportPass
1113

14+
logger = logging.getLogger(__name__)
15+
1216

1317
class RemoveClonePass(ExportPass):
1418
"""Remove all clones from graph_module"""
@@ -21,4 +25,10 @@ def call_operator(self, op, args, kwargs, meta):
2125
raise ValueError(
2226
f"clone operator expects exactly one argument, got {len(args)}"
2327
)
28+
29+
if "memory_format" in kwargs:
30+
logger.warning(
31+
f"Removing clone with memory_format '{kwargs['memory_format']}'."
32+
)
33+
2434
return args[0]

backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# pyre-unsafe
77

88
from . import ( # noqa
9+
clone_support,
910
convolution_support,
1011
embedding_support,
1112
ethos_u55_support,
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
8+
import torch.fx as fx
9+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
10+
register_tosa_support_check,
11+
SupportedTOSAOperatorCheck,
12+
)
13+
from executorch.backends.arm.tosa_specification import TosaSpecification
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
@register_tosa_support_check
20+
class CloneSupported(SupportedTOSAOperatorCheck):
21+
targets = [exir_ops.edge.aten.clone.default]
22+
23+
tosa_specs = [
24+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
25+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
26+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
27+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
28+
]
29+
30+
def is_node_tosa_supported(
31+
self, node: fx.Node, tosa_spec: TosaSpecification
32+
) -> bool:
33+
34+
input_node = node.args[0]
35+
if not isinstance(input_node, fx.Node):
36+
self.reporter.report_reject(node, "Non tensor clones are not supported")
37+
return False
38+
39+
return True

backends/arm/operator_support/to_copy_support.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -113,30 +113,8 @@ def is_node_tosa_supported(
113113
f"Output dtype {output_val.dtype} is not supported in "
114114
f"{node.target} for input dtype {input_dtype}. "
115115
f"Supported output types: "
116-
f"{''.join(str(t) for t in supported_dtypes[input_dtype])}",
116+
f"{' '.join(str(t) for t in supported_dtypes[input_dtype])}",
117117
)
118118
return False
119119

120-
# Check memory format (to_copy)
121-
if "memory_format" in node.kwargs:
122-
if node.kwargs["memory_format"] in (torch.preserve_format,):
123-
self.reporter.report_reject(
124-
node,
125-
f"Argument 'memory_format' is not supported for "
126-
f"{node.target} right now.",
127-
)
128-
return False
129-
130-
# Check dim_order (to_dim_order_copy)
131-
if "dim_order" in node.kwargs:
132-
dim_order = node.kwargs["dim_order"]
133-
# pyre-ignore[6]
134-
if dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
135-
self.reporter.report_reject(
136-
node,
137-
f"Argument {dim_order=} is not supported for "
138-
f"{node.target} right now.",
139-
)
140-
return False
141-
142120
return True

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,6 @@ def is_node_supported(
228228
exir_ops.edge.aten.var.correction,
229229
exir_ops.edge.aten.var.dim,
230230
exir_ops.edge.aten.view_copy.default,
231-
exir_ops.edge.aten.clone.default,
232231
exir_ops.edge.aten.unsqueeze_copy.default,
233232
exir_ops.edge.aten.squeeze_copy.dims,
234233
exir_ops.edge.aten.pow.Tensor_Scalar,

backends/arm/quantizer/quantization_annotator.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -415,14 +415,40 @@ def any_or_hardtanh_min_zero(n: Node):
415415
torch.ops.aten.minimum.default,
416416
torch.ops.aten.maximum.default,
417417
):
418-
shared_qspec = SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type]
419-
quant_properties.quant_inputs = [
420-
_QuantProperty(0, input_act_qspec),
421-
_QuantProperty(
422-
1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec # type: ignore[arg-type]
423-
),
424-
]
418+
419+
same_input = node.args[0] == node.args[1]
420+
421+
# Handle an edge case leading to an infinite recursion of shared qspecs
422+
input_0_has_quant_info = (
423+
hasattr(node.args[0], "meta")
424+
and "quantization_annotation" in node.args[0].meta # type: ignore[union-attr]
425+
)
426+
input_0_shared = input_0_has_quant_info and isinstance( # type: ignore[union-attr]
427+
node.args[0].meta["quantization_annotation"].output_qspec, # type: ignore[union-attr]
428+
SharedQuantizationSpec, # type: ignore[union-attr]
429+
)
430+
431+
if same_input:
432+
shared_qspec = SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type]
433+
quant_properties.quant_inputs = [
434+
_QuantProperty(0, input_act_qspec), # type: ignore[arg-type]
435+
_QuantProperty(1, input_act_qspec), # type: ignore[arg-type]
436+
]
437+
elif input_0_shared:
438+
shared_qspec = SharedQuantizationSpec((node.args[1], node)) # type: ignore[arg-type]
439+
quant_properties.quant_inputs = [
440+
_QuantProperty(0, shared_qspec), # type: ignore[arg-type]
441+
_QuantProperty(1, input_act_qspec), # type: ignore[arg-type]
442+
]
443+
else:
444+
shared_qspec = SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type]
445+
quant_properties.quant_inputs = [
446+
_QuantProperty(0, input_act_qspec), # type: ignore[arg-type]
447+
_QuantProperty(1, shared_qspec), # type: ignore[arg-type]
448+
]
449+
425450
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
451+
426452
elif node.target in (torch.ops.aten.where.self,):
427453
shared_qspec = SharedQuantizationSpec(node.args[1]) # type: ignore[arg-type]
428454
quant_properties.quant_inputs = [

backends/arm/scripts/parse_test_names.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"unflatten.int",
2323
"_native_batch_norm_legit_no_training.default",
2424
"_native_batch_norm_legit.no_stats",
25+
"alias_copy.default",
2526
]
2627
ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS
2728

backends/arm/test/misc/test_dim_order_guards.py

Lines changed: 84 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,62 +6,107 @@
66

77
from typing import Tuple
88

9-
import pytest
10-
119
import torch
12-
from executorch.backends.arm.test import common
1310

1411
from executorch.backends.arm.test.tester.test_pipeline import (
15-
TosaPipelineBI,
12+
OpNotSupportedPipeline,
1613
TosaPipelineMI,
1714
)
1815

1916

20-
input_t1 = Tuple[torch.Tensor] # Input x
17+
input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input y
18+
19+
20+
class ChannelsLastInput(torch.nn.Module):
21+
"""
22+
Test rejection of a partition which has a channels last input.
23+
"""
24+
25+
inputs: input_t1 = (
26+
torch.randn(1, 2, 2, 2).to(memory_format=torch.channels_last),
27+
torch.randn(1, 2, 2, 2),
28+
)
29+
30+
def forward(self, x, y):
31+
x = x * y
32+
x = x.to(dtype=torch.int32, memory_format=torch.channels_last)
33+
x = x / 2
34+
return x, y
35+
36+
37+
class ChannelsLastOutput(torch.nn.Module):
38+
"""
39+
Test rejection of a partition which has a channels last output
40+
"""
41+
42+
inputs: input_t1 = (
43+
torch.randn(
44+
1,
45+
2,
46+
2,
47+
2,
48+
),
49+
torch.randn(1, 2, 2, 2),
50+
)
2151

52+
def forward(self, x, y):
53+
x = x * y
54+
x = x.clone(memory_format=torch.channels_last)
55+
x = x / 2
56+
return x, y
2257

23-
class Conv2D(torch.nn.Module):
24-
inputs: dict[str, input_t1] = {
25-
"randn": (torch.randn(1, 2, 20, 20),),
26-
}
58+
59+
class ChannelsLastInsidePartition(torch.nn.Module):
60+
"""
61+
Test a non rejection of a fully partitioned module which changes memory inside the partition.
62+
The TOSA backend ignores this memory format change, and since the input and output
63+
has the expected channels_last memory format, the partition should be accepted.
64+
"""
65+
66+
inputs: input_t1 = (
67+
torch.randn(
68+
1,
69+
2,
70+
2,
71+
2,
72+
),
73+
torch.randn(1, 2, 2, 2),
74+
)
2775

2876
def __init__(self):
2977
super().__init__()
30-
self.conv2d = torch.nn.Conv2d(in_channels=2, out_channels=3, kernel_size=(3, 3))
78+
self.conv = torch.nn.Conv2d(2, 2, kernel_size=1, bias=False)
3179

32-
def forward(self, x):
33-
return self.conv2d(x.to(memory_format=torch.channels_last))
80+
def forward(self, x, y):
81+
x = x * y
82+
x = x.to(memory_format=torch.channels_last)
83+
x = self.conv(x)
84+
x = x.clone(memory_format=torch.contiguous_format)
85+
return x, y
3486

3587

36-
@common.parametrize("test_data", Conv2D.inputs)
37-
def test_tosa_MI_pipeline(test_data: input_t1):
38-
module = Conv2D()
88+
def test_dim_order_ok():
3989
pipeline = TosaPipelineMI[input_t1](
40-
module,
41-
test_data,
42-
[],
43-
[],
44-
use_to_edge_transform_and_lower=False,
90+
ChannelsLastInsidePartition(), ChannelsLastInsidePartition.inputs, []
4591
)
46-
pos = pipeline.find_pos("partition")
47-
pipeline._stages = pipeline._stages[:pos]
4892
pipeline.run()
49-
with pytest.raises(RuntimeError):
50-
pipeline.tester.partition()
51-
52-
53-
@common.parametrize("test_data", Conv2D.inputs)
54-
def test_tosa_BI_pipeline(test_data: input_t1):
55-
module = Conv2D()
56-
pipeline = TosaPipelineBI[input_t1](
57-
module,
58-
test_data,
59-
[],
60-
[],
61-
use_to_edge_transform_and_lower=False,
93+
94+
95+
def test_channels_last_input():
96+
pipeline = OpNotSupportedPipeline[input_t1](
97+
ChannelsLastInput(),
98+
ChannelsLastInput.inputs,
99+
non_delegated_ops={},
100+
n_expected_delegates=0,
101+
)
102+
pipeline.run()
103+
104+
105+
def test_channels_last_output():
106+
pipeline = OpNotSupportedPipeline[input_t1](
107+
ChannelsLastOutput(),
108+
ChannelsLastOutput.inputs,
109+
non_delegated_ops={},
110+
n_expected_delegates=0,
62111
)
63-
pos = pipeline.find_pos("partition")
64-
pipeline._stages = pipeline._stages[:pos]
65112
pipeline.run()
66-
with pytest.raises(RuntimeError):
67-
pipeline.tester.partition()

0 commit comments

Comments
 (0)