Skip to content

Commit 59b9707

Browse files
Arm backend: Align dim_order ops handling (#14064)
- Removes all use of the _to_copy/clone operator. These ops are lowered to _to_dim_order_copy/ clone_dim_order in the to_edge step when skip_dim_order=False, which we expect. - Remove all dim_order kwargs in the to_tosa_memory_format_pass since we set our own dim_order. - Add a pass for storing the initial output dim_order to verify that it has not changed unexpectedly. - Replace RemoveClonePass with more general RemoveNoopPass which also handles to_dim_order ops casting to the same dtype. Move to after the to_tosa_memory_pass to tidy up after it. - Change the delegation of to_copy_dim_order_support to pick up casts of the same dtype, as they are removed as noops anyways. Minor fixes: - Renames to_copy tests to be picked up by the op name parser. - Moves cast_int64 pass to after compute_constant_ops_aot to assert the buffer exists in .state_dict before accessing it. - Replace fatal xfail for test_CLIPTextModelWithProjection with numerical diff xfail Signed-off-by: Adrian Lundell <[email protected]>
1 parent 396d722 commit 59b9707

21 files changed

+243
-157
lines changed

backends/arm/_passes/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .arm_pass import ArmPass # noqa # usort: skip
99
from .add_bias_pass import AddBiasPass # noqa
1010
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
11+
from .annotate_output_dim_order_pass import AnnotateOutputDimOrderPass # noqa
1112
from .broadcast_args_pass import BroadcastArgsPass # noqa
1213
from .cast_bool_to_int8_pass import CastBoolToInt8Pass # noqa
1314
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
@@ -82,7 +83,7 @@
8283
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
8384
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
8485
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
85-
from .remove_clone_pass import RemoveClonePass # noqa
86+
from .remove_noop_pass import RemoveNoopPass # noqa
8687
from .replace_scalar_with_tensor_pass import ( # noqa
8788
ReplaceScalarWithTensorArgPassTOSABI,
8889
ReplaceScalarWithTensorArgPassTOSAMI,
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from executorch.backends.arm._passes import ArmPass
7+
from executorch.backends.arm._passes.arm_pass_utils import get_output_dim_orders
8+
from executorch.exir.pass_base import PassResult
9+
10+
11+
class AnnotateOutputDimOrderPass(ArmPass):
12+
"""
13+
Stores the current output dim_orders in the meta dict of the output node. This is used
14+
for verifying that the dim order does not change unexpectedly in later passes.
15+
"""
16+
17+
def call(self, graph_module):
18+
output_node = graph_module.graph.output_node()
19+
output_node.meta["original_dim_orders"] = get_output_dim_orders(graph_module)
20+
21+
return PassResult(graph_module, True)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from executorch.backends.arm._passes import (
1212
AddBiasPass,
1313
AnnotateDecomposedMatmulPass,
14+
AnnotateOutputDimOrderPass,
1415
BroadcastArgsPass,
1516
CastBoolToInt8Pass,
1617
CastInt64BuffersToInt32Pass,
@@ -81,7 +82,7 @@
8182
MatchArgDtypePass,
8283
MatchArgRanksPass,
8384
QuantizeOperatorArguments,
84-
RemoveClonePass,
85+
RemoveNoopPass,
8586
ReplaceInfValues,
8687
ReplaceScalarWithTensorArgPassTOSABI,
8788
ReplaceScalarWithTensorArgPassTOSAMI,
@@ -119,6 +120,7 @@ def _transform(self, graph_module: GraphModule):
119120
return self(graph_module).graph_module
120121

121122
def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
123+
self.add_pass(AnnotateOutputDimOrderPass())
122124
self.add_pass(FuseQuantizedActivationPass())
123125
self.add_pass(RemoveGetItemPass())
124126
self.add_pass(ConvertSplitToSlicePass())
@@ -152,7 +154,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
152154
self.add_pass(ComputeConstantOpsAOT(exported_program))
153155

154156
self.add_pass(DecomposeGroupedConv())
155-
self.add_pass(RemoveClonePass())
156157
self.add_pass(ConvertExpandCopyToRepeatPass())
157158
self.add_pass(UnsqueezeBeforeRepeatPass())
158159
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
@@ -171,11 +172,13 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
171172
self.add_pass(InsertTableOpsPass(exported_program))
172173
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
173174
self.add_pass(ToTosaMemoryFormatPass(exported_program))
175+
self.add_pass(RemoveNoopPass())
174176
self.add_pass(InsertRescalePass())
175177

176178
return self._transform(exported_program.graph_module)
177179

178180
def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
181+
self.add_pass(AnnotateOutputDimOrderPass())
179182
self.add_pass(DecomposeExpm1Pass())
180183
self.add_pass(DecomposeLogitPass())
181184
self.add_pass(DecomposeMaskedFill())
@@ -235,10 +238,8 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
235238
self.add_pass(ComputeConstantOpsAOT(exported_program))
236239

237240
self.add_pass(DecomposeGroupedConv())
238-
self.add_pass(RemoveClonePass())
239241
self.add_pass(ConvertExpandCopyToRepeatPass())
240242
self.add_pass(UnsqueezeBeforeRepeatPass())
241-
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
242243
self.add_pass(DecomposeSumPass())
243244
self.add_pass(DecomposeCumsumPass(exported_program))
244245
self.add_pass(Conv1dUnsqueezePass())
@@ -249,10 +250,12 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
249250

250251
self.add_pass(FuseViewCopyTransform())
251252
self.add_pass(FuseConstantArgsPass(exported_program))
253+
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
252254
self.add_pass(AddBiasPass(exported_program))
253255
self.add_pass(InsertTableOpsPass(exported_program))
254256
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
255257
self.add_pass(ToTosaMemoryFormatPass(exported_program))
258+
self.add_pass(RemoveNoopPass())
256259
self.add_pass(InsertRescalePass())
257260

258261
return self._transform(exported_program.graph_module)

backends/arm/_passes/arm_pass_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,8 @@ def set_node_arg(node: torch.fx.Node, i: int | str, value):
235235
node.kwargs = kwargs
236236
else:
237237
raise RuntimeError("Invalid type")
238+
239+
240+
def get_output_dim_orders(graph_module):
241+
output_node = graph_module.graph.output_node()
242+
return [get_first_fake_tensor(node).dim_order() for node in output_node.args[0]]

backends/arm/_passes/convert_int64_output_ops_to_int32.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,10 @@ class ConvertInt64OutputOpsToInt32Pass(ExportPass):
6868

6969
def _get_decomposition(self, op):
7070
if op in self.edge_ops:
71-
return exir_ops.edge.aten._to_copy.default
71+
return exir_ops.edge.dim_order_ops._to_dim_order_copy.default
7272

7373
if op in self.aten_ops:
74-
return torch.ops.aten._to_copy.default
74+
return torch.ops.dim_order_ops._to_dim_order_copy.default
7575

7676
raise RuntimeError(
7777
f"[{self.__class__.__name__}] Can't get decomposition for op {op}"

backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,17 @@ class DecorateFp32toInt32CastingPass(ArmPass):
3030
To lower pytorch fp32 -> int32 casting to TOSA,
3131
we need to transform the value with Ceil, Floor, and Where.
3232
Before:
33-
output = to_copy(x, dtype=torch.int32)
33+
output = to_dim_order_copy(x, dtype=torch.int32)
3434
After:
3535
%zero = full((1,), 0.0, dtype=torch.float32)
3636
is_non_negative = x >= %zero
3737
floor_x = floor(x)
3838
ceil_x = ceil(x)
3939
decorated_x = where(is_non_negative, floor_x, ceil_x)
40-
output = to_copy(decorated_x, dtype=torch.int32)
40+
output = to_dim_order_copy(decorated_x, dtype=torch.int32)
4141
"""
4242

4343
targets = [
44-
exir_ops.edge.aten._to_copy.default,
4544
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
4645
]
4746

backends/arm/_passes/insert_int64_input_cast_pass.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ class InsertCastForOpsWithInt64InputPass(ExportPass):
3131

3232
def get_decomposition(self, op):
3333
if op in self.edge_ops:
34-
return exir_ops.edge.aten._to_copy.default
34+
return exir_ops.edge.dim_order_ops._to_dim_order_copy.default
3535

3636
if op in self.aten_ops:
37-
return torch.ops.aten._to_copy.default
37+
return torch.ops.dim_order_ops._to_dim_order_copy.default
3838

3939
raise RuntimeError(
4040
f"[{self.__class__.__name__}] Can't get decomposition for op {op}"
@@ -56,15 +56,14 @@ def _check_aten_embedding_within_int32(self, weights, indices, node: torch.fx.No
5656
return True
5757

5858
def _insert_int32_cast_before_node(self, graph, node, original_input):
59-
to_copy_op = self.get_decomposition(node.target)
59+
to_dim_order_copy_op = self.get_decomposition(node.target)
6060
with graph.inserting_before(node):
6161
cast_before = create_node(
6262
graph,
63-
to_copy_op,
63+
to_dim_order_copy_op,
6464
args=(original_input,),
6565
kwargs={
6666
"dtype": torch.int32,
67-
"memory_format": torch.preserve_format,
6867
},
6968
)
7069
node.replace_input_with(original_input, cast_before)

backends/arm/_passes/remove_clone_pass.py renamed to backends/arm/_passes/remove_noop_pass.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,20 @@
1414
logger = logging.getLogger(__name__)
1515

1616

17-
class RemoveClonePass(ExportPass):
18-
"""Remove all clones from graph_module"""
17+
class RemoveNoopPass(ExportPass):
18+
"""Remove no-ops from graph_module"""
1919

2020
def call_operator(self, op, args, kwargs, meta):
21-
if op != exir_ops.edge.dim_order_ops._clone_dim_order.default:
21+
if op not in (
22+
exir_ops.edge.dim_order_ops._clone_dim_order.default,
23+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
24+
):
2225
return super().call_operator(op, args, kwargs, meta)
2326

24-
if len(args) != 1:
25-
raise ValueError(
26-
f"clone operator expects exactly one argument, got {len(args)}"
27-
)
27+
input_dtype = args[0].data.dtype
28+
output_dtype = kwargs.get("dtype", input_dtype)
2829

29-
if "memory_format" in kwargs:
30-
logger.warning(
31-
f"Removing clone with memory_format '{kwargs['memory_format']}'."
32-
)
30+
if input_dtype != output_dtype:
31+
return super().call_operator(op, args, kwargs, meta)
3332

3433
return args[0]

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,22 @@
66
# pyre-unsafe
77

88

9+
import logging
10+
911
import torch
12+
from executorch.backends.arm._passes import AnnotateOutputDimOrderPass
1013
from executorch.backends.arm._passes.arm_pass_utils import (
1114
create_node,
1215
get_first_fake_tensor,
16+
get_output_dim_orders,
1317
is_param_node,
1418
)
1519
from executorch.exir import ExportedProgram
1620
from executorch.exir.dialects._ops import ops as exir_ops
1721
from executorch.exir.pass_base import ExportPass, PassResult
1822

23+
logger = logging.getLogger(__name__)
24+
1925

2026
def _is_input(node: torch.fx.Node, exported_program: ExportedProgram) -> bool:
2127
"""
@@ -250,10 +256,27 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
250256
node, input_node, graph_module
251257
)
252258

259+
def remove_dim_order_kwargs(
260+
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node
261+
):
262+
if node.op != "call_function":
263+
return
264+
265+
kwargs = dict(node.kwargs)
266+
267+
if "dim_order" in kwargs:
268+
logger.warning(
269+
f"Ignoring dim_order kwarg '{kwargs['dim_order']}' for '{node.name}'."
270+
)
271+
del kwargs["dim_order"]
272+
273+
node.kwargs = kwargs
274+
253275
def call(self, graph_module: torch.fx.GraphModule):
254276
for node in graph_module.graph.nodes:
255277
node_data = get_first_fake_tensor(node).data
256278

279+
self.remove_dim_order_kwargs(graph_module, node)
257280
# Inputs and outputs are always in (N)NCHW format
258281
if _is_input(node, self.exported_program) or node.op == "output":
259282
dim_order = tuple(range(node_data.dim()))
@@ -269,10 +292,40 @@ def call(self, graph_module: torch.fx.GraphModule):
269292
dim_order = tuple(range(node_data.dim())) # type: ignore[assignment]
270293

271294
node.meta["tosa_dim_order"] = dim_order
295+
272296
# Insert TOSA transposes to convert between (N)NCHW and (N)NHWC format.
273297
# See insert_tosa_transposes for insertion conditions.
274298
self.insert_tosa_transposes(graph_module)
275299
graph_module.recompile()
276300
graph_module = super().call(graph_module).graph_module
277301

278302
return PassResult(graph_module, True)
303+
304+
def requires(self, graph_module) -> None:
305+
"""
306+
This is the only pass which handles dim_orders, so verify that the output dim_orders has not changed since the beginning of the lowering pipeline.
307+
"""
308+
309+
dim_orders = get_output_dim_orders(graph_module)
310+
original_dim_orders = graph_module.graph.output_node().meta.get(
311+
"original_dim_orders"
312+
)
313+
output_node = graph_module.graph.output_node()
314+
315+
if original_dim_orders is None:
316+
raise RuntimeError(
317+
f"{AnnotateOutputDimOrderPass.__name__} must be run in the beginning of the pass pipeline to verify that the dim order has not changed unexpectedly during its run."
318+
)
319+
320+
if len(dim_orders) != len(original_dim_orders):
321+
raise RuntimeError(
322+
f"The number of outputs has changed since {AnnotateOutputDimOrderPass.__name__} was run."
323+
)
324+
325+
for node, dim_order, original_dim_order in zip(
326+
output_node.args[0], dim_orders, original_dim_orders
327+
):
328+
if dim_order != original_dim_order:
329+
raise RuntimeError(
330+
f"The dim order of output {node.name} has changed from {original_dim_order} to {dim_order} since {AnnotateOutputDimOrderPass.__name__} was run."
331+
)

backends/arm/operator_support/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# pyre-unsafe
77

88
from . import ( # noqa
9-
clone_support,
9+
clone_dim_order_support,
1010
convolution_support,
1111
embedding_support,
1212
ethos_u55_support,
@@ -18,6 +18,6 @@
1818
right_shift_support,
1919
sin_cos_support,
2020
slice_copy_support,
21-
to_copy_support,
21+
to_dim_order_copy_support,
2222
tosa_supported_operators,
2323
)

0 commit comments

Comments
 (0)