Skip to content

Commit da97a0e

Browse files
Arm backend: Add aten.select to "view" operator support check (#12672)
Select decomposes into squeeze, which in turn becomes a view. Therefore, perform the same check on select operators as view operators. Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 70d9e94 commit da97a0e

File tree

2 files changed

+58
-7
lines changed

2 files changed

+58
-7
lines changed

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@
66
# pyre-unsafe
77

88
import typing
9+
from typing import cast
910

1011
import torch
1112
import torch.fx as fx
13+
1214
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
1315
from executorch.backends.arm._passes.insert_table_ops import TableOps
1416
from executorch.backends.arm.operators.op_permute import transform_permutation_vector
1517
from executorch.backends.arm.tosa_utils import tosa_shape
1618
from executorch.exir.backend.utils import WhyNoPartitionReporter
17-
1819
from executorch.exir.dialects._ops import ops as exir_ops
1920
from torch.fx.passes.operator_support import OperatorSupportBase
2021

@@ -212,19 +213,51 @@ def is_node_supported(
212213
Returns:
213214
False if the operator is not support and True if it is supported.
214215
"""
215-
if not node.target == exir_ops.edge.aten.view_copy.default:
216+
# Select decomposes into squeeze, which in turn becomes a view. Therefore,
217+
# perform the same check on select operators as view operators.
218+
if node.target not in (
219+
exir_ops.edge.aten.view_copy.default,
220+
exir_ops.edge.aten.select.int,
221+
exir_ops.edge.aten.select_copy.int,
222+
):
216223
return True
217224

218-
shape = list(get_first_fake_tensor(node).shape)
225+
if node.target in (
226+
exir_ops.edge.aten.select.int,
227+
exir_ops.edge.aten.select_copy.int,
228+
):
229+
input_node, dim, index = cast(tuple[fx.Node, int, int], node.args)
230+
231+
shape = input_node.meta["val"].shape
232+
rank = len(shape)
233+
if not -rank <= dim < rank:
234+
raise IndexError(
235+
f"Dim {dim} is outside of the range for tensor '{node.target}' of "
236+
f"rank {rank}"
237+
)
238+
dim = dim % rank
239+
240+
size = shape[dim]
241+
if not -size <= index < size:
242+
raise IndexError(
243+
f"Index {index} is outside of the range for dim {dim} with size "
244+
f"{size} for tensor {node.target}"
245+
)
246+
index = index % size
247+
248+
# Shape after squeeze. This may get converted into a view which may become
249+
# a transpose. This is why we're checking select.
250+
squeezed_shape = shape[:dim] + shape[dim + 1 :]
251+
shape = squeezed_shape
252+
else:
253+
shape = list(get_first_fake_tensor(node).shape)
254+
219255
dtype = _try_determine_dtype(node)
220-
permutation = list(typing.cast(list[int], node.args[1]))
221256

222257
rank = len(shape)
223258
if rank > 4:
224259
if dtype == torch.int32:
225-
self.reporter.report_reject(
226-
node, f"No support for {permutation=} in int32."
227-
)
260+
self.reporter.report_reject(node, "No support for rank > 4 in int32.")
228261
return False
229262

230263
if dtype in (torch.int8, torch.int16):

backends/arm/test/ops/test_select.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from executorch.backends.arm.test.tester.test_pipeline import (
1414
EthosU55PipelineBI,
1515
EthosU85PipelineBI,
16+
OpNotSupportedPipeline,
1617
TosaPipelineBI,
1718
TosaPipelineMI,
1819
)
@@ -32,6 +33,10 @@
3233
"select3d_0_dim_1_index": lambda: (torch.arange(-16, 16, 0.2), 0, 1),
3334
}
3435

36+
test_data_not_delegated = {
37+
"select3d_large_after_squeeze": lambda: (torch.rand(3, 64, 3, 49, 32), 0, 0),
38+
}
39+
3540
aten_op_copy = "torch.ops.aten.select_copy.int"
3641
aten_op_int = "torch.ops.aten.select.int"
3742

@@ -129,6 +134,19 @@ def test_select_int_u55_BI(test_data: Tuple):
129134
pipeline.run()
130135

131136

137+
@common.parametrize("test_data", test_data_not_delegated)
138+
def test_select_int_u55_BI_not_delegated(test_data: Tuple):
139+
pipeline = OpNotSupportedPipeline[input_t1](
140+
SelectInt(),
141+
test_data(),
142+
{aten_op_copy: 0},
143+
n_expected_delegates=0,
144+
quantize=True,
145+
u55_subset=True,
146+
)
147+
pipeline.run()
148+
149+
132150
@common.parametrize("test_data", test_data_suite, x_fails)
133151
@common.XfailIfNoCorstone320
134152
def test_select_int_u85_BI_copy(test_data: Tuple):

0 commit comments

Comments
 (0)