Skip to content

Commit c571400

Browse files
Martin LindströmMartin Lindström
authored andcommitted
Arm backend: Sort passes in transform_for_annotation_pipeline
The passes listed in ArmPassManager.transform_for_annotation_pipeline can feel a bit arbitrary because there is no clearly intended structure or pattern being applied there. Restructure the list into clearly labelled blocks to make the code easier to read and maintain. Signed-off-by: Martin Lindström <[email protected]> Change-Id: I8b94e65d5709474303ff1f7b1fc98e181d27c5ae
1 parent 6de1f4e commit c571400

File tree

2 files changed

+43
-28
lines changed

2 files changed

+43
-28
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,14 @@ def transform_to_backend_pipeline(
293293
)
294294

295295
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
296+
# Preprocessing passes
297+
296298
self.add_pass(
297299
RemoveGraphAssertsPass()
298300
) # ConvertInt64ConstOpsToInt32Pass requires this pass to remove the assertation in Graph
301+
302+
# Transformation passes (pre scalar -> tensor)
303+
299304
self.add_pass(ConvertInt64ConstOpsToInt32Pass())
300305
self.add_pass(ConvertInt64OutputOpsToInt32Pass())
301306
self.add_pass(InsertInt32CastsAfterInt64PlaceholdersPass())
@@ -306,12 +311,18 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
306311
self.add_pass(CastBoolToInt8Pass())
307312
self.add_pass(DecomposeSignPass())
308313
self.add_pass(DecomposeAddmmPass())
309-
self.add_pass(ReplaceScalarWithTensorByProfilePass())
310314
self.add_pass(DecomposeRemainderPass())
311315
self.add_pass(DecomposeFloorDividePass())
312316
self.add_pass(DecomposeDivTensorModePass())
313-
self.add_pass(DecomposeAddSubAlphaPass())
317+
318+
# Scalars -> tensors
319+
320+
self.add_pass(ReplaceScalarWithTensorByProfilePass())
314321
self.add_pass(ScalarsToAttributePass())
322+
323+
# Transformation passes (post scalar removal)
324+
325+
self.add_pass(DecomposeAddSubAlphaPass())
315326
self.add_pass(DecomposeGroupNormPass())
316327
self.add_pass(DecomposeLayerNormPass())
317328
self.add_pass(DecomposeVarPass())
@@ -325,16 +336,16 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
325336
self.add_pass(DecomposeSqrtPass())
326337
self.add_pass(DecomposeSiluPass())
327338
self.add_pass(DecomposeAvgPool2d())
328-
329339
if self.tosa_spec.is_U55_subset:
330340
# Numerically stable softmax uses amax which is not supported on Ethos-U55
331341
self.add_pass(DecomposeSoftmaxUnstablePass())
332342
else:
333343
self.add_pass(DecomposeSoftmaxPass())
334-
335344
self.add_pass(ConvertMinMaxPass())
336-
self.add_pass(ReplaceInfValues())
337345

346+
# Postprocessing passes
347+
348+
self.add_pass(ReplaceInfValues())
338349
if not self.tosa_spec.is_U55_subset:
339350
# Uses where which is not supported on Ethos-U55
340351
self.add_pass(DecomposeMaskedFill())

backends/arm/_passes/decompose_remainder_pass.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from typing import Set, Type
6+
from typing import Dict, Set, Type
77

88
import torch
99
from executorch.backends.arm._passes import ArmPass
@@ -17,46 +17,50 @@
1717

1818
Op = OpOverload | EdgeOpOverload
1919

20-
21-
def _get_remainder_decomposition_ops(op: Op) -> tuple[Op, Op, Op]:
22-
"""
23-
Returns the (div_mode_op, mul_op, sub_op) needed to lower the provided
24-
remainder operator. The concrete ops depend on whether the remainder op is
25-
the aten or edge variant.
26-
"""
27-
if op == exir_ops.edge.aten.remainder.Tensor:
28-
return (
29-
exir_ops.edge.aten.div.Tensor_mode,
30-
exir_ops.edge.aten.mul.Tensor,
31-
exir_ops.edge.aten.sub.Tensor,
32-
)
33-
if op == torch.ops.aten.remainder.Tensor:
34-
return (
35-
torch.ops.aten.div.Tensor_mode,
36-
torch.ops.aten.mul.Tensor,
37-
torch.ops.aten.sub.Tensor,
38-
)
39-
raise RuntimeError(f"Can't get remainder decomposition ops for op {op}")
20+
_decomposition_ops: Dict[Op, tuple[Op, Op, Op]] = {
21+
exir_ops.edge.aten.remainder.Scalar: (
22+
exir_ops.edge.aten.div.Tensor_mode,
23+
exir_ops.edge.aten.mul.Scalar,
24+
exir_ops.edge.aten.sub.Tensor,
25+
),
26+
torch.ops.aten.remainder.Tensor: (
27+
torch.ops.aten.div.Tensor_mode,
28+
torch.ops.aten.mul.Tensor,
29+
torch.ops.aten.sub.Tensor,
30+
),
31+
torch.ops.aten.remainder.Scalar: (
32+
torch.ops.aten.div.Tensor_mode,
33+
torch.ops.aten.mul.Scalar,
34+
torch.ops.aten.sub.Tensor,
35+
),
36+
exir_ops.edge.aten.remainder.Tensor: (
37+
exir_ops.edge.aten.div.Tensor_mode,
38+
exir_ops.edge.aten.mul.Tensor,
39+
exir_ops.edge.aten.sub.Tensor,
40+
),
41+
}
4042

4143

4244
class DecomposeRemainderPass(ArmPass):
4345
"""
4446
Decompose the remainder operation into primitive arithmetic:
4547
remainder(x, y) -> x - floor_div(x, y) * y
46-
where floor_div(x, y) == div(x, y, rounding_mode=\"floor\").
48+
where floor_div(x, y) == div(x, y, rounding_mode="floor").
4749
"""
4850

4951
_passes_required_after: Set[Type[ExportPass]] = {DecomposeDivTensorModePass}
5052

5153
def call_operator(self, op, args, kwargs, meta, updated=False):
5254
supported_ops = (
55+
exir_ops.edge.aten.remainder.Scalar,
5356
exir_ops.edge.aten.remainder.Tensor,
57+
torch.ops.aten.remainder.Scalar,
5458
torch.ops.aten.remainder.Tensor,
5559
)
5660
if op not in supported_ops:
5761
return super().call_operator(op, args, kwargs, meta, updated)
5862

59-
div_op, mul_op, sub_op = _get_remainder_decomposition_ops(op)
63+
div_op, mul_op, sub_op = _decomposition_ops[op]
6064
x, y = args[0], args[1]
6165

6266
floor_div = super().call_operator(

0 commit comments

Comments
 (0)