Skip to content

Commit ba79dc1

Browse files
Martin LindströmMartin Lindström
authored andcommitted
Arm backend: Sort passes in _tosa_pipeline into blocks
The passes listed in `ArmPassManager._tosa_pipeline` can feel a bit arbitrary because there is no clearly intended structure or pattern being applied there. Restructure the list into clearly labelled blocks to make the code easier to read and maintain. Signed-off-by: Martin Lindström <[email protected]> Change-Id: Iadf37cda2c7a88cad80bf363062d38d492206be7
1 parent acf5b4b commit ba79dc1

File tree

3 files changed

+44
-16
lines changed

3 files changed

+44
-16
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,12 @@ def _transform(self, graph_module: GraphModule):
159159
def _tosa_pipeline(
160160
self, exported_program: ExportedProgram, graph_module: GraphModule
161161
) -> GraphModule:
162+
# Preprocessing passes
163+
162164
self.add_pass(AnnotateOutputDimOrderPass())
165+
166+
# Node transformation passes (pre q/dq folding)
167+
163168
self.add_pass(FuseQuantizedActivationPass())
164169
self.add_pass(RemoveGetItemPass())
165170
self.add_pass(ConvertToClampPass())
@@ -174,8 +179,19 @@ def _tosa_pipeline(
174179
self.add_pass(ConvertELUParamsPass())
175180
self.add_pass(ConvertSplitToSlicePass())
176181
self.add_pass(QuantizeOperatorArguments())
182+
183+
# Fold Q/DQ nodes, insert INT8/INT32 rescales.
184+
177185
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
178186
self.add_pass(FuseDuplicateUsersPass())
187+
# TODO: DecomposeLinearPass should run after InsertRescaleInt32Pass or
188+
# before FoldAndAnnotateQParamsPass but is unable to at the moment.
189+
# Ticket: MLETORCH-1539
190+
self.add_pass(DecomposeLinearPass())
191+
self.add_pass(InsertRescaleInt32Pass())
192+
193+
# Node transformation passes (post q/dq folding)
194+
179195
self.add_pass(DecomposeExpm1Pass())
180196
self.add_pass(DecomposeLogitPass())
181197
self.add_pass(DecomposeMaskedFill())
@@ -196,57 +212,67 @@ def _tosa_pipeline(
196212
self.add_pass(DecomposeSignPass())
197213
self.add_pass(DecomposeFloorDividePass())
198214
self.add_pass(DecomposeDivTensorModePass())
215+
self.add_pass(DecomposeGeluPass())
216+
self.add_pass(DecomposeAddSubAlphaPass())
217+
self.add_pass(DecomposeGroupedConv())
218+
self.add_pass(Conv1dUnsqueezePass())
219+
220+
# Scalars -> tensors, match tensor dtypes and ranks.
221+
199222
self.add_pass(ReplaceScalarWithTensorByProfilePass())
223+
self.add_pass(MatchArgDtypePass())
224+
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
225+
# TODO: Move DecomposeNotEqualPass to before or after this block of
226+
# passes. Ticket: MLETORCH-1540
227+
self.add_pass(DecomposeNotEqualPass())
228+
self.add_pass(MatchArgRanksPass(exported_program))
229+
self.add_pass(FuseConstantArgsPass(exported_program))
230+
231+
# Node transformation passes (post scalar-removal)
232+
200233
self.add_pass(DecomposeRemainderPass())
201234
self.add_pass(DecomposeDivTensorModePass())
202235
self.add_pass(DecomposeEmbeddingPass())
203236
self.add_pass(FuseBatchnorm2DPass(exported_program))
204237
self.add_pass(ConvertMmToBmmPass())
205238
self.add_pass(DecomposeGluPass())
206-
self.add_pass(DecomposeLinearPass())
207239
self.add_pass(DecomposeLeakyReLUPass())
208-
self.add_pass(DecomposeNotEqualPass())
209240
self.add_pass(DecomposeDivPass())
210-
self.add_pass(DecomposeAddSubAlphaPass())
211241
self.add_pass(DecomposeSoftmaxPass())
212-
self.add_pass(DecomposeGeluPass())
213242
self.add_pass(ConvertFullLikeToFullPass())
214243
self.add_pass(ConvertMinMaxPass())
215244
self.add_pass(ConvertAnyDefaultDimDimsPass())
216-
self.add_pass(MatchArgDtypePass())
217-
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
218-
self.add_pass(MatchArgRanksPass(exported_program))
219245
self.add_pass(DecomposeAdaptiveAvgPool2dPass())
220246
self.add_pass(DecomposeAvgPool2d())
221247
self.add_pass(
222248
DecorateFp32toInt32CastingPass()
223249
) # Require that no new fp32->int32 is introduced after this pass
224250
self.add_pass(ComputeConstantOpsAOT(exported_program))
225-
226-
self.add_pass(DecomposeGroupedConv())
227251
self.add_pass(ConvertExpandCopyToRepeatPass())
228252
self.add_pass(UnsqueezeBeforeRepeatPass())
229253
self.add_pass(DecomposeCumsumPass(exported_program))
230-
self.add_pass(Conv1dUnsqueezePass())
231254
self.add_pass(DecomposeMaxPool2DPass())
232255
self.add_pass(SizeAdjustInputPass())
233256
self.add_pass(DecomposeSelectPass())
234257
self.add_pass(ConvertSqueezesToViewPass())
235258
self.add_pass(CastToInt32Pass())
236259
self.add_pass(BroadcastArgsPass())
237-
238260
self.add_pass(ConvertPermuteSingletonToViewPass())
239261
self.add_pass(FuseViewCopyTransform())
240-
self.add_pass(FuseConstantArgsPass(exported_program))
241262
self.add_pass(DecomposeConv2dWithInt16ActivationPass())
242-
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
263+
self.add_pass(DecomposeSumPass())
243264
self.add_pass(InsertTableOpsPass(exported_program))
265+
266+
# Aten -> TOSA transformation passes
267+
244268
self.add_pass(RewriteUpsamplePass())
245269
self.add_pass(RewriteConv2dPass(exported_program))
246270
self.add_pass(RewriteMatmulPass())
271+
272+
# Postprocessing/cleanup passes
273+
274+
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
247275
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
248-
self.add_pass(InsertRescaleInt32Pass())
249-
self.add_pass(DecomposeSumPass())
250276
self.add_pass(ToTosaMemoryFormatPass(exported_program))
251277
self.add_pass(RemoveNoopPass())
252278
self.add_pass(InsertRescalePass())

backends/arm/_passes/decompose_linear_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
create_node,
1313
get_first_fake_tensor,
1414
)
15+
from executorch.backends.arm._passes.insert_rescales_pass import InsertRescaleInt32Pass
1516
from executorch.exir.dialects._ops import ops as exir_ops
1617
from executorch.exir.pass_base import ExportPass, PassResult
1718

@@ -26,7 +27,7 @@ class DecomposeLinearPass(ArmPass):
2627
output = view(conv2d)
2728
"""
2829

29-
_passes_required_after: Set[Type[ExportPass]] = set()
30+
_passes_required_after: Set[Type[ExportPass]] = {InsertRescaleInt32Pass}
3031

3132
def call(self, graph_module):
3233
for node in graph_module.graph.nodes:

backends/arm/_passes/match_arg_ranks_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(self, exported_program: ExportedProgram) -> None:
5757
exir_ops.edge.aten.lt.Tensor,
5858
exir_ops.edge.aten.le.Tensor,
5959
exir_ops.edge.aten.pow.Tensor_Tensor,
60+
exir_ops.edge.aten.remainder.Tensor,
6061
exir_ops.edge.aten.where.self,
6162
exir_ops.edge.aten.bitwise_and.Tensor,
6263
exir_ops.edge.aten.bitwise_xor.Tensor,

0 commit comments

Comments
 (0)