Skip to content

Commit 5d656a8

Browse files
mansnilspytorchbot
authored andcommitted
Arm backend: Add 6D tensor and pixel shuffle/unshuffle support (#14626)
Adds 6D tensor support required by pixel_shuffle/pixel_unshuffle when given 4D inputs, which means for now we only support 4D inputs. Adds TOSA, VGF and xfailing Ethos-U85 unit tests. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 (cherry picked from commit 6e7353f)
1 parent 54030da commit 5d656a8

File tree

8 files changed

+312
-41
lines changed

8 files changed

+312
-41
lines changed

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 66 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
NNCHW_ORDER,
2626
NNHWC_INVERSE_ORDER,
2727
NNHWC_ORDER,
28+
NNNCHW_ORDER,
29+
NNNHWC_INVERSE_ORDER,
30+
NNNHWC_ORDER,
2831
)
2932
from executorch.exir import ExportedProgram
3033
from executorch.exir.dialects._ops import ops as exir_ops
@@ -48,6 +51,8 @@ class ToTosaMemoryFormatPass(ExportPass):
4851
The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape.
4952
"""
5053

54+
_passes_required_after: Set[Type[ExportPass]] = set()
55+
5156
def __init__(self, exported_program: ExportedProgram) -> None:
5257
self.exported_program = exported_program
5358
super().__init__()
@@ -84,7 +89,11 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
8489
@staticmethod
8590
def memory_format_differs(shape):
8691
"""Returns true if the shape will have a different memory layout in (N)NCHW and (N)NHWC format"""
87-
if len(shape) >= 5:
92+
if len(shape) >= 6:
93+
C = shape[3]
94+
H = shape[4]
95+
W = shape[5]
96+
elif len(shape) == 5:
8897
C = shape[2]
8998
H = shape[3]
9099
W = shape[4]
@@ -103,25 +112,26 @@ def memory_format_differs(shape):
103112

104113
@staticmethod
105114
def is_channel_reshape(input_shape, output_shape):
106-
"""Returns true if the reshape changes the channel dimension"""
107-
if not (
108-
(len(input_shape) == len(output_shape) and (len(output_shape) in (4, 5)))
109-
or (len(input_shape) == 4 and len(output_shape) == 5)
110-
or (len(input_shape) == 5 and len(output_shape) == 4)
111-
):
115+
"""Returns true if reshape changes the channel dimension or batch product dimension(s)"""
116+
117+
valid_ranks = {4, 5, 6}
118+
119+
if not (len(input_shape) in valid_ranks and len(output_shape) in valid_ranks):
112120
return False
113121

114122
C_old = input_shape[-3]
115123
C_new = output_shape[-3]
116124

117-
N_new = (
118-
output_shape[0]
119-
if len(output_shape) == 4
120-
else output_shape[0] * output_shape[1]
121-
)
122-
N_old = (
123-
input_shape[0] if len(input_shape) == 4 else input_shape[0] * input_shape[1]
124-
)
125+
def get_batch_prod_dim(shape):
126+
product = 1
127+
128+
for dim in shape[:-3]:
129+
product = product * dim
130+
131+
return product
132+
133+
N_old = get_batch_prod_dim(input_shape)
134+
N_new = get_batch_prod_dim(output_shape)
125135

126136
return (N_old != N_new) or (C_old != C_new)
127137

@@ -132,17 +142,27 @@ def insert_input_transpose(node, input_node, graph_module):
132142
node.replace_input_with(input_node, pre_permute_node)
133143
return
134144

145+
if len(get_first_fake_tensor(input_node).size()) == 6:
146+
mem_format = NNNHWC_INVERSE_ORDER
147+
elif len(get_first_fake_tensor(input_node).size()) == 5:
148+
mem_format = NNHWC_INVERSE_ORDER
149+
else:
150+
mem_format = NHWC_INVERSE_ORDER
151+
# Guard: mem_format must be a true permutation for the current rank
152+
_rank_ = len(
153+
get_first_fake_tensor(input_node).size()
154+
) # or (node) in output path
155+
assert sorted(mem_format) == list(
156+
range(_rank_)
157+
), f"bad perm {mem_format} for rank {_rank_} in insert_input_transpose"
158+
135159
with graph_module.graph.inserting_before(node):
136160
permute_node = create_node(
137161
graph_module.graph,
138162
exir_ops.backend.tosa.TRANSPOSE.default,
139163
args=(
140164
input_node,
141-
list(
142-
NNHWC_INVERSE_ORDER
143-
if len(get_first_fake_tensor(input_node).size()) == 5
144-
else NHWC_INVERSE_ORDER
145-
),
165+
list(mem_format),
146166
),
147167
from_node=node,
148168
)
@@ -154,26 +174,38 @@ def insert_input_transpose(node, input_node, graph_module):
154174

155175
@staticmethod
156176
def insert_output_transpose(node, graph_module):
177+
178+
if len(get_first_fake_tensor(node).size()) == 6:
179+
mem_format = NNNHWC_ORDER
180+
elif len(get_first_fake_tensor(node).size()) == 5:
181+
mem_format = NNHWC_ORDER
182+
else:
183+
mem_format = NHWC_ORDER
184+
# Guard: mem_format must be a true permutation for the current rank
185+
_rank_ = len(get_first_fake_tensor(node).size()) # or (node) in output path
186+
assert sorted(mem_format) == list(
187+
range(_rank_)
188+
), f"bad perm {mem_format} for rank {_rank_} in insert_input_transpose"
189+
157190
with graph_module.graph.inserting_after(node):
158191
permute_node = create_node(
159192
graph_module.graph,
160193
exir_ops.backend.tosa.TRANSPOSE.default,
161194
args=(
162195
node,
163-
list(
164-
NNHWC_ORDER
165-
if len(get_first_fake_tensor(node).size()) == 5
166-
else NHWC_ORDER
167-
),
196+
list(mem_format),
168197
),
169198
from_node=node,
170199
)
171200

172-
permute_node.meta["tosa_dim_order"] = (
173-
NNHWC_ORDER
174-
if len(get_first_fake_tensor(node).size()) == 5
175-
else NHWC_ORDER
176-
)
201+
rank = len(get_first_fake_tensor(node).size())
202+
if rank == 6:
203+
permute_node.meta["tosa_dim_order"] = NNNHWC_ORDER
204+
elif rank == 5:
205+
permute_node.meta["tosa_dim_order"] = NNHWC_ORDER
206+
else:
207+
permute_node.meta["tosa_dim_order"] = NHWC_ORDER
208+
177209
node.meta["tosa_dim_order"] = tuple(
178210
range(len(get_first_fake_tensor(node).size()))
179211
)
@@ -252,7 +284,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
252284
]
253285
for input_node in inputs:
254286
input_dim_order = get_first_fake_tensor(input_node).dim_order()
255-
if input_dim_order in (NCHW_ORDER, NNCHW_ORDER):
287+
if input_dim_order in (NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER):
256288
self.insert_output_transpose(input_node, graph_module)
257289

258290
# Transpose outputs if they are in (N)NCHW format
@@ -267,6 +299,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
267299
if output_dim_order in (
268300
NCHW_ORDER,
269301
NNCHW_ORDER,
302+
NNNCHW_ORDER,
270303
):
271304
self.insert_input_transpose(
272305
output_node, output_node_input, graph_module
@@ -304,6 +337,8 @@ def call(self, graph_module: torch.fx.GraphModule):
304337
dim_order = HWCM_ORDER
305338
elif node_data.dim() == 5:
306339
dim_order = NNHWC_ORDER
340+
elif node_data.dim() == 6:
341+
dim_order = NNNHWC_ORDER
307342
else:
308343
dim_order = tuple(range(node_data.dim())) # type: ignore[assignment]
309344

backends/arm/constants.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,13 @@
3434
NHWC_INVERSE_ORDER: Final = (0, 3, 1, 2)
3535
NNHWC_ORDER: Final = (0, 1, 3, 4, 2)
3636
NNHWC_INVERSE_ORDER: Final = (0, 1, 4, 2, 3)
37+
NNNHWC_ORDER: Final = (0, 1, 2, 4, 5, 3)
38+
NNNHWC_INVERSE_ORDER: Final = (0, 1, 2, 5, 3, 4)
3739

3840
NCHW_ORDER: Final = (0, 1, 2, 3)
39-
NCHW_INVERSE_ORDER: Final = (0, 2, 3, 1)
4041
NNCHW_ORDER: Final = (0, 1, 2, 3, 4)
41-
NNCHW_INVERSE_ORDER: Final = (0, 1, 3, 4, 2)
42+
NNNCHW_ORDER: Final = (0, 1, 2, 3, 4, 5)
4243

4344
HWCM_ORDER: Final = (2, 3, 0, 1)
45+
46+
MAX_RANK: Final = 6

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
FuseQuantizedActivationPass,
2020
)
2121
from executorch.backends.arm._passes.insert_table_ops import TableOps
22-
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
22+
from executorch.backends.arm.constants import DQ_OPS, MAX_RANK, Q_OPS
2323
from executorch.backends.arm.operator_support.ethos_u55_support import (
2424
EthosU55DtypeSupport,
2525
EthosU55NotSupported,
@@ -126,7 +126,7 @@ def tosa_support_factory(
126126
negative_checks: list[OperatorSupportBase] = [
127127
CheckInt64InputsAndOutputs(exported_program, reporter),
128128
CheckFloat64Inputs(exported_program, reporter),
129-
RankCheck(reporter, max_rank=5),
129+
RankCheck(reporter, max_rank=MAX_RANK),
130130
*[
131131
reporter.wrap_check(check, f"Rejected by {check.__class__.__name__}")
132132
for check in (additional_checks if additional_checks else [])

backends/arm/quantizer/quantization_annotator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,8 @@ def _match_pattern(
365365
torch.ops.aten.dropout_.default,
366366
torch.ops.aten.adaptive_avg_pool2d.default,
367367
torch.ops.aten.alias_copy.default,
368+
torch.ops.aten.pixel_shuffle.default,
369+
torch.ops.aten.pixel_unshuffle.default,
368370
]
369371

370372

backends/arm/scripts/parse_test_names.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
"_native_batch_norm_legit_no_training.default",
2727
"_native_batch_norm_legit.no_stats",
2828
"alias_copy.default",
29+
"pixel_shuffle.default",
30+
"pixel_unshuffle.default",
2931
]
3032
ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS
3133

backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,12 @@ class TestSD3Transformer2DModel(unittest.TestCase):
2424

2525
# Adjust nbr below as we increase op support.
2626
ops_after_partitioner_FP = {
27-
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
2827
"executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 1,
29-
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 2,
3028
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1,
3129
"torch.ops.higher_order.executorch_call_delegate": 1,
3230
}
3331

3432
ops_after_partitioner_INT = {
35-
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
36-
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 2,
3733
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2,
3834
"torch.ops.higher_order.executorch_call_delegate": 2,
3935
}

0 commit comments

Comments
 (0)