Skip to content

Commit 9c5ef40

Browse files
committed
Arm backend: Make TOSA backend NCHW-compatible
- Moves the needed input/output transposes into the delegated graph to run on Ethos-U rather than requiring the EthosUBackend to implement transposes on CPU. - Renames the annotate_channels_last_dim_order_pass to to_tosa_memory_format_pass since to be more descriptive and future proof. This changes additionally enables running muliple batches since the EtohsU transpose supports that natively, whereas the CPU implementation did not. Change-Id: I676e5915b15cbcc370a03d70bfa2ea2fe20b2210 Signed-off-by: Adrian Lundell <[email protected]>
1 parent ee936b0 commit 9c5ef40

36 files changed

+405
-444
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from . import arm_pass_utils # noqa
88
from .arm_pass import ArmPass # noqa # usort: skip
99
from .add_bias_pass import AddBiasPass # noqa
10-
from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa
1110
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
1211
from .broadcast_args_pass import BroadcastArgsPass # noqa
1312
from .cast_bool_to_int8_pass import CastBoolToInt8Pass # noqa
@@ -80,6 +79,7 @@
8079
)
8180
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa
8281
from .size_adjust_input_pass import SizeAdjustInputPass # noqa
82+
from .to_tosa_memory_format_pass import ToTosaMemoryFormatPass # noqa
8383
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa
8484
from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa
8585
from .replace_inf_values_pass import ReplaceInfValues # noqa # usort: skip

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import executorch.backends.arm.tosa.dialect # noqa: unused
1111
from executorch.backends.arm._passes import (
1212
AddBiasPass,
13-
AnnotateChannelsLastDimOrder,
1413
AnnotateDecomposedMatmulPass,
1514
BroadcastArgsPass,
1615
CastBoolToInt8Pass,
@@ -79,6 +78,7 @@
7978
RetraceFoldedDtypesPass,
8079
ScalarsToAttributePass,
8180
SizeAdjustInputPass,
81+
ToTosaMemoryFormatPass,
8282
UnsqueezeBeforeRepeatPass,
8383
UnsqueezeScalarPlaceholdersPass,
8484
)
@@ -156,7 +156,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
156156

157157
self.add_pass(InsertTableOpsPass(exported_program))
158158
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
159-
self.add_pass(AnnotateChannelsLastDimOrder())
159+
self.add_pass(ToTosaMemoryFormatPass(exported_program))
160160
self.add_pass(InsertRescalePass())
161161

162162
return self._transform(exported_program.graph_module)
@@ -230,7 +230,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
230230
self.add_pass(AddBiasPass(exported_program))
231231
self.add_pass(InsertTableOpsPass(exported_program))
232232
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
233-
self.add_pass(AnnotateChannelsLastDimOrder())
233+
self.add_pass(ToTosaMemoryFormatPass(exported_program))
234234
self.add_pass(InsertRescalePass())
235235

236236
return self._transform(exported_program.graph_module)

backends/arm/_passes/decompose_select.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
# pyre-unsafe
88

99
import torch
10-
from executorch.backends.arm._passes.arm_pass_utils import create_node
10+
from executorch.backends.arm._passes.arm_pass_utils import (
11+
create_node,
12+
get_first_fake_tensor,
13+
)
1114
from executorch.exir.dialects._ops import ops as exir_ops
1215
from executorch.exir.pass_base import ExportPass, PassResult
1316

@@ -34,8 +37,9 @@ def call(self, graph_module: torch.fx.GraphModule):
3437

3538
input_node, dim, index = node.args
3639

37-
rank = len(input_node.meta["val"].size())
38-
shape = input_node.meta["val"].shape
40+
input_tensor = get_first_fake_tensor(input_node)
41+
rank = len(input_tensor.size())
42+
shape = input_tensor.shape
3943
dim = dim % rank if dim < 0 else dim
4044
index = index % shape[dim] if index < 0 else index
4145

@@ -44,7 +48,7 @@ def call(self, graph_module: torch.fx.GraphModule):
4448
graph_module.graph, slice_op, (input_node, dim, index, index + 1)
4549
)
4650
squeeze_node = create_node(
47-
graph_module.graph, squeeze_op, (slice_node, [dim])
51+
graph_module.graph, squeeze_op, (slice_node, [dim]), from_node=node
4852
)
4953

5054
node.replace_all_uses_with(squeeze_node)

backends/arm/_passes/annotate_channels_last_dim_order_pass.py renamed to backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 67 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
from executorch.backends.arm._passes.arm_pass_utils import (
1111
create_node,
1212
get_first_fake_tensor,
13+
is_param_node,
1314
)
1415
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
16+
from executorch.exir import ExportedProgram
1517
from executorch.exir.dialects._ops import ops as exir_ops
1618
from executorch.exir.pass_base import ExportPass, PassResult
1719
from torch.library import impl, Library
@@ -40,7 +42,14 @@ def _transpose_impl(*args, **kwargs):
4042
return args[0]
4143

4244

43-
class AnnotateChannelsLastDimOrder(ExportPass):
45+
def _is_input(node: torch.fx.Node, exported_program: ExportedProgram) -> bool:
46+
"""
47+
Returns True if the node is an input node, i.e. a placeholder or a parameter.
48+
"""
49+
return node.op == "placeholder" and not is_param_node(exported_program, node)
50+
51+
52+
class ToTosaMemoryFormatPass(ExportPass):
4453
"""
4554
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
4655
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts passthrough_to_tosa._transpose
@@ -54,6 +63,10 @@ class AnnotateChannelsLastDimOrder(ExportPass):
5463
NNHWC_order = (0, 1, 3, 4, 2)
5564
NNHWC_inverse_order = (0, 1, 4, 2, 3)
5665

66+
def __init__(self, exported_program: ExportedProgram) -> None:
67+
self.exported_program = exported_program
68+
super().__init__()
69+
5770
def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
5871
"""
5972
returns True for w in the following sequence;
@@ -70,6 +83,7 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
7083

7184
return False
7285

86+
@staticmethod
7387
@staticmethod
7488
def memory_format_differs(shape):
7589
"""Returns true if the shape will have a different memory layout in (N)NCHW and (N)NHWC format"""
@@ -116,25 +130,30 @@ def is_channel_reshape(input_shape, output_shape):
116130

117131
@staticmethod
118132
def insert_input_transpose(node, input_node, graph_module):
133+
if input_node.target == torch.ops.passthrough_to_tosa._transpose.default:
134+
pre_permute_node = input_node.all_input_nodes[0]
135+
node.replace_input_with(input_node, pre_permute_node)
136+
return
137+
119138
with graph_module.graph.inserting_before(node):
120139
permute_node = create_node(
121140
graph_module.graph,
122141
torch.ops.passthrough_to_tosa._transpose.default,
123142
args=(
124143
input_node,
125144
list(
126-
AnnotateChannelsLastDimOrder.NNHWC_inverse_order
145+
ToTosaMemoryFormatPass.NNHWC_inverse_order
127146
if len(get_first_fake_tensor(input_node).size()) == 5
128-
else AnnotateChannelsLastDimOrder.NHWC_inverse_order
147+
else ToTosaMemoryFormatPass.NHWC_inverse_order
129148
),
130149
),
150+
from_node=node,
131151
)
132152
node.replace_input_with(input_node, permute_node)
133153

134154
permute_node.meta["tosa_dim_order"] = tuple(
135155
range(len(input_node.meta["val"].size()))
136156
)
137-
permute_node.meta["val"] = input_node.meta["val"]
138157

139158
@staticmethod
140159
def insert_output_transpose(node, graph_module):
@@ -145,25 +164,23 @@ def insert_output_transpose(node, graph_module):
145164
args=(
146165
node,
147166
list(
148-
AnnotateChannelsLastDimOrder.NNHWC_order
167+
ToTosaMemoryFormatPass.NNHWC_order
149168
if len(get_first_fake_tensor(node).size()) == 5
150-
else AnnotateChannelsLastDimOrder.NHWC_order
169+
else ToTosaMemoryFormatPass.NHWC_order
151170
),
152171
),
172+
from_node=node,
153173
)
174+
154175
permute_node.meta["tosa_dim_order"] = (
155-
AnnotateChannelsLastDimOrder.NNHWC_order
176+
ToTosaMemoryFormatPass.NNHWC_order
156177
if len(get_first_fake_tensor(node).size()) == 5
157-
else AnnotateChannelsLastDimOrder.NHWC_order
158-
)
159-
permute_node.meta["val"] = get_first_fake_tensor(node).permute(
160-
AnnotateChannelsLastDimOrder.NNHWC_order
161-
if len(get_first_fake_tensor(node).size()) == 5
162-
else AnnotateChannelsLastDimOrder.NHWC_order
178+
else ToTosaMemoryFormatPass.NHWC_order
163179
)
164180
node.meta["tosa_dim_order"] = tuple(
165181
range(len(get_first_fake_tensor(node).size()))
166182
)
183+
167184
users = [user for user in node.users if user != permute_node]
168185
for user in users:
169186
user.replace_input_with(node, permute_node)
@@ -174,20 +191,23 @@ def _insert_view_transpose(
174191
):
175192
nchw_to_nhwc = len(input_shape) < 4 and len(output_shape) >= 4
176193
nhwc_to_nchw = len(input_shape) >= 4 and len(output_shape) < 4
177-
channel_reshape = AnnotateChannelsLastDimOrder.is_channel_reshape(
194+
channel_reshape = ToTosaMemoryFormatPass.is_channel_reshape(
178195
output_shape, input_shape
179196
)
180197

181198
if (
182199
channel_reshape or nhwc_to_nchw
183-
) and AnnotateChannelsLastDimOrder.memory_format_differs(input_shape):
184-
AnnotateChannelsLastDimOrder.insert_input_transpose(
200+
) and ToTosaMemoryFormatPass.memory_format_differs(input_shape):
201+
202+
ToTosaMemoryFormatPass.insert_input_transpose(
185203
node, input_node, graph_module
186204
)
205+
187206
if (
188207
channel_reshape or nchw_to_nhwc
189-
) and AnnotateChannelsLastDimOrder.memory_format_differs(output_shape):
190-
AnnotateChannelsLastDimOrder.insert_output_transpose(node, graph_module)
208+
) and ToTosaMemoryFormatPass.memory_format_differs(output_shape):
209+
210+
ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module)
191211

192212
def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
193213
"""
@@ -205,9 +225,10 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
205225
for node in graph_module.graph.nodes:
206226
# call_function and placeholder allowed due to
207227
# index.Tensor being able to come in as both
208-
if node.op not in ["call_function", "placeholder"]:
228+
if node.op not in ["call_function", "placeholder", "output"]:
209229
continue
210230

231+
# Transpose views
211232
elif node.target in (
212233
exir_ops.edge.aten.view_copy.default,
213234
exir_ops.edge.aten.index.Tensor,
@@ -218,25 +239,48 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
218239
input_node = node.args[0]
219240
input_shape = input_node.meta["val"].shape
220241
output_shape = node.meta["val"].shape
221-
222242
self._insert_view_transpose(
223-
input_shape, output_shape, node, input_node, graph_module
243+
input_shape,
244+
output_shape,
245+
node,
246+
input_node,
247+
graph_module,
224248
)
225249

250+
# Transpose inputs
251+
elif _is_input(node, self.exported_program):
252+
input_shape = get_first_fake_tensor(node).size()
253+
if len(input_shape) in (4, 5):
254+
ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module)
255+
256+
# Transpose outputs
257+
elif node.op == "output":
258+
output_shape = get_first_fake_tensor(node).size()
259+
260+
if len(output_shape) in (4, 5):
261+
for input_node in node.all_input_nodes:
262+
ToTosaMemoryFormatPass.insert_input_transpose(
263+
node, input_node, graph_module
264+
)
265+
226266
def call(self, graph_module: torch.fx.GraphModule):
227267
for node in graph_module.graph.nodes:
228268
node_data = get_first_fake_tensor(node).data
229269

230-
if node_data.dim() == 4:
270+
# Inputs and outputs are always in (N)NCHW format
271+
if _is_input(node, self.exported_program) or node.op == "output":
272+
dim_order = tuple(range(node_data.dim()))
273+
elif node_data.dim() == 4:
231274
dim_order = self.NHWC_order
232275
if self.is_weight_node_for_depthwise_conv2d(node):
233276
# The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to
234277
# dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
235278
dim_order = self.HWCM_order
236279
elif node_data.dim() == 5:
237-
dim_order = self.NNHWC_order # type: ignore[assignment]
280+
dim_order = self.NNHWC_order
238281
else:
239282
dim_order = tuple(range(node_data.dim())) # type: ignore[assignment]
283+
240284
node.meta["tosa_dim_order"] = dim_order
241285
# Insert TOSA transposes to convert between (N)NCHW and (N)NHWC format.
242286
# See insert_tosa_transposes for insertion conditions.

backends/arm/operators/op_transpose.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,14 @@ def define_node(
4747
validate_valid_dtype(
4848
self.target,
4949
[inputs[0], output],
50-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
50+
[
51+
ts.DType.INT8,
52+
ts.DType.INT16,
53+
ts.DType.INT32,
54+
ts.DType.FP32,
55+
ts.DType.BOOL,
56+
ts.DType.FP16,
57+
],
5158
output.tosa_spec,
5259
)
5360

0 commit comments

Comments
 (0)