Skip to content

Commit 81f8d6f

Browse files
Merge branch 'main' into enable-per-channel-quantization-for-VgfPipeline
2 parents 3868638 + 44d24fa commit 81f8d6f

File tree

6 files changed

+180
-81
lines changed

6 files changed

+180
-81
lines changed

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def call_operator(self, op, args, kwargs, meta):
105105

106106
# Reshape back to 5D if necessary
107107
if len(input_shape) > 4:
108-
original_dims = input_shape[0:-4]
108+
original_dims = input_shape[0:-3]
109109
temp_shape = list(x.data.shape)[1:]
110110
temp_shape = original_dims + temp_shape
111111
dims_to_reduce = [dim + len(original_dims) - 1 for dim in dims_to_reduce]

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@
66
# pyre-unsafe
77

88
import typing
9+
from typing import cast
910

1011
import torch
1112
import torch.fx as fx
13+
1214
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
1315
from executorch.backends.arm._passes.insert_table_ops import TableOps
1416
from executorch.backends.arm.operators.op_permute import transform_permutation_vector
1517
from executorch.backends.arm.tosa_utils import tosa_shape
1618
from executorch.exir.backend.utils import WhyNoPartitionReporter
17-
1819
from executorch.exir.dialects._ops import ops as exir_ops
1920
from torch.fx.passes.operator_support import OperatorSupportBase
2021

@@ -212,19 +213,51 @@ def is_node_supported(
212213
Returns:
213214
False if the operator is not support and True if it is supported.
214215
"""
215-
if not node.target == exir_ops.edge.aten.view_copy.default:
216+
# Select decomposes into squeeze, which in turn becomes a view. Therefore,
217+
# perform the same check on select operators as view operators.
218+
if node.target not in (
219+
exir_ops.edge.aten.view_copy.default,
220+
exir_ops.edge.aten.select.int,
221+
exir_ops.edge.aten.select_copy.int,
222+
):
216223
return True
217224

218-
shape = list(get_first_fake_tensor(node).shape)
225+
if node.target in (
226+
exir_ops.edge.aten.select.int,
227+
exir_ops.edge.aten.select_copy.int,
228+
):
229+
input_node, dim, index = cast(tuple[fx.Node, int, int], node.args)
230+
231+
shape = input_node.meta["val"].shape
232+
rank = len(shape)
233+
if not -rank <= dim < rank:
234+
raise IndexError(
235+
f"Dim {dim} is outside of the range for tensor '{node.target}' of "
236+
f"rank {rank}"
237+
)
238+
dim = dim % rank
239+
240+
size = shape[dim]
241+
if not -size <= index < size:
242+
raise IndexError(
243+
f"Index {index} is outside of the range for dim {dim} with size "
244+
f"{size} for tensor {node.target}"
245+
)
246+
index = index % size
247+
248+
# Shape after squeeze. This may get converted into a view which may become
249+
# a transpose. This is why we're checking select.
250+
squeezed_shape = shape[:dim] + shape[dim + 1 :]
251+
shape = squeezed_shape
252+
else:
253+
shape = list(get_first_fake_tensor(node).shape)
254+
219255
dtype = _try_determine_dtype(node)
220-
permutation = list(typing.cast(list[int], node.args[1]))
221256

222257
rank = len(shape)
223258
if rank > 4:
224259
if dtype == torch.int32:
225-
self.reporter.report_reject(
226-
node, f"No support for {permutation=} in int32."
227-
)
260+
self.reporter.report_reject(node, "No support for rank > 4 in int32.")
228261
return False
229262

230263
if dtype in (torch.int8, torch.int16):

backends/arm/test/ops/test_ceil.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm.test import common
10+
from executorch.backends.arm.test.tester.test_pipeline import (
11+
EthosU55PipelineBI,
12+
EthosU85PipelineBI,
13+
TosaPipelineBI,
14+
TosaPipelineMI,
15+
)
16+
17+
input_t1 = Tuple[torch.Tensor]
18+
19+
20+
class Ceil(torch.nn.Module):
21+
def forward(self, x: torch.Tensor):
22+
return torch.ceil(x)
23+
24+
aten_op = "torch.ops.aten.ceil.default"
25+
exir_op = "executorch_exir_dialects_edge__ops_aten_ceil_default"
26+
27+
28+
zeros = torch.zeros(1, 10, 10, 10)
29+
ones = torch.ones(10, 10, 10)
30+
rand = torch.rand(10, 10) - 0.5
31+
randn_pos = torch.randn(1, 4, 4, 4) + 10
32+
randn_neg = torch.randn(1, 4, 4, 4) - 10
33+
ramp = torch.arange(-16, 16, 0.2)
34+
35+
test_data = {
36+
"ceil_zeros": lambda: (Ceil(), zeros),
37+
"ceil_ones": lambda: (Ceil(), ones),
38+
"ceil_rand": lambda: (Ceil(), rand),
39+
"ceil_randn_pos": lambda: (Ceil(), randn_pos),
40+
"ceil_randn_neg": lambda: (Ceil(), randn_neg),
41+
"ceil_ramp": lambda: (Ceil(), ramp),
42+
}
43+
44+
45+
@common.parametrize("test_data", test_data)
46+
def test_ceil_tosa_MI(test_data: input_t1):
47+
module, data = test_data()
48+
pipeline = TosaPipelineMI[input_t1](
49+
module,
50+
(data,),
51+
module.aten_op,
52+
module.exir_op,
53+
)
54+
pipeline.run()
55+
56+
57+
@common.parametrize("test_data", test_data)
58+
def test_ceil_tosa_BI(test_data: input_t1):
59+
module, data = test_data()
60+
pipeline = TosaPipelineBI[input_t1](
61+
module,
62+
(data,),
63+
module.aten_op,
64+
module.exir_op,
65+
atol=0.06,
66+
rtol=0.01,
67+
)
68+
pipeline.run()
69+
70+
71+
@common.parametrize("test_data", test_data)
72+
@common.XfailIfNoCorstone300
73+
def test_ceil_u55_BI(test_data: input_t1):
74+
module, data = test_data()
75+
pipeline = EthosU55PipelineBI[input_t1](
76+
module,
77+
(data,),
78+
module.aten_op,
79+
module.exir_op,
80+
run_on_fvp=True,
81+
)
82+
pipeline.run()
83+
84+
85+
@common.parametrize("test_data", test_data)
86+
@common.XfailIfNoCorstone320
87+
def test_ceil_u85_BI(test_data: input_t1):
88+
module, data = test_data()
89+
pipeline = EthosU85PipelineBI[input_t1](
90+
module,
91+
(data,),
92+
module.aten_op,
93+
module.exir_op,
94+
run_on_fvp=True,
95+
)
96+
pipeline.run()

backends/arm/test/ops/test_unary.py renamed to backends/arm/test/ops/test_floor.py

Lines changed: 19 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,13 @@
1414
TosaPipelineMI,
1515
)
1616

17-
18-
input_t1 = Tuple[torch.Tensor] # Input x
19-
20-
21-
class Ceil(torch.nn.Module):
22-
def forward(self, x: torch.Tensor):
23-
return torch.ceil(x)
24-
25-
op_name = "ceil"
26-
aten_op = "torch.ops.aten.ceil.default"
27-
exir_op = "executorch_exir_dialects_edge__ops_aten_ceil_default"
17+
input_t1 = Tuple[torch.Tensor]
2818

2919

3020
class Floor(torch.nn.Module):
3121
def forward(self, x: torch.Tensor):
3222
return torch.floor(x)
3323

34-
op_name = "floor"
3524
aten_op = "torch.ops.aten.floor.default"
3625
exir_op = "executorch_exir_dialects_edge__ops_aten_floor_default"
3726

@@ -43,77 +32,34 @@ def forward(self, x: torch.Tensor):
4332
randn_neg = torch.randn(1, 4, 4, 4) - 10
4433
ramp = torch.arange(-16, 16, 0.2)
4534

46-
4735
test_data = {
48-
"ceil_zeros": lambda: (
49-
Ceil(),
50-
zeros,
51-
),
52-
"floor_zeros": lambda: (
53-
Floor(),
54-
zeros,
55-
),
56-
"ceil_ones": lambda: (
57-
Ceil(),
58-
ones,
59-
),
60-
"floor_ones": lambda: (
61-
Floor(),
62-
ones,
63-
),
64-
"ceil_rand": lambda: (
65-
Ceil(),
66-
rand,
67-
),
68-
"floor_rand": lambda: (
69-
Floor(),
70-
rand,
71-
),
72-
"ceil_randn_pos": lambda: (
73-
Ceil(),
74-
randn_pos,
75-
),
76-
"floor_randn_pos": lambda: (
77-
Floor(),
78-
randn_pos,
79-
),
80-
"ceil_randn_neg": lambda: (
81-
Ceil(),
82-
randn_neg,
83-
),
84-
"floor_randn_neg": lambda: (
85-
Floor(),
86-
randn_neg,
87-
),
88-
"ceil_ramp": lambda: (
89-
Ceil(),
90-
ramp,
91-
),
92-
"floor_ramp": lambda: (
93-
Floor(),
94-
ramp,
95-
),
36+
"floor_zeros": lambda: (Floor(), zeros),
37+
"floor_ones": lambda: (Floor(), ones),
38+
"floor_rand": lambda: (Floor(), rand),
39+
"floor_randn_pos": lambda: (Floor(), randn_pos),
40+
"floor_randn_neg": lambda: (Floor(), randn_neg),
41+
"floor_ramp": lambda: (Floor(), ramp),
9642
}
9743

9844

9945
@common.parametrize("test_data", test_data)
100-
def test_unary_tosa_MI(test_data: input_t1):
101-
module, test_data = test_data()
46+
def test_floor_tosa_MI(test_data: input_t1):
47+
module, data = test_data()
10248
pipeline = TosaPipelineMI[input_t1](
10349
module,
104-
(test_data,),
50+
(data,),
10551
module.aten_op,
10652
module.exir_op,
10753
)
10854
pipeline.run()
10955

11056

11157
@common.parametrize("test_data", test_data)
112-
def test_unary_tosa_BI(test_data: input_t1):
113-
module, test_data = test_data()
58+
def test_floor_tosa_BI(test_data: input_t1):
59+
module, data = test_data()
11460
pipeline = TosaPipelineBI[input_t1](
11561
module,
116-
(test_data,),
62+
(data,),
11763
module.aten_op,
11864
module.exir_op,
11965
atol=0.06,
@@ -124,11 +70,11 @@ def test_unary_tosa_BI(test_data: input_t1):
12470

12571
@common.parametrize("test_data", test_data)
12672
@common.XfailIfNoCorstone300
127-
def test_unary_u55_BI(test_data: input_t1):
128-
module, test_data = test_data()
73+
def test_floor_u55_BI(test_data: input_t1):
74+
module, data = test_data()
12975
pipeline = EthosU55PipelineBI[input_t1](
13076
module,
131-
(test_data,),
77+
(data,),
13278
module.aten_op,
13379
module.exir_op,
13480
run_on_fvp=True,
@@ -138,11 +84,11 @@ def test_unary_u55_BI(test_data: input_t1):
13884

13985
@common.parametrize("test_data", test_data)
14086
@common.XfailIfNoCorstone320
141-
def test_unary_u85_BI(test_data: input_t1):
142-
module, test_data = test_data()
87+
def test_floor_u85_BI(test_data: input_t1):
88+
module, data = test_data()
14389
pipeline = EthosU85PipelineBI[input_t1](
14490
module,
145-
(test_data,),
91+
(data,),
14692
module.aten_op,
14793
module.exir_op,
14894
run_on_fvp=True,

backends/arm/test/ops/test_mean_dim.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,11 @@ class MeanDim(torch.nn.Module):
210210
(1, 2),
211211
False,
212212
),
213+
"rank5_2": lambda: (
214+
torch.rand(1, 4, 7, 3, 2),
215+
(2),
216+
False,
217+
),
213218
"u55_avg_pool_not_supported": lambda: (
214219
torch.rand(1, 1, 1, 257),
215220
(0, 1, 2, 3),
@@ -255,6 +260,7 @@ def test_mean_dim_tosa_BI(test_data):
255260
"rank5_01234": "Rank 5 graph input currently not supported in EthosUBackend (passes since CHW are all averaged over so data order does not matter in this case)",
256261
"rank5_234": "Rank 5 graph input currently not supported in EthosUBackend (passes since CHW are all averaged over so data order does not matter in this case)",
257262
"rank5_12": "Rank 5 graph input currently not supported in EthosUBackend",
263+
"rank5_2": "Rank 5 graph input currently not supported in EthosUBackend",
258264
}
259265

260266

backends/arm/test/ops/test_select.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from executorch.backends.arm.test.tester.test_pipeline import (
1414
EthosU55PipelineBI,
1515
EthosU85PipelineBI,
16+
OpNotSupportedPipeline,
1617
TosaPipelineBI,
1718
TosaPipelineMI,
1819
)
@@ -32,6 +33,10 @@
3233
"select3d_0_dim_1_index": lambda: (torch.arange(-16, 16, 0.2), 0, 1),
3334
}
3435

36+
test_data_not_delegated = {
37+
"select3d_large_after_squeeze": lambda: (torch.rand(3, 64, 3, 49, 32), 0, 0),
38+
}
39+
3540
aten_op_copy = "torch.ops.aten.select_copy.int"
3641
aten_op_int = "torch.ops.aten.select.int"
3742

@@ -129,6 +134,19 @@ def test_select_int_u55_BI(test_data: Tuple):
129134
pipeline.run()
130135

131136

137+
@common.parametrize("test_data", test_data_not_delegated)
138+
def test_select_int_u55_BI_not_delegated(test_data: Tuple):
139+
pipeline = OpNotSupportedPipeline[input_t1](
140+
SelectInt(),
141+
test_data(),
142+
{aten_op_copy: 0},
143+
n_expected_delegates=0,
144+
quantize=True,
145+
u55_subset=True,
146+
)
147+
pipeline.run()
148+
149+
132150
@common.parametrize("test_data", test_data_suite, x_fails)
133151
@common.XfailIfNoCorstone320
134152
def test_select_int_u85_BI_copy(test_data: Tuple):

0 commit comments

Comments
 (0)