Skip to content

Commit 8da822c

Browse files
martinlsmAdrianLundellMartin Lindström
authored
Arm backend: Add pass order validation to ArmPassManager (#14148)
Introduce a mechanism to enforce required ordering of passes in ArmPassManager. Each ArmPass must now declare which passes are required to run after it, ensuring ordering constraints are always upheld. This prevents accidental breakage when modifying pass ordering in the manager. Ordering constraints are verified by the new method ArmPass.validate_constraints_mandatory. We considered reusing torch.fx.passes.infra.pass_manager.PassManager.validate_constraints, but that utility only checks pairwise ordering and cannot enforce that a pass is actually run, which did not meet our needs. This patch only implements the mechanism and tests for it. Defining the actual pass orderings are done in a later patch. ### Test plan The change comes with added unit tests in backends/arm/test/misc/test_pass_required_order.py Signed-off-by: Adrian Lundell <[email protected]> Signed-off-by: Martin Lindström <[email protected]> Co-authored-by: Adrian Lundell <[email protected]> Co-authored-by: Martin Lindström <[email protected]>
1 parent 246009b commit 8da822c

File tree

81 files changed

+489
-24
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+489
-24
lines changed

backends/arm/_passes/add_bias_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from typing import Set, Type
7+
68
import torch
79
from executorch.backends.arm._passes import ArmPass
810
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
911
from executorch.backends.transforms.utils import create_constant_placeholder
1012

1113
from executorch.exir.dialects._ops import ops as exir_ops
12-
from executorch.exir.pass_base import PassResult
14+
from executorch.exir.pass_base import ExportPass, PassResult
1315
from torch.export.graph_signature import InputKind
1416

1517

@@ -19,6 +21,8 @@ class AddBiasPass(ArmPass):
1921
The bias is set to zero.
2022
"""
2123

24+
_passes_required_after: Set[Type[ExportPass]] = set()
25+
2226
targeted_ops = (exir_ops.edge.aten.convolution.default,)
2327

2428
def call(self, graph_module):

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import itertools
99
import operator
10-
from typing import cast, List
10+
from typing import cast, List, Set, Type
1111

1212
import torch
1313
from executorch.backends.arm._passes.arm_pass_utils import create_node
@@ -29,6 +29,8 @@ class AnnotateDecomposedMatmulPass(ExportPass):
2929
matmul-op (can be mm or bmm).
3030
"""
3131

32+
_passes_required_after: Set[Type[ExportPass]] = set()
33+
3234
def _match_partition_to_node(
3335
self, node: torch.fx.Node, partitioned_inputs: List[torch.fx.Node]
3436
) -> torch.fx.Node:

backends/arm/_passes/annotate_output_dim_order_pass.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
7+
from typing import Set, Type
8+
69
from executorch.backends.arm._passes import ArmPass
710
from executorch.backends.arm._passes.arm_pass_utils import get_output_dim_orders
8-
from executorch.exir.pass_base import PassResult
11+
from executorch.exir.pass_base import ExportPass, PassResult
912

1013

1114
class AnnotateOutputDimOrderPass(ArmPass):
@@ -14,6 +17,8 @@ class AnnotateOutputDimOrderPass(ArmPass):
1417
for verifying that the dim order does not change unexpectedly in later passes.
1518
"""
1619

20+
_passes_required_after: Set[Type[ExportPass]] = set()
21+
1722
def call(self, graph_module):
1823
output_node = graph_module.graph.output_node()
1924
output_node.meta["original_dim_orders"] = get_output_dim_orders(graph_module)

backends/arm/_passes/arm_pass.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
# pyre-unsafe
77

88
import traceback
9-
from typing import Optional
9+
from abc import abstractmethod
10+
from typing import List, Optional, Set, Type
1011

1112
import torch
1213
from executorch.exir.pass_base import ExportPass, NodeMetadata
@@ -19,6 +20,36 @@ def __init__(self, exported_program: Optional[torch.export.ExportedProgram] = No
1920
super(ArmPass, self).__init__()
2021
self.exported_program = exported_program
2122

23+
@property
24+
@abstractmethod
25+
def _passes_required_after(self) -> Set[Type[ExportPass]]:
26+
"""The subclass defines passes that must run after it"""
27+
pass
28+
29+
@staticmethod
30+
def get_required_passes(pass_) -> List[str]:
31+
"""
32+
Returns the list of passes that must be run after this pass, sorted by name.
33+
"""
34+
if hasattr(pass_, "_passes_required_after"):
35+
return sorted([ArmPass.get_name(p) for p in pass_._passes_required_after])
36+
else:
37+
return []
38+
39+
@staticmethod
40+
def get_name(pass_) -> str:
41+
"""
42+
Returns the name of the pass.
43+
"""
44+
if isinstance(pass_, ExportPass):
45+
return pass_.__class__.__name__
46+
elif hasattr(pass_, "__name__"):
47+
return pass_.__name__
48+
else:
49+
raise ValueError(
50+
f"Cannot get name for pass: {pass_}. It must be an instance of ExportPass or have a __name__ attribute."
51+
)
52+
2253
def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False):
2354
if not updated:
2455
return super().call_operator(op, args, kwargs, meta)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
# pyre-unsafe
99

10+
11+
from collections import defaultdict
12+
1013
import executorch.backends.arm.tosa.dialect # noqa: unused
1114
from executorch.backends.arm._passes import (
1215
AddBiasPass,
@@ -94,6 +97,7 @@
9497
UnsqueezeScalarPlaceholdersPass,
9598
)
9699

100+
from executorch.backends.arm._passes.arm_pass import ArmPass
97101
from executorch.backends.arm.tosa.specification import (
98102
TosaLoweringContext,
99103
TosaSpecification,
@@ -115,6 +119,32 @@ def __init__(self, tosa_spec: TosaSpecification) -> None:
115119
self.tosa_spec = tosa_spec
116120
super().__init__()
117121

122+
def validate_constraints_mandatory(self):
123+
"""
124+
Validates that necessary passes have run before transforming to backend.
125+
126+
Note that this differs from the original validate_constraints function, which
127+
only checks the order of passes.
128+
"""
129+
passes_to_run = defaultdict(list)
130+
131+
for current_pass in self.passes:
132+
current_pass_name = ArmPass.get_name(current_pass)
133+
for required_pass_name in ArmPass.get_required_passes(current_pass):
134+
passes_to_run[required_pass_name].append(current_pass_name)
135+
136+
passes_to_run.pop(current_pass_name, None)
137+
138+
if len(passes_to_run) > 0:
139+
error_msg = "The following constraints for passes are not met:\n"
140+
for required_pass, requiring_passes in passes_to_run.items():
141+
for requiring_pass in requiring_passes:
142+
error_msg += (
143+
f" - {required_pass} must run after {requiring_pass}\n"
144+
)
145+
146+
raise RuntimeError(error_msg)
147+
118148
def _transform(self, graph_module: GraphModule):
119149
with TosaLoweringContext(self.tosa_spec):
120150
return self(graph_module).graph_module
@@ -125,7 +155,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
125155
self.add_pass(RemoveGetItemPass())
126156
self.add_pass(ConvertSplitToSlicePass())
127157
self.add_pass(ConvertMmToBmmPass())
128-
self.add_pass(DecomposeLinearVectorNormPass())
129158
self.add_pass(
130159
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
131160
)
@@ -175,6 +204,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
175204
self.add_pass(RemoveNoopPass())
176205
self.add_pass(InsertRescalePass())
177206

207+
self.validate_constraints_mandatory()
178208
return self._transform(exported_program.graph_module)
179209

180210
def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
@@ -258,6 +288,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
258288
self.add_pass(RemoveNoopPass())
259289
self.add_pass(InsertRescalePass())
260290

291+
self.validate_constraints_mandatory()
261292
return self._transform(exported_program.graph_module)
262293

263294
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):

backends/arm/_passes/broadcast_args_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from typing import Set, Type
7+
68
from executorch.backends.arm._passes import ArmPass
79

810
from executorch.backends.arm._passes.arm_pass_utils import (
@@ -12,7 +14,7 @@
1214

1315
from executorch.exir.dialects._ops import ops as exir_ops
1416

15-
from executorch.exir.pass_base import PassResult
17+
from executorch.exir.pass_base import ExportPass, PassResult
1618
from torch.fx import GraphModule, Node
1719

1820

@@ -22,6 +24,8 @@ class BroadcastArgsPass(ArmPass):
2224
This is done when more than one arg needs broadcasting.
2325
"""
2426

27+
_passes_required_after: Set[Type[ExportPass]] = set()
28+
2529
targeted_ops = {
2630
exir_ops.edge.aten.add.Tensor,
2731
exir_ops.edge.aten.sub.Tensor,

backends/arm/_passes/cast_bool_to_int8_pass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
# The TOSA BITWISE_AND, BITWISE_OR, and BITWISE_XOR don't handle bool as input
77
# If input/output is bool lest add a cast/conversion pass before/after to/from int8.
88

9+
from typing import Set, Type
10+
911
import torch
1012

1113
from executorch.exir.dialects._ops import ops as exir_ops
@@ -15,6 +17,8 @@
1517
class CastBoolToInt8Pass(ExportPass):
1618
"""Casts the input to int8 if it is not already and casts back the output to the original input dtype."""
1719

20+
_passes_required_after: Set[Type[ExportPass]] = set()
21+
1822
targeted_ops = {
1923
exir_ops.edge.aten.bitwise_and.Tensor,
2024
exir_ops.edge.aten.bitwise_or.Tensor,

backends/arm/_passes/cast_int64_pass.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# pyre-unsafe
77

88
import logging
9+
from typing import Set, Type
910

1011
import torch
1112
from executorch.exir.pass_base import ExportPass, PassResult
@@ -19,6 +20,8 @@ class CastInt64BuffersToInt32Pass(ExportPass):
1920
Cast int64 buffers to int32 if the int64 data is in int32 range.
2021
"""
2122

23+
_passes_required_after: Set[Type[ExportPass]] = set()
24+
2225
def __init__(self, exported_program: torch.export.ExportedProgram):
2326
super(CastInt64BuffersToInt32Pass, self).__init__()
2427
self.exported_program = exported_program

backends/arm/_passes/cast_to_int32_pass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from typing import Set, Type
7+
68
import torch
79

810
from executorch.exir.dialects._ops import ops as exir_ops
@@ -12,6 +14,8 @@
1214
class CastToInt32Pass(ExportPass):
1315
"""Casts the input to int32 if it is not already and casts back the output to the original input dtype."""
1416

17+
_passes_required_after: Set[Type[ExportPass]] = set()
18+
1519
targeted_ops = {
1620
exir_ops.edge.aten.bitwise_left_shift.Tensor,
1721
exir_ops.edge.aten.bitwise_right_shift.Tensor,

backends/arm/_passes/conv1d_unsqueeze_pass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
# LICENSE file in the root directory of this source tree.
77

88

9+
from typing import Set, Type
10+
911
from executorch.exir.dialects._ops import ops as exir_ops
1012
from executorch.exir.pass_base import ExportPass
1113

@@ -21,6 +23,8 @@ class Conv1dUnsqueezePass(ExportPass):
2123
3) squeeze the output back down to 3d.
2224
"""
2325

26+
_passes_required_after: Set[Type[ExportPass]] = set()
27+
2428
def call_operator(self, op, args, kwargs, meta):
2529
if op != exir_ops.edge.aten.convolution.default:
2630
return super().call_operator(op, args, kwargs, meta)

0 commit comments

Comments
 (0)