Skip to content

Commit 6e7353f

Browse files
authored
Arm backend: Add 6D tensor and pixel shuffle/unshuffle support (#14626)
Adds 6D tensor support required by pixel_shuffle/pixel_unshuffle when given 4D inputs, which means for now we only support 4D inputs. Adds TOSA, VGF and xfailing Ethos-U85 unit tests. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218
1 parent b6bc421 commit 6e7353f

File tree

8 files changed

+310
-47
lines changed

8 files changed

+310
-47
lines changed

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 64 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
NNCHW_ORDER,
2727
NNHWC_INVERSE_ORDER,
2828
NNHWC_ORDER,
29+
NNNCHW_ORDER,
30+
NNNHWC_INVERSE_ORDER,
31+
NNNHWC_ORDER,
2932
)
3033
from executorch.exir import ExportedProgram
3134
from executorch.exir.dialects._ops import ops as exir_ops
@@ -51,12 +54,6 @@ class ToTosaMemoryFormatPass(ExportPass):
5154

5255
_passes_required_after: Set[Type[ExportPass]] = set()
5356

54-
NHWC_order = (0, 2, 3, 1)
55-
NHWC_inverse_order = (0, 3, 1, 2)
56-
HWCM_order = (2, 3, 0, 1)
57-
NNHWC_order = (0, 1, 3, 4, 2)
58-
NNHWC_inverse_order = (0, 1, 4, 2, 3)
59-
6057
def __init__(self, exported_program: ExportedProgram) -> None:
6158
self.exported_program = exported_program
6259
super().__init__()
@@ -93,7 +90,11 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
9390
@staticmethod
9491
def memory_format_differs(shape):
9592
"""Returns true if the shape will have a different memory layout in (N)NCHW and (N)NHWC format"""
96-
if len(shape) >= 5:
93+
if len(shape) >= 6:
94+
C = shape[3]
95+
H = shape[4]
96+
W = shape[5]
97+
elif len(shape) == 5:
9798
C = shape[2]
9899
H = shape[3]
99100
W = shape[4]
@@ -112,25 +113,26 @@ def memory_format_differs(shape):
112113

113114
@staticmethod
114115
def is_channel_reshape(input_shape, output_shape):
115-
"""Returns true if the reshape changes the channel dimension"""
116-
if not (
117-
(len(input_shape) == len(output_shape) and (len(output_shape) in (4, 5)))
118-
or (len(input_shape) == 4 and len(output_shape) == 5)
119-
or (len(input_shape) == 5 and len(output_shape) == 4)
120-
):
116+
"""Returns true if reshape changes the channel dimension or batch product dimension(s)"""
117+
118+
valid_ranks = {4, 5, 6}
119+
120+
if not (len(input_shape) in valid_ranks and len(output_shape) in valid_ranks):
121121
return False
122122

123123
C_old = input_shape[-3]
124124
C_new = output_shape[-3]
125125

126-
N_new = (
127-
output_shape[0]
128-
if len(output_shape) == 4
129-
else output_shape[0] * output_shape[1]
130-
)
131-
N_old = (
132-
input_shape[0] if len(input_shape) == 4 else input_shape[0] * input_shape[1]
133-
)
126+
def get_batch_prod_dim(shape):
127+
product = 1
128+
129+
for dim in shape[:-3]:
130+
product = product * dim
131+
132+
return product
133+
134+
N_old = get_batch_prod_dim(input_shape)
135+
N_new = get_batch_prod_dim(output_shape)
134136

135137
return (N_old != N_new) or (C_old != C_new)
136138

@@ -141,17 +143,27 @@ def insert_input_transpose(node, input_node, graph_module):
141143
node.replace_input_with(input_node, pre_permute_node)
142144
return
143145

146+
if len(get_first_fake_tensor(input_node).size()) == 6:
147+
mem_format = NNNHWC_INVERSE_ORDER
148+
elif len(get_first_fake_tensor(input_node).size()) == 5:
149+
mem_format = NNHWC_INVERSE_ORDER
150+
else:
151+
mem_format = NHWC_INVERSE_ORDER
152+
# Guard: mem_format must be a true permutation for the current rank
153+
_rank_ = len(
154+
get_first_fake_tensor(input_node).size()
155+
) # or (node) in output path
156+
assert sorted(mem_format) == list(
157+
range(_rank_)
158+
), f"bad perm {mem_format} for rank {_rank_} in insert_input_transpose"
159+
144160
with graph_module.graph.inserting_before(node):
145161
permute_node = create_node(
146162
graph_module.graph,
147163
exir_ops.backend.tosa.TRANSPOSE.default,
148164
args=(
149165
input_node,
150-
list(
151-
NNHWC_INVERSE_ORDER
152-
if len(get_first_fake_tensor(input_node).size()) == 5
153-
else NHWC_INVERSE_ORDER
154-
),
166+
list(mem_format),
155167
),
156168
from_node=node,
157169
)
@@ -163,26 +175,38 @@ def insert_input_transpose(node, input_node, graph_module):
163175

164176
@staticmethod
165177
def insert_output_transpose(node, graph_module):
178+
179+
if len(get_first_fake_tensor(node).size()) == 6:
180+
mem_format = NNNHWC_ORDER
181+
elif len(get_first_fake_tensor(node).size()) == 5:
182+
mem_format = NNHWC_ORDER
183+
else:
184+
mem_format = NHWC_ORDER
185+
# Guard: mem_format must be a true permutation for the current rank
186+
_rank_ = len(get_first_fake_tensor(node).size()) # or (node) in output path
187+
assert sorted(mem_format) == list(
188+
range(_rank_)
189+
), f"bad perm {mem_format} for rank {_rank_} in insert_input_transpose"
190+
166191
with graph_module.graph.inserting_after(node):
167192
permute_node = create_node(
168193
graph_module.graph,
169194
exir_ops.backend.tosa.TRANSPOSE.default,
170195
args=(
171196
node,
172-
list(
173-
NNHWC_ORDER
174-
if len(get_first_fake_tensor(node).size()) == 5
175-
else NHWC_ORDER
176-
),
197+
list(mem_format),
177198
),
178199
from_node=node,
179200
)
180201

181-
permute_node.meta["tosa_dim_order"] = (
182-
NNHWC_ORDER
183-
if len(get_first_fake_tensor(node).size()) == 5
184-
else NHWC_ORDER
185-
)
202+
rank = len(get_first_fake_tensor(node).size())
203+
if rank == 6:
204+
permute_node.meta["tosa_dim_order"] = NNNHWC_ORDER
205+
elif rank == 5:
206+
permute_node.meta["tosa_dim_order"] = NNHWC_ORDER
207+
else:
208+
permute_node.meta["tosa_dim_order"] = NHWC_ORDER
209+
186210
node.meta["tosa_dim_order"] = tuple(
187211
range(len(get_first_fake_tensor(node).size()))
188212
)
@@ -261,7 +285,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
261285
]
262286
for input_node in inputs:
263287
input_dim_order = get_first_fake_tensor(input_node).dim_order()
264-
if input_dim_order in (NCHW_ORDER, NNCHW_ORDER):
288+
if input_dim_order in (NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER):
265289
self.insert_output_transpose(input_node, graph_module)
266290

267291
# Transpose outputs if they are in (N)NCHW format
@@ -276,6 +300,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
276300
if output_dim_order in (
277301
NCHW_ORDER,
278302
NNCHW_ORDER,
303+
NNNCHW_ORDER,
279304
):
280305
self.insert_input_transpose(
281306
output_node, output_node_input, graph_module
@@ -313,6 +338,8 @@ def call(self, graph_module: torch.fx.GraphModule):
313338
dim_order = HWCM_ORDER
314339
elif node_data.dim() == 5:
315340
dim_order = NNHWC_ORDER
341+
elif node_data.dim() == 6:
342+
dim_order = NNNHWC_ORDER
316343
else:
317344
dim_order = tuple(range(node_data.dim())) # type: ignore[assignment]
318345

backends/arm/constants.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,13 @@
3434
NHWC_INVERSE_ORDER: Final = (0, 3, 1, 2)
3535
NNHWC_ORDER: Final = (0, 1, 3, 4, 2)
3636
NNHWC_INVERSE_ORDER: Final = (0, 1, 4, 2, 3)
37+
NNNHWC_ORDER: Final = (0, 1, 2, 4, 5, 3)
38+
NNNHWC_INVERSE_ORDER: Final = (0, 1, 2, 5, 3, 4)
3739

3840
NCHW_ORDER: Final = (0, 1, 2, 3)
39-
NCHW_INVERSE_ORDER: Final = (0, 2, 3, 1)
4041
NNCHW_ORDER: Final = (0, 1, 2, 3, 4)
41-
NNCHW_INVERSE_ORDER: Final = (0, 1, 3, 4, 2)
42+
NNNCHW_ORDER: Final = (0, 1, 2, 3, 4, 5)
4243

4344
HWCM_ORDER: Final = (2, 3, 0, 1)
45+
46+
MAX_RANK: Final = 6

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
FuseQuantizedActivationPass,
2020
)
2121
from executorch.backends.arm._passes.insert_table_ops import TableOps
22-
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
22+
from executorch.backends.arm.constants import DQ_OPS, MAX_RANK, Q_OPS
2323
from executorch.backends.arm.operator_support.ethos_u55_support import (
2424
EthosU55CastCheck,
2525
EthosU55DtypeSupport,
@@ -127,7 +127,7 @@ def tosa_support_factory(
127127
negative_checks: list[OperatorSupportBase] = [
128128
CheckInt64InputsAndOutputs(exported_program, reporter),
129129
CheckFloat64Inputs(exported_program, reporter),
130-
RankCheck(reporter, max_rank=5),
130+
RankCheck(reporter, max_rank=MAX_RANK),
131131
*[
132132
reporter.wrap_check(check, f"Rejected by {check.__class__.__name__}")
133133
for check in (additional_checks if additional_checks else [])

backends/arm/quantizer/quantization_annotator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,8 @@ def _match_pattern(
370370
torch.ops.aten.dropout_.default,
371371
torch.ops.aten.adaptive_avg_pool2d.default,
372372
torch.ops.aten.alias_copy.default,
373+
torch.ops.aten.pixel_shuffle.default,
374+
torch.ops.aten.pixel_unshuffle.default,
373375
]
374376

375377

backends/arm/scripts/parse_test_names.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
"_native_batch_norm_legit_no_training.default",
2727
"_native_batch_norm_legit.no_stats",
2828
"alias_copy.default",
29+
"pixel_shuffle.default",
30+
"pixel_unshuffle.default",
2931
]
3032
ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS
3133

backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,12 @@ class TestSD3Transformer2DModel:
3030

3131
# Adjust nbr below as we increase op support.
3232
ops_after_partitioner_FP = {
33-
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
3433
"executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 1,
35-
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 2,
3634
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1,
3735
"torch.ops.higher_order.executorch_call_delegate": 1,
3836
}
3937

4038
ops_after_partitioner_INT = {
41-
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
42-
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 2,
4339
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2,
4440
"torch.ops.higher_order.executorch_call_delegate": 2,
4541
}

0 commit comments

Comments
 (0)