Skip to content

Commit 67e45fe

Browse files
martinlsmMartin Lindström
andauthored
Arm backend: Add add_passes method (#15845)
Add a method, for ArmPassManager, called add_passes. This serves to group blocks of passes together more clearly for the reader. cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Martin Lindström <[email protected]> Co-authored-by: Martin Lindström <[email protected]>
1 parent b1e3e28 commit 67e45fe

File tree

2 files changed

+181
-147
lines changed

2 files changed

+181
-147
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 176 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88

99
from collections import defaultdict
10+
from collections.abc import Sequence
1011

1112
import executorch.backends.arm.tosa.dialect # noqa: unused
1213
from executorch.backends.arm._passes import (
@@ -112,6 +113,7 @@
112113
TosaSpecification,
113114
)
114115
from executorch.exir import ExportedProgram
116+
from executorch.exir.pass_base import ExportPass
115117
from executorch.exir.pass_manager import PassManager
116118
from torch.fx import GraphModule
117119
from torch.fx.passes.infra.pass_base import PassResult
@@ -150,6 +152,11 @@ def validate_constraints_mandatory(self):
150152

151153
raise RuntimeError(error_msg)
152154

155+
def add_passes(self, passes: Sequence[ExportPass | None]):
156+
for p in passes:
157+
if p is not None:
158+
self.add_pass(p)
159+
153160
def _transform(self, graph_module: GraphModule):
154161
with TosaLoweringContext(self.tosa_spec):
155162
return self(graph_module).graph_module
@@ -158,120 +165,136 @@ def _tosa_pipeline(
158165
self, exported_program: ExportedProgram, graph_module: GraphModule
159166
) -> GraphModule:
160167
# Preprocessing passes
161-
162168
self.add_pass(AnnotateOutputDimOrderPass())
163169

164170
# Node transformation passes (pre q/dq folding)
165-
166-
self.add_pass(FuseQuantizedActivationPass())
167-
self.add_pass(RemoveGetItemPass())
168-
self.add_pass(ConvertToClampPass())
169-
self.add_pass(DecomposeGroupNormPass())
170-
self.add_pass(DecomposeLayerNormPass())
171-
self.add_pass(DecomposeBatchNormNoStatsPass())
172-
self.add_pass(DecomposeVarPass())
173-
self.add_pass(
174-
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
171+
self.add_passes(
172+
[
173+
FuseQuantizedActivationPass(),
174+
RemoveGetItemPass(),
175+
ConvertToClampPass(),
176+
DecomposeGroupNormPass(),
177+
DecomposeLayerNormPass(),
178+
DecomposeBatchNormNoStatsPass(),
179+
DecomposeVarPass(),
180+
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec),
181+
AnnotateDecomposedMatmulPass(),
182+
ConvertELUParamsPass(),
183+
ConvertSplitToSlicePass(),
184+
QuantizeOperatorArguments(),
185+
]
175186
)
176-
self.add_pass(AnnotateDecomposedMatmulPass())
177-
self.add_pass(ConvertELUParamsPass())
178-
self.add_pass(ConvertSplitToSlicePass())
179-
self.add_pass(QuantizeOperatorArguments())
180187

181188
# Fold Q/DQ nodes, insert INT8/INT32 rescales.
182-
183-
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
184-
self.add_pass(FuseDuplicateUsersPass())
185-
# TODO: DecomposeLinearPass should run after InsertRescaleInt32Pass or
186-
# before FoldAndAnnotateQParamsPass but is unable to at the moment.
187-
# Ticket: MLETORCH-1539
188-
self.add_pass(DecomposeLinearPass())
189-
self.add_pass(InsertRescaleInt32Pass())
189+
self.add_passes(
190+
[
191+
FoldAndAnnotateQParamsPass(exported_program),
192+
FuseDuplicateUsersPass(),
193+
# TODO: DecomposeLinearPass should run after InsertRescaleInt32Pass or
194+
# before FoldAndAnnotateQParamsPass but is unable to at the moment.
195+
# Ticket: MLETORCH-1539
196+
DecomposeLinearPass(),
197+
InsertRescaleInt32Pass(),
198+
]
199+
)
190200

191201
# Node transformation passes (post q/dq folding)
192-
193-
self.add_pass(DecomposeLogitPass())
194-
self.add_pass(DecomposeMaskedFill())
195-
self.add_pass(DecomposeRoundPass())
196-
self.add_pass(DecomposeAcoshPass())
197-
self.add_pass(DecomposeAsinhPass())
198-
self.add_pass(DecomposeCoshPass())
199-
self.add_pass(DecomposeAsinAndAcosPass())
200-
self.add_pass(DecomposeSqrtPass())
201-
self.add_pass(DecomposeAtanPass())
202-
self.add_pass(DecomposeAtanhPass())
203-
self.add_pass(DecomposeAddmmPass())
204-
self.add_pass(DecomposeEluPass())
205-
self.add_pass(DecomposeExpm1Pass())
206-
self.add_pass(ConvertIntPowToMuls())
207-
self.add_pass(CastBoolToInt8Pass())
208-
self.add_pass(DecomposeSinhPass())
209-
self.add_pass(DecomposeSignPass())
210-
self.add_pass(DecomposeFloorDividePass())
211-
self.add_pass(DecomposeGeluPass())
212-
self.add_pass(DecomposeAddSubAlphaPass())
213-
self.add_pass(DecomposeGroupedConv())
214-
self.add_pass(Conv1dUnsqueezePass())
202+
self.add_passes(
203+
[
204+
DecomposeLogitPass(),
205+
DecomposeMaskedFill(),
206+
DecomposeRoundPass(),
207+
DecomposeAcoshPass(),
208+
DecomposeAsinhPass(),
209+
DecomposeCoshPass(),
210+
DecomposeAsinAndAcosPass(),
211+
DecomposeSqrtPass(),
212+
DecomposeAtanPass(),
213+
DecomposeAtanhPass(),
214+
DecomposeAddmmPass(),
215+
DecomposeEluPass(),
216+
DecomposeExpm1Pass(),
217+
ConvertIntPowToMuls(),
218+
CastBoolToInt8Pass(),
219+
DecomposeSinhPass(),
220+
DecomposeSignPass(),
221+
DecomposeFloorDividePass(),
222+
DecomposeGeluPass(),
223+
DecomposeAddSubAlphaPass(),
224+
DecomposeGroupedConv(),
225+
Conv1dUnsqueezePass(),
226+
]
227+
)
215228

216229
# Scalars -> tensors, match tensor dtypes and ranks.
217-
218-
self.add_pass(ReplaceScalarWithTensorByProfilePass())
219-
self.add_pass(ConvertFullLikeToFullPass())
220-
self.add_pass(MatchArgDtypePass())
221-
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
222-
# TODO: Move DecomposeNotEqualPass to before or after this block of
223-
# passes. Ticket: MLETORCH-1540
224-
self.add_pass(DecomposeNotEqualPass())
225-
self.add_pass(MatchArgRanksPass(exported_program))
226-
self.add_pass(FuseConstantArgsPass(exported_program))
230+
self.add_passes(
231+
[
232+
ReplaceScalarWithTensorByProfilePass(),
233+
ConvertFullLikeToFullPass(),
234+
MatchArgDtypePass(),
235+
UnsqueezeScalarPlaceholdersPass(exported_program),
236+
# TODO: Move DecomposeNotEqualPass to before or after this block of
237+
# passes. Ticket: MLETORCH-1540
238+
DecomposeNotEqualPass(),
239+
MatchArgRanksPass(exported_program),
240+
FuseConstantArgsPass(exported_program),
241+
]
242+
)
227243

228244
# Node transformation passes (post scalar-removal)
229-
230-
self.add_pass(DecomposeRemainderPass())
231-
self.add_pass(DecomposeDivTensorModePass())
232-
self.add_pass(DecomposeEmbeddingPass())
233-
self.add_pass(FuseBatchnorm2DPass(exported_program))
234-
self.add_pass(ConvertMmToBmmPass())
235-
self.add_pass(DecomposeGluPass())
236-
self.add_pass(DecomposeLeakyReLUPass())
237-
self.add_pass(DecomposeDivPass())
238-
self.add_pass(DecomposeSoftmaxPass())
239-
self.add_pass(ConvertMinMaxPass())
240-
self.add_pass(DecomposeAnyPass())
241-
self.add_pass(DecomposeAdaptiveAvgPool2dPass())
242-
self.add_pass(DecomposeAvgPool2d())
243-
self.add_pass(
244-
DecorateFp32toInt32CastingPass()
245-
) # Require that no new fp32->int32 is introduced after this pass
246-
self.add_pass(ComputeConstantOpsAOT(exported_program))
247-
self.add_pass(ConvertExpandCopyToRepeatPass())
248-
self.add_pass(UnsqueezeBeforeRepeatPass())
249-
self.add_pass(DecomposeCumsumPass(exported_program))
250-
self.add_pass(DecomposeMaxPool2DPass())
251-
self.add_pass(SizeAdjustInputPass())
252-
self.add_pass(DecomposeSelectPass())
253-
self.add_pass(ConvertSqueezesToViewPass())
254-
self.add_pass(CastToInt32Pass())
255-
self.add_pass(BroadcastArgsPass())
256-
self.add_pass(ConvertPermuteSingletonToViewPass())
257-
self.add_pass(FuseViewCopyTransformPass())
258-
self.add_pass(DecomposeConv2dWithInt16ActivationPass())
259-
self.add_pass(DecomposeSumPass())
260-
self.add_pass(InsertTableOpsPass(exported_program))
245+
self.add_passes(
246+
[
247+
DecomposeRemainderPass(),
248+
DecomposeDivTensorModePass(),
249+
DecomposeEmbeddingPass(),
250+
FuseBatchnorm2DPass(exported_program),
251+
ConvertMmToBmmPass(),
252+
DecomposeGluPass(),
253+
DecomposeLeakyReLUPass(),
254+
DecomposeDivPass(),
255+
DecomposeSoftmaxPass(),
256+
ConvertMinMaxPass(),
257+
DecomposeAnyPass(),
258+
DecomposeAdaptiveAvgPool2dPass(),
259+
DecomposeAvgPool2d(),
260+
DecorateFp32toInt32CastingPass(),
261+
ComputeConstantOpsAOT(exported_program),
262+
ConvertExpandCopyToRepeatPass(),
263+
UnsqueezeBeforeRepeatPass(),
264+
DecomposeCumsumPass(exported_program),
265+
DecomposeMaxPool2DPass(),
266+
SizeAdjustInputPass(),
267+
DecomposeSelectPass(),
268+
ConvertSqueezesToViewPass(),
269+
CastToInt32Pass(),
270+
BroadcastArgsPass(),
271+
ConvertPermuteSingletonToViewPass(),
272+
FuseViewCopyTransformPass(),
273+
DecomposeConv2dWithInt16ActivationPass(),
274+
DecomposeSumPass(),
275+
InsertTableOpsPass(exported_program),
276+
]
277+
)
261278

262279
# Aten -> TOSA transformation passes
263-
264-
self.add_pass(RewriteUpsamplePass())
265-
self.add_pass(RewriteConv2dPass(exported_program))
266-
self.add_pass(RewriteMatmulPass())
280+
self.add_passes(
281+
[
282+
RewriteUpsamplePass(),
283+
RewriteConv2dPass(exported_program),
284+
RewriteMatmulPass(),
285+
]
286+
)
267287

268288
# Postprocessing/cleanup passes
269-
270-
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
271-
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
272-
self.add_pass(ToTosaMemoryFormatPass(exported_program))
273-
self.add_pass(RemoveNoopPass())
274-
self.add_pass(InsertRescalePass())
289+
self.add_passes(
290+
[
291+
CastInt64BuffersToInt32Pass(exported_program),
292+
FuseEqualPlaceholdersPass(exported_program),
293+
ToTosaMemoryFormatPass(exported_program),
294+
RemoveNoopPass(),
295+
InsertRescalePass(),
296+
]
297+
)
275298

276299
self.validate_constraints_mandatory()
277300
return self._transform(graph_module)
@@ -287,66 +310,73 @@ def transform_to_backend_pipeline(
287310
return self._tosa_pipeline(exported_program, graph_module)
288311
else:
289312
raise NotImplementedError(
290-
f"No pass pipeline implemented for {self.tosa_spec=}"
313+
f"No pass pipeline implemented for {self.tosa_spec}"
291314
)
292315

293316
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
294317
# Preprocessing passes
295-
296-
self.add_pass(
297-
RemoveGraphAssertsPass()
298-
) # ConvertInt64ConstOpsToInt32Pass requires this pass to remove the assertation in Graph
318+
self.add_pass(RemoveGraphAssertsPass())
299319

300320
# Transformation passes (pre scalar -> tensor)
301-
302-
self.add_pass(ConvertInt64ConstOpsToInt32Pass())
303-
self.add_pass(ConvertInt64OutputOpsToInt32Pass())
304-
self.add_pass(InsertInt32CastsAfterInt64PlaceholdersPass())
305-
self.add_pass(DecomposeEmbeddingPass())
306-
self.add_pass(DecomposeScaledDotProductAttention())
307-
self.add_pass(DecomposeRoundPass())
308-
self.add_pass(DecomposeLogitPass())
309-
self.add_pass(CastBoolToInt8Pass())
310-
self.add_pass(DecomposeSignPass())
311-
self.add_pass(DecomposeAddmmPass())
312-
self.add_pass(DecomposeRemainderPass())
313-
self.add_pass(DecomposeFloorDividePass())
314-
self.add_pass(DecomposeDivTensorModePass())
321+
self.add_passes(
322+
[
323+
ConvertInt64ConstOpsToInt32Pass(),
324+
ConvertInt64OutputOpsToInt32Pass(),
325+
InsertInt32CastsAfterInt64PlaceholdersPass(),
326+
DecomposeEmbeddingPass(),
327+
DecomposeScaledDotProductAttention(),
328+
DecomposeRoundPass(),
329+
DecomposeLogitPass(),
330+
CastBoolToInt8Pass(),
331+
DecomposeSignPass(),
332+
DecomposeAddmmPass(),
333+
DecomposeRemainderPass(),
334+
DecomposeFloorDividePass(),
335+
DecomposeDivTensorModePass(),
336+
]
337+
)
315338

316339
# Scalars -> tensors
317-
318-
self.add_pass(ReplaceScalarWithTensorByProfilePass())
319-
self.add_pass(ScalarsToAttributePass())
340+
self.add_passes(
341+
[
342+
ReplaceScalarWithTensorByProfilePass(),
343+
ScalarsToAttributePass(),
344+
]
345+
)
320346

321347
# Transformation passes (post scalar removal)
322-
323-
self.add_pass(DecomposeAddSubAlphaPass())
324-
self.add_pass(DecomposeGroupNormPass())
325-
self.add_pass(DecomposeLayerNormPass())
326-
self.add_pass(DecomposeVarPass())
327-
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
328-
self.add_pass(DecomposeNotEqualPass())
329-
self.add_pass(DecomposeCosineSimilarityPass())
330-
self.add_pass(DecomposeGluPass())
331-
self.add_pass(DecomposeDivPass())
332-
self.add_pass(DecomposeLeakyReLUPass())
333-
self.add_pass(DecomposeLinearVectorNormPass())
334-
self.add_pass(DecomposeSqrtPass())
335-
self.add_pass(DecomposeSiluPass())
336-
self.add_pass(DecomposeAvgPool2d())
337-
if self.tosa_spec.is_U55_subset:
338-
# Numerically stable softmax uses amax which is not supported on Ethos-U55
339-
self.add_pass(DecomposeSoftmaxUnstablePass())
340-
else:
341-
self.add_pass(DecomposeSoftmaxPass())
342-
self.add_pass(ConvertMinMaxPass())
348+
self.add_passes(
349+
[
350+
DecomposeAddSubAlphaPass(),
351+
DecomposeGroupNormPass(),
352+
DecomposeLayerNormPass(),
353+
DecomposeVarPass(),
354+
DecomposeMeanDimPass(graph_module, self.tosa_spec),
355+
DecomposeNotEqualPass(),
356+
DecomposeCosineSimilarityPass(),
357+
DecomposeGluPass(),
358+
DecomposeDivPass(),
359+
DecomposeLeakyReLUPass(),
360+
DecomposeLinearVectorNormPass(),
361+
DecomposeSqrtPass(),
362+
DecomposeSiluPass(),
363+
DecomposeAvgPool2d(),
364+
(
365+
DecomposeSoftmaxUnstablePass()
366+
if self.tosa_spec.is_U55_subset
367+
else DecomposeSoftmaxPass()
368+
),
369+
ConvertMinMaxPass(),
370+
]
371+
)
343372

344373
# Postprocessing passes
345-
346-
self.add_pass(ReplaceInfValues())
347-
if not self.tosa_spec.is_U55_subset:
348-
# Uses where which is not supported on Ethos-U55
349-
self.add_pass(DecomposeMaskedFill())
374+
self.add_passes(
375+
[
376+
ReplaceInfValues(),
377+
DecomposeMaskedFill() if not self.tosa_spec.is_U55_subset else None,
378+
]
379+
)
350380

351381
return self._transform(graph_module)
352382

backends/arm/_passes/remove_graph_asserts_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@
66
from typing import Set, Type
77

88
from executorch.backends.arm._passes.arm_pass import ArmPass
9+
10+
from executorch.backends.arm._passes.convert_int64_const_ops_to_int32 import (
11+
ConvertInt64ConstOpsToInt32Pass,
12+
)
913
from executorch.exir.pass_base import ExportPass
1014
from executorch.exir.passes import remove_graph_asserts_pass
1115

1216

1317
class RemoveGraphAssertsPass(remove_graph_asserts_pass.RemoveGraphAssertsPass, ArmPass):
14-
_passes_required_after: Set[Type[ExportPass]] = set()
18+
_passes_required_after: Set[Type[ExportPass]] = {ConvertInt64ConstOpsToInt32Pass}

0 commit comments

Comments
 (0)