Skip to content

Commit 9602b2e

Browse files
authored
Arm backend: Add chronological dependencies for passes (#14578)
The passes in the Arm backend have an attribute called `_passes_required_after` which is a set specifying which passes must run after the pass itself. This patch sets these dependencies for all the passes. Signed-off-by: Martin Lindstroem <[email protected]>
1 parent d0f486a commit 9602b2e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+363
-87
lines changed

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111

1212
import torch
1313
from executorch.backends.arm._passes.arm_pass_utils import create_node
14+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
15+
FoldAndAnnotateQParamsPass,
16+
)
1417

1518
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1619
from executorch.exir.dialects._ops import ops as exir_ops
@@ -29,7 +32,7 @@ class AnnotateDecomposedMatmulPass(ExportPass):
2932
matmul-op (can be mm or bmm).
3033
"""
3134

32-
_passes_required_after: Set[Type[ExportPass]] = set()
35+
_passes_required_after: Set[Type[ExportPass]] = {FoldAndAnnotateQParamsPass}
3336

3437
def _match_partition_to_node(
3538
self, node: torch.fx.Node, partitioned_inputs: List[torch.fx.Node]

backends/arm/_passes/conv1d_unsqueeze_pass.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
from typing import Set, Type
1010

11+
from executorch.backends.arm._passes.add_bias_pass import AddBiasPass
12+
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass
13+
1114
from executorch.exir.dialects._ops import ops as exir_ops
1215
from executorch.exir.pass_base import ExportPass
1316

@@ -23,7 +26,7 @@ class Conv1dUnsqueezePass(ExportPass):
2326
3) squeeze the output back down to 3d.
2427
"""
2528

26-
_passes_required_after: Set[Type[ExportPass]] = set()
29+
_passes_required_after: Set[Type[ExportPass]] = {AddBiasPass, SizeAdjustInputPass}
2730

2831
def call_operator(self, op, args, kwargs, meta):
2932
if op != exir_ops.edge.aten.convolution.default:

backends/arm/_passes/convert_any_default_dim_dims_pass.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from typing import Set, Type
77

88
import torch
9+
from executorch.backends.arm._passes.convert_squeezes_to_view import (
10+
ConvertSqueezesToViewPass,
11+
)
912
from executorch.exir.dialects._ops import ( # type: ignore[import-not-found]
1013
ops as exir_ops,
1114
)
@@ -46,7 +49,7 @@ class ConvertAnyDefaultDimDimsPass(ExportPass):
4649
squeeze(dim = [dim1, dim2])
4750
"""
4851

49-
_passes_required_after: Set[Type[ExportPass]] = set()
52+
_passes_required_after: Set[Type[ExportPass]] = {ConvertSqueezesToViewPass}
5053

5154
def call(self, graph_module: torch.fx.GraphModule):
5255
modified = False

backends/arm/_passes/convert_expand_copy_to_repeat.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
import torch
1212

13+
from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import (
14+
UnsqueezeBeforeRepeatPass,
15+
)
1316
from executorch.exir.dialects._ops import ops as exir_ops
1417
from executorch.exir.pass_base import ExportPass
1518

@@ -50,7 +53,7 @@ class ConvertExpandCopyToRepeatPass(ExportPass):
5053
Replace expand copy with repeat since it is a repeat that can only repeat singleton dimensions.
5154
"""
5255

53-
_passes_required_after: Set[Type[ExportPass]] = set()
56+
_passes_required_after: Set[Type[ExportPass]] = {UnsqueezeBeforeRepeatPass}
5457

5558
expand_copy = exir_ops.edge.aten.expand_copy.default
5659
repeat = exir_ops.edge.aten.repeat.default

backends/arm/_passes/convert_full_like_to_full_pass.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55

66
from typing import Set, Type
77

8+
from executorch.backends.arm._passes.arm_pass import ArmPass
9+
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
10+
811
from executorch.exir.dialects._ops import ops as exir_ops
912
from executorch.exir.pass_base import ExportPass
1013

1114

12-
class ConvertFullLikeToFullPass(ExportPass):
15+
class ConvertFullLikeToFullPass(ArmPass):
1316
"""As per the full_like pytorch documentation,
1417
`torch.full_like(input, fill_value)` is equivalent to
1518
`torch.full(input.size(),
@@ -21,7 +24,7 @@ class ConvertFullLikeToFullPass(ExportPass):
2124
Skip layout and device since it's not relevant for our backend.
2225
"""
2326

24-
_passes_required_after: Set[Type[ExportPass]] = set()
27+
_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOT}
2528

2629
def call_operator(self, op, args, kwargs, meta):
2730
if op not in [

backends/arm/_passes/convert_int64_const_ops_to_int32.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class ConvertInt64ConstOpsToInt32Pass(ExportPass):
3131
5. `torch.tensor`
3232
"""
3333

34-
_passes_required_after: Set[Type[ExportPass]] = set()
34+
_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOT}
3535

3636
torch_ops = [
3737
torch.ops.aten.full.default,

backends/arm/_passes/convert_minmax_pass.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from typing import Set, Type
77

88
import torch
9+
from executorch.backends.arm._passes.convert_squeezes_to_view import (
10+
ConvertSqueezesToViewPass,
11+
)
912
from executorch.exir.dialects._ops import ops as exir_ops
1013
from executorch.exir.pass_base import ExportPass, PassResult
1114

@@ -31,7 +34,7 @@ class ConvertMinMaxPass(ExportPass):
3134
squeeze(dim = [dim1, dim2])
3235
"""
3336

34-
_passes_required_after: Set[Type[ExportPass]] = set()
37+
_passes_required_after: Set[Type[ExportPass]] = {ConvertSqueezesToViewPass}
3538

3639
def check_argmax(self, node):
3740
"""

backends/arm/_passes/convert_squeezes_to_view.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from typing import Set, Type
1010

11+
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
12+
1113
from executorch.exir.dialects._ops import ops as exir_ops
1214
from executorch.exir.pass_base import ExportPass
1315

@@ -17,7 +19,7 @@ class ConvertSqueezesToViewPass(ExportPass):
1719
Replaces squeeze/unsqueeze operators with view. These are simply special cases of the view op, so removing them gives us less cases to handle in the node visitiors.
1820
"""
1921

20-
_passes_required_after: Set[Type[ExportPass]] = set()
22+
_passes_required_after: Set[Type[ExportPass]] = {FuseViewCopyTransform}
2123

2224
def call_operator(self, op, args, kwargs, meta):
2325
if op not in [

backends/arm/_passes/convert_to_clamp.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55

66
from typing import Set, Tuple, Type
77

8+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
9+
QuantizeOperatorArguments,
10+
)
11+
812
from executorch.exir.dialects._ops import ops as exir_ops
913
from executorch.exir.pass_base import ExportPass
1014

@@ -24,7 +28,7 @@ def get_clamp_params(op, args) -> Tuple[float | None, float | None]:
2428

2529

2630
class ConvertToClampPass(ExportPass):
27-
_passes_required_after: Set[Type[ExportPass]] = set()
31+
_passes_required_after: Set[Type[ExportPass]] = {QuantizeOperatorArguments}
2832

2933
def call_operator(self, op, args, kwargs, meta):
3034
if op not in edge_operators:

backends/arm/_passes/decompose_acosh_pass.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@
88
from typing import Set, Type
99

1010
from executorch.backends.arm._passes import ArmPass
11+
from executorch.backends.arm._passes.decompose_sqrt_pass import DecomposeSqrtPass
12+
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass # noqa
13+
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
14+
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
15+
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
16+
ReplaceScalarWithTensorArgPassTOSAMI,
17+
)
1118
from executorch.exir.dialects._ops import ops as exir_ops
1219
from executorch.exir.pass_base import ExportPass
1320

@@ -22,7 +29,13 @@ class DecomposeAcoshPass(ArmPass):
2229
acosh(x) = log(x + sqrt((x-1)(x+1))
2330
"""
2431

25-
_passes_required_after: Set[Type[ExportPass]] = set()
32+
_passes_required_after: Set[Type[ExportPass]] = {
33+
DecomposeSqrtPass,
34+
InsertTableOpsPass,
35+
MatchArgRanksPass,
36+
ReplaceScalarWithTensorArgPassTOSAMI,
37+
MatchArgDtypePass,
38+
}
2639

2740
def call_operator(self, op, args, kwargs, meta, updated=False):
2841

0 commit comments

Comments
 (0)