Skip to content

Commit fd8bb35

Browse files
Martin LindströmMartin Lindström
authored andcommitted
Arm backend: Convert ExportPasses to ArmPasses
All passes in backends/arm/_passes have been set to inherit from `ArmPass` instead of `ExportPass`. Furthermore, the attribute `exported_program` in `ArmPass` has been removed to instead let each subclasses set it. The reason is that it was type annotated as an `Optional[ExportedProgram]]` in `ArmPass` and Mypy complained about that when a subclass tried to access it without checking for `None`. By moving the attribute down to each subclass that needs it, the confusion around whether the value is `None` or not is elimininated. Signed-off-by: Martin Lindström <[email protected]> Change-Id: Ic73ecbeca255dfca74e23b4ce422dc06a094a058
1 parent ce6e2cf commit fd8bb35

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+222
-128
lines changed

backends/arm/_passes/_debug_passes.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import torch
7+
from executorch.backends.arm._passes import ArmPass
78
from executorch.devtools.visualization.visualization_utils import visualize_graph
89
from executorch.exir import ExportedProgram
9-
from executorch.exir.pass_base import ExportPass, PassResult
10+
from executorch.exir.pass_base import PassResult
1011

1112

12-
class VisualizePass(ExportPass):
13+
class VisualizePass(ArmPass):
1314
"""
1415
This pass visualizes the graph at the point of insertion in the pass manager
1516
"""

backends/arm/_passes/add_bias_pass.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
1111
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
1212
from executorch.backends.transforms.utils import create_constant_placeholder
13+
from executorch.exir import ExportedProgram
1314

1415
from executorch.exir.dialects._ops import ops as exir_ops
1516
from executorch.exir.pass_base import ExportPass, PassResult
@@ -26,6 +27,10 @@ class AddBiasPass(ArmPass):
2627

2728
targeted_ops = (exir_ops.edge.aten.convolution.default,)
2829

30+
def __init__(self, exported_program: ExportedProgram) -> None:
31+
super().__init__()
32+
self.exported_program = exported_program
33+
2934
def call(self, graph_module):
3035
modified = False
3136
for node in graph_module.graph.nodes:

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import cast, List, Set, Type
1111

1212
import torch
13+
from executorch.backends.arm._passes.arm_pass import ArmPass
1314
from executorch.backends.arm._passes.arm_pass_utils import create_node
1415
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1516
FoldAndAnnotateQParamsPass,
@@ -23,7 +24,7 @@
2324
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
2425

2526

26-
class AnnotateDecomposedMatmulPass(ExportPass):
27+
class AnnotateDecomposedMatmulPass(ArmPass):
2728
"""
2829
torch.matmul and it's equivalent operator @ can be decomposed in many ways, for instance:
2930
dq -> matmul -> q can become

backends/arm/_passes/arm_pass.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,12 @@
99
from abc import abstractmethod
1010
from typing import List, Optional, Set, Type
1111

12-
import torch
1312
from executorch.exir.pass_base import ExportPass, NodeMetadata
1413

1514

1615
class ArmPass(ExportPass):
1716
"""Base class for Arm passes"""
1817

19-
def __init__(self, exported_program: Optional[torch.export.ExportedProgram] = None):
20-
super(ArmPass, self).__init__()
21-
self.exported_program = exported_program
22-
2318
@property
2419
@abstractmethod
2520
def _passes_required_after(self) -> Set[Type[ExportPass]]:

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
207207
# needs to happen before AddBiasPass, but after the table ops are inserted
208208
# to be able to validate that conv2d has right dtype arguments.
209209
self.add_pass(DecomposeConv2dWithInt16ActivationPass())
210-
self.add_pass(RewriteUpsamplePass(exported_program))
210+
self.add_pass(RewriteUpsamplePass())
211211
self.add_pass(AddBiasPass(exported_program))
212212

213213
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
@@ -292,7 +292,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
292292
self.add_pass(FuseViewCopyTransform())
293293
self.add_pass(FuseConstantArgsPass(exported_program))
294294
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
295-
self.add_pass(RewriteUpsamplePass(exported_program))
295+
self.add_pass(RewriteUpsamplePass())
296296
self.add_pass(AddBiasPass(exported_program))
297297
self.add_pass(InsertTableOpsPass(exported_program))
298298
self.add_pass(FuseEqualPlaceholdersPass(exported_program))

backends/arm/_passes/cast_bool_to_int8_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010

1111
import torch
1212

13+
from executorch.backends.arm._passes.arm_pass import ArmPass
1314
from executorch.exir.dialects._ops import ops as exir_ops
1415
from executorch.exir.pass_base import ExportPass
1516

1617

17-
class CastBoolToInt8Pass(ExportPass):
18+
class CastBoolToInt8Pass(ArmPass):
1819
"""Casts the input to int8 if it is not already and casts back the output to the original input dtype."""
1920

2021
_passes_required_after: Set[Type[ExportPass]] = set()

backends/arm/_passes/cast_int64_pass.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,23 @@
99
from typing import Set, Type
1010

1111
import torch
12+
from executorch.backends.arm._passes.arm_pass import ArmPass
1213
from executorch.exir.pass_base import ExportPass, PassResult
1314
from torch._export.utils import is_buffer
15+
from torch.export import ExportedProgram
1416

1517
logger = logging.getLogger(__name__)
1618

1719

18-
class CastInt64BuffersToInt32Pass(ExportPass):
20+
class CastInt64BuffersToInt32Pass(ArmPass):
1921
"""
2022
Cast int64 buffers to int32 if the int64 data is in int32 range.
2123
"""
2224

2325
_passes_required_after: Set[Type[ExportPass]] = set()
2426

25-
def __init__(self, exported_program: torch.export.ExportedProgram):
26-
super(CastInt64BuffersToInt32Pass, self).__init__()
27+
def __init__(self, exported_program: ExportedProgram):
28+
super().__init__()
2729
self.exported_program = exported_program
2830

2931
def _assert_within_int32(self, tensor: torch.Tensor, node: torch.fx.Node):

backends/arm/_passes/cast_to_int32_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88
import torch
99

10+
from executorch.backends.arm._passes.arm_pass import ArmPass
1011
from executorch.exir.dialects._ops import ops as exir_ops
1112
from executorch.exir.pass_base import ExportPass
1213

1314

14-
class CastToInt32Pass(ExportPass):
15+
class CastToInt32Pass(ArmPass):
1516
"""Casts the input to int32 if it is not already and casts back the output to the original input dtype."""
1617

1718
_passes_required_after: Set[Type[ExportPass]] = set()

backends/arm/_passes/conv1d_unsqueeze_pass.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@
88

99
from typing import Set, Type
1010

11+
from executorch.backends.arm._passes import ArmPass
12+
1113
from executorch.backends.arm._passes.add_bias_pass import AddBiasPass
1214
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass
1315

1416
from executorch.exir.dialects._ops import ops as exir_ops
1517
from executorch.exir.pass_base import ExportPass
1618

1719

18-
class Conv1dUnsqueezePass(ExportPass):
20+
class Conv1dUnsqueezePass(ArmPass):
1921
"""
2022
This pass is used to change conv1d ops into conv2d since TOSA only
2123
supports 2d and 3d convolution. This is done by modifying the graph to do the
@@ -38,7 +40,11 @@ def call_operator(self, op, args, kwargs, meta):
3840
x = args[0]
3941
x_unsqueezed_shape = list(x.data.shape) + [1]
4042
x = super().call_operator(
41-
exir_ops.edge.aten.view_copy.default, (x, x_unsqueezed_shape), {}, meta
43+
exir_ops.edge.aten.view_copy.default,
44+
(x, x_unsqueezed_shape),
45+
{},
46+
meta,
47+
updated=True,
4248
)
4349

4450
w_meta = meta.copy()
@@ -48,7 +54,11 @@ def call_operator(self, op, args, kwargs, meta):
4854
w = args[1]
4955
w_unsqueezed_shape = list(w.data.shape) + [1]
5056
w = super().call_operator(
51-
exir_ops.edge.aten.view_copy.default, (w, w_unsqueezed_shape), {}, w_meta
57+
exir_ops.edge.aten.view_copy.default,
58+
(w, w_unsqueezed_shape),
59+
{},
60+
w_meta,
61+
updated=True,
5262
)
5363

5464
new_args = (
@@ -63,12 +73,16 @@ def call_operator(self, op, args, kwargs, meta):
6373
args[8],
6474
)
6575
x = super().call_operator(
66-
exir_ops.edge.aten.convolution.default, new_args, kwargs, meta
76+
exir_ops.edge.aten.convolution.default, new_args, kwargs, meta, updated=True
6777
)
6878

6979
x_squeezed_shape = list(x.data.shape)[:-1]
7080
x = super().call_operator(
71-
exir_ops.edge.aten.view_copy.default, (x, x_squeezed_shape), {}, meta
81+
exir_ops.edge.aten.view_copy.default,
82+
(x, x_squeezed_shape),
83+
{},
84+
meta,
85+
updated=True,
7286
)
7387

7488
return x

backends/arm/_passes/convert_any_default_dim_dims_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Set, Type
77

88
import torch
9+
from executorch.backends.arm._passes import ArmPass
910
from executorch.backends.arm._passes.convert_squeezes_to_view import (
1011
ConvertSqueezesToViewPass,
1112
)
@@ -18,7 +19,7 @@
1819
)
1920

2021

21-
class ConvertAnyDefaultDimDimsPass(ExportPass):
22+
class ConvertAnyDefaultDimDimsPass(ArmPass):
2223
"""
2324
Converts any.default, any.dim and any.dims to a sequence of any.dim by unrolling multi-dimensional reduction.
2425
Please refer to KeepDimsFalseToSqueezePass for an explanation of this coversion.

0 commit comments

Comments
 (0)