Skip to content

Commit 41ce65c

Browse files
Arm backend: Do not partition noop subgraphs (#12995)
- Reject partitions that will be lowered to empty subgraphs, i.e. containing only clones/ noop expands. - For eye/one/zeros the graph is not really empty since it contains one constant, and this works now. Simply move the previosuly xfailing tests to MI/BI. (u55/ u85 still fails because of missing CPU ops) cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 Signed-off-by: Adrian Lundell <[email protected]>
1 parent f8a422c commit 41ce65c

13 files changed

+240
-147
lines changed

backends/arm/_passes/convert_expand_copy_to_repeat.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,43 @@
88
import logging
99
from typing import cast
1010

11+
import torch
12+
1113
from executorch.exir.dialects._ops import ops as exir_ops
1214
from executorch.exir.pass_base import ExportPass
1315

1416
logger = logging.getLogger(__name__)
1517

1618

19+
def calculate_multiples(args):
20+
input_node_or_tensor = args[0]
21+
22+
if isinstance(input_node_or_tensor, torch.fx.node.Node):
23+
input_data = input_node_or_tensor.meta["val"]
24+
else:
25+
input_data = input_node_or_tensor.data
26+
27+
input_shape = input_data.shape
28+
29+
multiples = cast(list[int], args[1])
30+
expanded_rank = len(multiples)
31+
32+
# Expanded shape is 'input_shape' front-padded with ones.
33+
padding = expanded_rank - len(input_shape)
34+
extended_shape = [
35+
input_shape[i] if i >= 0 else 1 for i in range(-padding, len(input_shape))
36+
]
37+
38+
# To convert expand arg to repeat arg, non-repeated dims should have
39+
# multiples[dim] = 1. Passing -1 to expand arg means
40+
# not changing the size of that dimension.
41+
multiples = [
42+
multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1
43+
for i in range(expanded_rank)
44+
]
45+
return multiples
46+
47+
1748
class ConvertExpandCopyToRepeatPass(ExportPass):
1849
"""
1950
Replace expand copy with repeat since it is a repeat that can only repeat singleton dimensions.
@@ -26,23 +57,7 @@ def call_operator(self, op, args, kwargs, meta):
2657
if op != self.expand_copy:
2758
return super().call_operator(op, args, kwargs, meta)
2859

29-
input_shape = args[0].data.shape
30-
multiples = cast(list[int], args[1])
31-
expanded_rank = len(multiples)
32-
33-
# Expanded shape is 'input_shape' front-padded with ones.
34-
padding = expanded_rank - len(input_shape)
35-
extended_shape = [
36-
input_shape[i] if i >= 0 else 1 for i in range(-padding, len(input_shape))
37-
]
38-
39-
# To convert expand arg to repeat arg, non-repeated dims should have
40-
# multiples[dim] = 1. Passing -1 to expand arg means
41-
# not changing the size of that dimension.
42-
multiples = [
43-
multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1
44-
for i in range(expanded_rank)
45-
]
60+
multiples = calculate_multiples(args)
4661

4762
if all((x == 1 for x in multiples)):
4863
# All dimensions/repetitions occur only once. Remove node

backends/arm/_passes/remove_clone_pass.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@
66

77
# pyre-unsafe
88

9+
import logging
10+
911
from executorch.exir.dialects._ops import ops as exir_ops
1012
from executorch.exir.pass_base import ExportPass
1113

14+
logger = logging.getLogger(__name__)
15+
1216

1317
class RemoveClonePass(ExportPass):
1418
"""Remove all clones from graph_module"""
@@ -21,4 +25,10 @@ def call_operator(self, op, args, kwargs, meta):
2125
raise ValueError(
2226
f"clone operator expects exactly one argument, got {len(args)}"
2327
)
28+
29+
if "memory_format" in kwargs:
30+
logger.warning(
31+
f"Removing clone with memory_format '{kwargs['memory_format']}'."
32+
)
33+
2434
return args[0]

backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# pyre-unsafe
77

88
from . import ( # noqa
9+
clone_support,
910
convolution_support,
1011
embedding_support,
1112
ethos_u55_support,
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
8+
import torch.fx as fx
9+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
10+
register_tosa_support_check,
11+
SupportedTOSAOperatorCheck,
12+
)
13+
from executorch.backends.arm.tosa_specification import TosaSpecification
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
@register_tosa_support_check
20+
class CloneSupported(SupportedTOSAOperatorCheck):
21+
targets = [exir_ops.edge.aten.clone.default]
22+
23+
tosa_specs = [
24+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
25+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
26+
]
27+
28+
def is_node_tosa_supported(
29+
self, node: fx.Node, tosa_spec: TosaSpecification
30+
) -> bool:
31+
32+
input_node = node.args[0]
33+
if not isinstance(input_node, fx.Node):
34+
self.reporter.report_reject(node, "Non tensor clones are not supported")
35+
return False
36+
37+
return True

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,6 @@ def is_node_supported(
228228
exir_ops.edge.aten.var.correction,
229229
exir_ops.edge.aten.var.dim,
230230
exir_ops.edge.aten.view_copy.default,
231-
exir_ops.edge.aten.clone.default,
232231
exir_ops.edge.aten.unsqueeze_copy.default,
233232
exir_ops.edge.aten.squeeze_copy.dims,
234233
exir_ops.edge.aten.pow.Tensor_Scalar,

backends/arm/scripts/parse_test_names.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"unflatten.int",
2626
"_native_batch_norm_legit_no_training.default",
2727
"_native_batch_norm_legit.no_stats",
28+
"alias_copy.default",
2829
]
2930
ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS
3031

backends/arm/test/ops/test_alias_copy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def __init__(self):
4141
super().__init__()
4242

4343
def forward(self, x: torch.Tensor):
44-
return torch.alias_copy(x)
44+
return (
45+
torch.alias_copy(x) * 1
46+
) # Multiply by one to make sure it is partitioned.
4547

4648

4749
@common.parametrize("test_data", AliasCopy.test_data)

backends/arm/test/ops/test_clone.py

Lines changed: 69 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,9 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
#
7-
# Tests the clone op which copies the data of the input tensor (possibly with new data format)
8-
#
96

107
from typing import Tuple
118

12-
import pytest
139
import torch
1410

1511
from executorch.backends.arm.test import common
@@ -28,57 +24,82 @@
2824
input_t = Tuple[torch.Tensor]
2925

3026

31-
class Clone(torch.nn.Module):
32-
"""A simple module that clones an input tensor."""
27+
class CloneFirstArg(torch.nn.Module):
28+
def forward(self, x):
29+
return x.clone() + x
3330

34-
def forward(self, x: torch.Tensor):
35-
return x.clone()
3631

32+
class CloneSecondArg(torch.nn.Module):
33+
def forward(self, x):
34+
return x * x.clone()
35+
36+
37+
class CloneOutput(torch.nn.Module):
38+
def forward(self, x):
39+
return (x / x).clone()
40+
41+
42+
class CloneBothArgs(torch.nn.Module):
43+
def forward(self, x):
44+
return x.clone() + x.clone()
45+
46+
47+
class CloneAfterOtherOp(torch.nn.Module):
48+
def forward(self, x):
49+
x = x * 2
50+
return x.clone() + x
51+
52+
53+
class CloneParallelToOtherOp(torch.nn.Module):
54+
def forward(self, x):
55+
return x * 2 + x.clone()
3756

38-
test_data_suite = {
39-
"ones_1D_10": lambda: (torch.ones(10),),
40-
"ones_1D_50": lambda: (torch.ones(50),),
41-
"rand_1D_20": lambda: (torch.rand(20),),
42-
"rand_2D_10x10": lambda: (torch.rand(10, 10),),
43-
"rand_3D_5x5x5": lambda: (torch.rand(5, 5, 5),),
44-
"rand_4D_2x3x4x5": lambda: (torch.rand(2, 3, 4, 5),),
45-
"large_tensor": lambda: (torch.rand(1000),),
46-
}
4757

58+
delegated_clones = {
59+
"clone_first_arg": lambda: (CloneFirstArg, (torch.rand(1, 2, 3, 4),)),
60+
"clone_second_arg": lambda: (CloneSecondArg, (torch.rand(1, 2, 3, 4),)),
61+
"clone_output": lambda: (CloneOutput, (torch.rand(1, 2, 3, 4),)),
62+
"clone_both_args": lambda: (CloneBothArgs, (torch.rand(1, 2, 3, 4),)),
63+
"clone_after_other_op": lambda: (CloneAfterOtherOp, (torch.rand(1, 2, 3, 4),)),
64+
"clone_parallel_to_other_op": lambda: (
65+
CloneParallelToOtherOp,
66+
(torch.rand(1, 2, 3, 4),),
67+
),
68+
}
4869

49-
@common.parametrize("test_data", test_data_suite)
50-
def test_clone_tosa_FP(test_data: Tuple[torch.Tensor]):
5170

71+
@common.parametrize("input_data", delegated_clones)
72+
def test_clone_tosa_FP(input_data):
73+
module, input_tensor = input_data()
5274
pipeline = TosaPipelineFP[input_t](
53-
Clone(),
54-
test_data(),
55-
aten_op,
56-
exir_op,
75+
module(),
76+
input_tensor,
77+
[],
5778
)
58-
5979
pipeline.run()
6080

6181

62-
@common.parametrize("test_data", test_data_suite)
63-
def test_clone_tosa_INT(test_data):
82+
@common.parametrize("input_data", delegated_clones)
83+
def test_clone_tosa_INT(input_data):
84+
module, input_tensor = input_data()
85+
6486
pipeline = TosaPipelineINT[input_t](
65-
Clone(),
66-
test_data(),
87+
module(),
88+
input_tensor,
6789
aten_op,
6890
exir_op,
6991
)
7092
pipeline.run()
7193

7294

73-
@common.parametrize("test_data", test_data_suite)
95+
@common.parametrize("input_data", delegated_clones)
7496
@common.XfailIfNoCorstone300
75-
@pytest.mark.xfail(
76-
reason="Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477"
77-
)
78-
def test_clone_u55_INT(test_data):
97+
def test_clone_u55_INT(input_data):
98+
module, input_tensor = input_data()
99+
79100
pipeline = EthosU55PipelineINT[input_t](
80-
Clone(),
81-
test_data(),
101+
module(),
102+
input_tensor,
82103
aten_op,
83104
exir_op,
84105
run_on_fvp=True,
@@ -87,15 +108,14 @@ def test_clone_u55_INT(test_data):
87108
pipeline.run()
88109

89110

90-
@common.parametrize("test_data", test_data_suite)
111+
@common.parametrize("input_data", delegated_clones)
91112
@common.XfailIfNoCorstone320
92-
@pytest.mark.xfail(
93-
reason="Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477"
94-
)
95-
def test_clone_u85_INT(test_data):
113+
def test_clone_u85_INT(input_data):
114+
module, input_tensor = input_data()
115+
96116
pipeline = EthosU85PipelineINT[input_t](
97-
Clone(),
98-
test_data(),
117+
module(),
118+
input_tensor,
99119
aten_op,
100120
exir_op,
101121
run_on_fvp=True,
@@ -104,27 +124,23 @@ def test_clone_u85_INT(test_data):
104124
pipeline.run()
105125

106126

107-
@common.parametrize("test_data", test_data_suite)
127+
@common.parametrize("test_data", delegated_clones)
108128
@common.SkipIfNoModelConverter
109-
@pytest.mark.xfail(
110-
reason="Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477"
111-
)
112129
def test_clone_vgf_FP(test_data):
130+
module, input_tensor = test_data()
113131
pipeline = VgfPipeline[input_t](
114-
Clone(), test_data(), aten_op, exir_op, tosa_version="TOSA-1.0+FP"
132+
module(), input_tensor, aten_op, exir_op, tosa_version="TOSA-1.0+FP"
115133
)
116134
pipeline.run()
117135

118136

119-
@common.parametrize("test_data", test_data_suite)
137+
@common.parametrize("test_data", delegated_clones)
120138
@common.SkipIfNoModelConverter
121-
@pytest.mark.xfail(
122-
reason="Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477"
123-
)
124139
def test_clone_vgf_INT(test_data):
140+
module, input_tensor = test_data()
125141
pipeline = VgfPipeline[input_t](
126-
Clone(),
127-
test_data(),
142+
module(),
143+
input_tensor,
128144
aten_op,
129145
exir_op,
130146
tosa_version="TOSA-1.0+INT",

0 commit comments

Comments
 (0)