Skip to content

Commit 017cab3

Browse files
Arm backend: Support channels-last input and output (#14400)
- Insert transposes for input/output iff the incoming/outgoing data is in channels first format. - For testing using tosa_reference_mode, transpose numpy arrays to and from correct data format since numpy doesn't have the concept of dim_order. - Remove checks for channels_first only input. - Remove check for not changing dim_order before to_tosa_memory_format pass since the behaviour of channel last tensors is non-predictable. - Add dim order testing of example networks and mv2 - Add a section to the documentation about memory formats. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 Co-authored-by: Adrian Lundell <[email protected]>
1 parent 878b03b commit 017cab3

13 files changed

+296
-181
lines changed

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 49 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,23 @@
99
import logging
1010

1111
import torch
12-
from executorch.backends.arm._passes import AnnotateOutputDimOrderPass
12+
from executorch.backends.arm._passes.annotate_decomposed_matmul import (
13+
AnnotateDecomposedMatmulPass,
14+
)
1315
from executorch.backends.arm._passes.arm_pass_utils import (
1416
create_node,
1517
get_first_fake_tensor,
16-
get_output_dim_orders,
1718
is_param_node,
1819
)
20+
from executorch.backends.arm.constants import (
21+
HWCM_ORDER,
22+
NCHW_ORDER,
23+
NHWC_INVERSE_ORDER,
24+
NHWC_ORDER,
25+
NNCHW_ORDER,
26+
NNHWC_INVERSE_ORDER,
27+
NNHWC_ORDER,
28+
)
1929
from executorch.exir import ExportedProgram
2030
from executorch.exir.dialects._ops import ops as exir_ops
2131
from executorch.exir.pass_base import ExportPass, PassResult
@@ -38,12 +48,6 @@ class ToTosaMemoryFormatPass(ExportPass):
3848
The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape.
3949
"""
4050

41-
NHWC_order = (0, 2, 3, 1)
42-
NHWC_inverse_order = (0, 3, 1, 2)
43-
HWCM_order = (2, 3, 0, 1)
44-
NNHWC_order = (0, 1, 3, 4, 2)
45-
NNHWC_inverse_order = (0, 1, 4, 2, 3)
46-
4751
def __init__(self, exported_program: ExportedProgram) -> None:
4852
self.exported_program = exported_program
4953
super().__init__()
@@ -135,9 +139,9 @@ def insert_input_transpose(node, input_node, graph_module):
135139
args=(
136140
input_node,
137141
list(
138-
ToTosaMemoryFormatPass.NNHWC_inverse_order
142+
NNHWC_INVERSE_ORDER
139143
if len(get_first_fake_tensor(input_node).size()) == 5
140-
else ToTosaMemoryFormatPass.NHWC_inverse_order
144+
else NHWC_INVERSE_ORDER
141145
),
142146
),
143147
from_node=node,
@@ -157,18 +161,18 @@ def insert_output_transpose(node, graph_module):
157161
args=(
158162
node,
159163
list(
160-
ToTosaMemoryFormatPass.NNHWC_order
164+
NNHWC_ORDER
161165
if len(get_first_fake_tensor(node).size()) == 5
162-
else ToTosaMemoryFormatPass.NHWC_order
166+
else NHWC_ORDER
163167
),
164168
),
165169
from_node=node,
166170
)
167171

168172
permute_node.meta["tosa_dim_order"] = (
169-
ToTosaMemoryFormatPass.NNHWC_order
173+
NNHWC_ORDER
170174
if len(get_first_fake_tensor(node).size()) == 5
171-
else ToTosaMemoryFormatPass.NHWC_order
175+
else NHWC_ORDER
172176
)
173177
node.meta["tosa_dim_order"] = tuple(
174178
range(len(get_first_fake_tensor(node).size()))
@@ -218,7 +222,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
218222
for node in graph_module.graph.nodes:
219223
# call_function and placeholder allowed due to
220224
# index.Tensor being able to come in as both
221-
if node.op not in ["call_function", "placeholder", "output"]:
225+
if node.op != "call_function":
222226
continue
223227

224228
# Transpose views
@@ -240,21 +244,33 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
240244
graph_module,
241245
)
242246

243-
# Transpose inputs
244-
elif _is_input(node, self.exported_program):
245-
input_shape = get_first_fake_tensor(node).size()
246-
if len(input_shape) in (4, 5):
247-
ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module)
247+
output_node = graph_module.graph.output_node()
248248

249-
# Transpose outputs
250-
elif node.op == "output":
251-
output_shape = get_first_fake_tensor(node).size()
249+
# Transpose inputs if they are in (N)NCHW format
250+
inputs = [
251+
n for n in graph_module.graph.nodes if _is_input(n, self.exported_program)
252+
]
253+
for input_node in inputs:
254+
input_dim_order = get_first_fake_tensor(input_node).dim_order()
255+
if input_dim_order in (NCHW_ORDER, NNCHW_ORDER):
256+
self.insert_output_transpose(input_node, graph_module)
257+
258+
# Transpose outputs if they are in (N)NCHW format
259+
outputs = output_node.args[0]
260+
output_dim_orders = output_node.meta.get("original_dim_orders")
261+
if output_dim_orders is None:
262+
raise RuntimeError(
263+
f"{AnnotateDecomposedMatmulPass.__name__} is required to run at the beginning of the pass pipeline when using {ToTosaMemoryFormatPass.__name__}."
264+
)
252265

253-
if len(output_shape) in (4, 5):
254-
for input_node in node.all_input_nodes:
255-
ToTosaMemoryFormatPass.insert_input_transpose(
256-
node, input_node, graph_module
257-
)
266+
for output_node_input, output_dim_order in zip(outputs, output_dim_orders): # type: ignore[arg-type]
267+
if output_dim_order in (
268+
NCHW_ORDER,
269+
NNCHW_ORDER,
270+
):
271+
self.insert_input_transpose(
272+
output_node, output_node_input, graph_module
273+
)
258274

259275
def remove_dim_order_kwargs(
260276
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node
@@ -277,17 +293,17 @@ def call(self, graph_module: torch.fx.GraphModule):
277293
node_data = get_first_fake_tensor(node).data
278294

279295
self.remove_dim_order_kwargs(graph_module, node)
280-
# Inputs and outputs are always in (N)NCHW format
296+
# Inputs and outputs may vary in dim_order
281297
if _is_input(node, self.exported_program) or node.op == "output":
282-
dim_order = tuple(range(node_data.dim()))
298+
dim_order = node_data.dim_order()
283299
elif node_data.dim() == 4:
284-
dim_order = self.NHWC_order
300+
dim_order = NHWC_ORDER
285301
if self.is_weight_node_for_depthwise_conv2d(node):
286302
# The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to
287303
# dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
288-
dim_order = self.HWCM_order
304+
dim_order = HWCM_ORDER
289305
elif node_data.dim() == 5:
290-
dim_order = self.NNHWC_order
306+
dim_order = NNHWC_ORDER
291307
else:
292308
dim_order = tuple(range(node_data.dim())) # type: ignore[assignment]
293309

@@ -300,32 +316,3 @@ def call(self, graph_module: torch.fx.GraphModule):
300316
graph_module = super().call(graph_module).graph_module
301317

302318
return PassResult(graph_module, True)
303-
304-
def requires(self, graph_module) -> None:
305-
"""
306-
This is the only pass which handles dim_orders, so verify that the output dim_orders has not changed since the beginning of the lowering pipeline.
307-
"""
308-
309-
dim_orders = get_output_dim_orders(graph_module)
310-
original_dim_orders = graph_module.graph.output_node().meta.get(
311-
"original_dim_orders"
312-
)
313-
output_node = graph_module.graph.output_node()
314-
315-
if original_dim_orders is None:
316-
raise RuntimeError(
317-
f"{AnnotateOutputDimOrderPass.__name__} must be run in the beginning of the pass pipeline to verify that the dim order has not changed unexpectedly during its run."
318-
)
319-
320-
if len(dim_orders) != len(original_dim_orders):
321-
raise RuntimeError(
322-
f"The number of outputs has changed since {AnnotateOutputDimOrderPass.__name__} was run."
323-
)
324-
325-
for node, dim_order, original_dim_order in zip(
326-
output_node.args[0], dim_orders, original_dim_orders
327-
):
328-
if dim_order != original_dim_order:
329-
raise RuntimeError(
330-
f"The dim order of output {node.name} has changed from {original_dim_order} to {dim_order} since {AnnotateOutputDimOrderPass.__name__} was run."
331-
)

backends/arm/constants.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,15 @@
2929
DEQUANT_PER_TENSOR_OP_T,
3030
)
3131
PER_CHANNEL_QDQ_OPS: Final = (QUANT_PER_CHANNEL_OP, DEQUANT_PER_CHANNEL_OP)
32+
33+
NHWC_ORDER: Final = (0, 2, 3, 1)
34+
NHWC_INVERSE_ORDER: Final = (0, 3, 1, 2)
35+
NNHWC_ORDER: Final = (0, 1, 3, 4, 2)
36+
NNHWC_INVERSE_ORDER: Final = (0, 1, 4, 2, 3)
37+
38+
NCHW_ORDER: Final = (0, 1, 2, 3)
39+
NCHW_INVERSE_ORDER: Final = (0, 2, 3, 1)
40+
NNCHW_ORDER: Final = (0, 1, 2, 3, 4)
41+
NNCHW_INVERSE_ORDER: Final = (0, 1, 3, 4, 2)
42+
43+
HWCM_ORDER: Final = (2, 3, 0, 1)

backends/arm/operator_support/to_dim_order_copy_support.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def _merge_supported_types(
8989
torch.int32,
9090
torch.bfloat16,
9191
torch.float16,
92+
torch.float32,
9293
],
9394
}
9495
ALL_SUPPORTED_TYPES = _merge_supported_types(

backends/arm/process_node.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,6 @@ def process_inputs(
7070
tosa_spec: TosaSpecification,
7171
):
7272
"""Serialize an input node"""
73-
# inputs need to be in default dim_order (contiguous memory format)
74-
meta = node.meta["val"]
75-
if meta.dim_order() != tuple(range(meta.dim())):
76-
raise RuntimeError(
77-
f"Arm backend only supports contiguous memory format for inputs. "
78-
f"Expected dim_order: {tuple(range(meta.dim()))}, but got: {meta.dim_order()} for node {node.name}"
79-
)
8073
try:
8174
tosa_arg = TosaArg(node, tosa_spec)
8275
except ValueError as e:

backends/arm/runtime/EthosUBackend.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -249,15 +249,6 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
249249
handles.inputs->io[i].elem_size);
250250
return Error::InvalidProgram;
251251
}
252-
supported = executorch::runtime::is_contiguous_dim_order(
253-
tensor_in.dim_order().data(), tensor_in.dim());
254-
if (!supported) {
255-
ET_LOG(
256-
Error,
257-
"Input %d expected contiguous dim_order, but got non-contiguous dim_order",
258-
i);
259-
return Error::InvalidProgram;
260-
}
261252

262253
// Select a compatible copy routine including checking for input layouts
263254
// which require permutation.
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Tuple
8+
9+
import torch
10+
from executorch.backends.arm.test import common
11+
12+
from executorch.backends.arm.test.tester.test_pipeline import (
13+
EthosU55PipelineINT,
14+
EthosU85PipelineINT,
15+
TosaPipelineFP,
16+
TosaPipelineINT,
17+
)
18+
19+
20+
input_t1 = Tuple[torch.Tensor] # Input x
21+
22+
23+
class ChannelsLastInput(torch.nn.Module):
24+
"""
25+
Test a complex case with (channels last, channels first) input,
26+
and (channels first, channels last) output.
27+
"""
28+
29+
inputs: input_t1 = (
30+
torch.arange(1, 25, dtype=torch.float32)
31+
.reshape((1, 2, 3, 4))
32+
.to(memory_format=torch.channels_last),
33+
torch.arange(1, 25, dtype=torch.float32).reshape((1, 2, 3, 4)),
34+
)
35+
36+
def forward(self, x, y):
37+
x = x * x
38+
return y, x
39+
40+
41+
class ChannelsFirstOutput(torch.nn.Module):
42+
"""
43+
Test coverting to channels_first inside the delegate.
44+
"""
45+
46+
inputs: input_t1 = (
47+
torch.arange(1, 25, dtype=torch.float32)
48+
.reshape((1, 2, 3, 4))
49+
.to(memory_format=torch.channels_last),
50+
)
51+
52+
def forward(self, x):
53+
x = x.clone(memory_format=torch.contiguous_format) * x
54+
return x
55+
56+
57+
class ChannelsLastOutput(torch.nn.Module):
58+
"""
59+
Test changing of dim_order inside the delegate.
60+
"""
61+
62+
inputs: input_t1 = (torch.arange(1, 9, dtype=torch.float32).reshape((1, 2, 2, 2)),)
63+
64+
def forward(self, x):
65+
x = x * x
66+
x = x.clone(memory_format=torch.channels_last)
67+
return x
68+
69+
70+
class ChannelsLastInsidePartition(torch.nn.Module):
71+
"""
72+
Test dim_order changes inside the partiton, but no dim_order changes at input/output.
73+
"""
74+
75+
inputs: input_t1 = (torch.randn((1, 2, 3, 3)),)
76+
77+
def __init__(self):
78+
super().__init__()
79+
self.conv2d = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=(3, 3))
80+
81+
def forward(self, x):
82+
return (
83+
self.conv2d(x.clone(memory_format=torch.channels_last)).clone(
84+
memory_format=torch.contiguous_format
85+
)
86+
* 1
87+
)
88+
89+
90+
test_modules = {
91+
"channels_last_input": ChannelsLastInput,
92+
"channels_first_output": ChannelsFirstOutput,
93+
"channels_last_output": ChannelsLastOutput,
94+
"channels_last_inside_partition": ChannelsLastInsidePartition,
95+
}
96+
97+
98+
@common.parametrize("module", test_modules)
99+
def test_dim_order_tosa_FP(module):
100+
pipeline = TosaPipelineFP[input_t1](module(), module.inputs, [])
101+
pipeline.run()
102+
103+
104+
@common.parametrize("module", test_modules)
105+
def test_dim_order_tosa_INT(module):
106+
pipeline = TosaPipelineINT[input_t1](
107+
module(), module.inputs, [], symmetric_io_quantization=True
108+
)
109+
pipeline.run()
110+
111+
112+
@common.XfailIfNoCorstone300
113+
@common.parametrize("module", test_modules)
114+
def test_dim_order_u55_INT(module):
115+
pipeline = EthosU55PipelineINT[input_t1](module(), module.inputs, [])
116+
pipeline.run()
117+
118+
119+
@common.XfailIfNoCorstone320
120+
@common.parametrize("module", test_modules)
121+
def test_dim_order_u85_INT(module):
122+
pipeline = EthosU85PipelineINT[input_t1](module(), module.inputs, [])
123+
pipeline.run()

0 commit comments

Comments
 (0)