77
88
99from collections import defaultdict
10+ from collections .abc import Sequence
1011
1112import executorch .backends .arm .tosa .dialect # noqa: unused
1213from executorch .backends .arm ._passes import (
112113 TosaSpecification ,
113114)
114115from executorch .exir import ExportedProgram
116+ from executorch .exir .pass_base import ExportPass
115117from executorch .exir .pass_manager import PassManager
116118from torch .fx import GraphModule
117119from torch .fx .passes .infra .pass_base import PassResult
@@ -150,6 +152,11 @@ def validate_constraints_mandatory(self):
150152
151153 raise RuntimeError (error_msg )
152154
155+ def add_passes (self , passes : Sequence [ExportPass | None ]):
156+ for p in passes :
157+ if p is not None :
158+ self .add_pass (p )
159+
153160 def _transform (self , graph_module : GraphModule ):
154161 with TosaLoweringContext (self .tosa_spec ):
155162 return self (graph_module ).graph_module
@@ -158,120 +165,136 @@ def _tosa_pipeline(
158165 self , exported_program : ExportedProgram , graph_module : GraphModule
159166 ) -> GraphModule :
160167 # Preprocessing passes
161-
162168 self .add_pass (AnnotateOutputDimOrderPass ())
163169
164170 # Node transformation passes (pre q/dq folding)
165-
166- self .add_pass (FuseQuantizedActivationPass ())
167- self .add_pass (RemoveGetItemPass ())
168- self .add_pass (ConvertToClampPass ())
169- self .add_pass (DecomposeGroupNormPass ())
170- self .add_pass (DecomposeLayerNormPass ())
171- self .add_pass (DecomposeBatchNormNoStatsPass ())
172- self .add_pass (DecomposeVarPass ())
173- self .add_pass (
174- DecomposeMeanDimPass (exported_program .graph_module , self .tosa_spec )
171+ self .add_passes (
172+ [
173+ FuseQuantizedActivationPass (),
174+ RemoveGetItemPass (),
175+ ConvertToClampPass (),
176+ DecomposeGroupNormPass (),
177+ DecomposeLayerNormPass (),
178+ DecomposeBatchNormNoStatsPass (),
179+ DecomposeVarPass (),
180+ DecomposeMeanDimPass (exported_program .graph_module , self .tosa_spec ),
181+ AnnotateDecomposedMatmulPass (),
182+ ConvertELUParamsPass (),
183+ ConvertSplitToSlicePass (),
184+ QuantizeOperatorArguments (),
185+ ]
175186 )
176- self .add_pass (AnnotateDecomposedMatmulPass ())
177- self .add_pass (ConvertELUParamsPass ())
178- self .add_pass (ConvertSplitToSlicePass ())
179- self .add_pass (QuantizeOperatorArguments ())
180187
181188 # Fold Q/DQ nodes, insert INT8/INT32 rescales.
182-
183- self .add_pass (FoldAndAnnotateQParamsPass (exported_program )) # type: ignore[call-arg]
184- self .add_pass (FuseDuplicateUsersPass ())
185- # TODO: DecomposeLinearPass should run after InsertRescaleInt32Pass or
186- # before FoldAndAnnotateQParamsPass but is unable to at the moment.
187- # Ticket: MLETORCH-1539
188- self .add_pass (DecomposeLinearPass ())
189- self .add_pass (InsertRescaleInt32Pass ())
189+ self .add_passes (
190+ [
191+ FoldAndAnnotateQParamsPass (exported_program ),
192+ FuseDuplicateUsersPass (),
193+ # TODO: DecomposeLinearPass should run after InsertRescaleInt32Pass or
194+ # before FoldAndAnnotateQParamsPass but is unable to at the moment.
195+ # Ticket: MLETORCH-1539
196+ DecomposeLinearPass (),
197+ InsertRescaleInt32Pass (),
198+ ]
199+ )
190200
191201 # Node transformation passes (post q/dq folding)
192-
193- self .add_pass (DecomposeLogitPass ())
194- self .add_pass (DecomposeMaskedFill ())
195- self .add_pass (DecomposeRoundPass ())
196- self .add_pass (DecomposeAcoshPass ())
197- self .add_pass (DecomposeAsinhPass ())
198- self .add_pass (DecomposeCoshPass ())
199- self .add_pass (DecomposeAsinAndAcosPass ())
200- self .add_pass (DecomposeSqrtPass ())
201- self .add_pass (DecomposeAtanPass ())
202- self .add_pass (DecomposeAtanhPass ())
203- self .add_pass (DecomposeAddmmPass ())
204- self .add_pass (DecomposeEluPass ())
205- self .add_pass (DecomposeExpm1Pass ())
206- self .add_pass (ConvertIntPowToMuls ())
207- self .add_pass (CastBoolToInt8Pass ())
208- self .add_pass (DecomposeSinhPass ())
209- self .add_pass (DecomposeSignPass ())
210- self .add_pass (DecomposeFloorDividePass ())
211- self .add_pass (DecomposeGeluPass ())
212- self .add_pass (DecomposeAddSubAlphaPass ())
213- self .add_pass (DecomposeGroupedConv ())
214- self .add_pass (Conv1dUnsqueezePass ())
202+ self .add_passes (
203+ [
204+ DecomposeLogitPass (),
205+ DecomposeMaskedFill (),
206+ DecomposeRoundPass (),
207+ DecomposeAcoshPass (),
208+ DecomposeAsinhPass (),
209+ DecomposeCoshPass (),
210+ DecomposeAsinAndAcosPass (),
211+ DecomposeSqrtPass (),
212+ DecomposeAtanPass (),
213+ DecomposeAtanhPass (),
214+ DecomposeAddmmPass (),
215+ DecomposeEluPass (),
216+ DecomposeExpm1Pass (),
217+ ConvertIntPowToMuls (),
218+ CastBoolToInt8Pass (),
219+ DecomposeSinhPass (),
220+ DecomposeSignPass (),
221+ DecomposeFloorDividePass (),
222+ DecomposeGeluPass (),
223+ DecomposeAddSubAlphaPass (),
224+ DecomposeGroupedConv (),
225+ Conv1dUnsqueezePass (),
226+ ]
227+ )
215228
216229 # Scalars -> tensors, match tensor dtypes and ranks.
217-
218- self .add_pass (ReplaceScalarWithTensorByProfilePass ())
219- self .add_pass (ConvertFullLikeToFullPass ())
220- self .add_pass (MatchArgDtypePass ())
221- self .add_pass (UnsqueezeScalarPlaceholdersPass (exported_program ))
222- # TODO: Move DecomposeNotEqualPass to before or after this block of
223- # passes. Ticket: MLETORCH-1540
224- self .add_pass (DecomposeNotEqualPass ())
225- self .add_pass (MatchArgRanksPass (exported_program ))
226- self .add_pass (FuseConstantArgsPass (exported_program ))
230+ self .add_passes (
231+ [
232+ ReplaceScalarWithTensorByProfilePass (),
233+ ConvertFullLikeToFullPass (),
234+ MatchArgDtypePass (),
235+ UnsqueezeScalarPlaceholdersPass (exported_program ),
236+ # TODO: Move DecomposeNotEqualPass to before or after this block of
237+ # passes. Ticket: MLETORCH-1540
238+ DecomposeNotEqualPass (),
239+ MatchArgRanksPass (exported_program ),
240+ FuseConstantArgsPass (exported_program ),
241+ ]
242+ )
227243
228244 # Node transformation passes (post scalar-removal)
229-
230- self .add_pass (DecomposeRemainderPass ())
231- self .add_pass (DecomposeDivTensorModePass ())
232- self .add_pass (DecomposeEmbeddingPass ())
233- self .add_pass (FuseBatchnorm2DPass (exported_program ))
234- self .add_pass (ConvertMmToBmmPass ())
235- self .add_pass (DecomposeGluPass ())
236- self .add_pass (DecomposeLeakyReLUPass ())
237- self .add_pass (DecomposeDivPass ())
238- self .add_pass (DecomposeSoftmaxPass ())
239- self .add_pass (ConvertMinMaxPass ())
240- self .add_pass (DecomposeAnyPass ())
241- self .add_pass (DecomposeAdaptiveAvgPool2dPass ())
242- self .add_pass (DecomposeAvgPool2d ())
243- self .add_pass (
244- DecorateFp32toInt32CastingPass ()
245- ) # Require that no new fp32->int32 is introduced after this pass
246- self .add_pass (ComputeConstantOpsAOT (exported_program ))
247- self .add_pass (ConvertExpandCopyToRepeatPass ())
248- self .add_pass (UnsqueezeBeforeRepeatPass ())
249- self .add_pass (DecomposeCumsumPass (exported_program ))
250- self .add_pass (DecomposeMaxPool2DPass ())
251- self .add_pass (SizeAdjustInputPass ())
252- self .add_pass (DecomposeSelectPass ())
253- self .add_pass (ConvertSqueezesToViewPass ())
254- self .add_pass (CastToInt32Pass ())
255- self .add_pass (BroadcastArgsPass ())
256- self .add_pass (ConvertPermuteSingletonToViewPass ())
257- self .add_pass (FuseViewCopyTransformPass ())
258- self .add_pass (DecomposeConv2dWithInt16ActivationPass ())
259- self .add_pass (DecomposeSumPass ())
260- self .add_pass (InsertTableOpsPass (exported_program ))
245+ self .add_passes (
246+ [
247+ DecomposeRemainderPass (),
248+ DecomposeDivTensorModePass (),
249+ DecomposeEmbeddingPass (),
250+ FuseBatchnorm2DPass (exported_program ),
251+ ConvertMmToBmmPass (),
252+ DecomposeGluPass (),
253+ DecomposeLeakyReLUPass (),
254+ DecomposeDivPass (),
255+ DecomposeSoftmaxPass (),
256+ ConvertMinMaxPass (),
257+ DecomposeAnyPass (),
258+ DecomposeAdaptiveAvgPool2dPass (),
259+ DecomposeAvgPool2d (),
260+ DecorateFp32toInt32CastingPass (),
261+ ComputeConstantOpsAOT (exported_program ),
262+ ConvertExpandCopyToRepeatPass (),
263+ UnsqueezeBeforeRepeatPass (),
264+ DecomposeCumsumPass (exported_program ),
265+ DecomposeMaxPool2DPass (),
266+ SizeAdjustInputPass (),
267+ DecomposeSelectPass (),
268+ ConvertSqueezesToViewPass (),
269+ CastToInt32Pass (),
270+ BroadcastArgsPass (),
271+ ConvertPermuteSingletonToViewPass (),
272+ FuseViewCopyTransformPass (),
273+ DecomposeConv2dWithInt16ActivationPass (),
274+ DecomposeSumPass (),
275+ InsertTableOpsPass (exported_program ),
276+ ]
277+ )
261278
262279 # Aten -> TOSA transformation passes
263-
264- self .add_pass (RewriteUpsamplePass ())
265- self .add_pass (RewriteConv2dPass (exported_program ))
266- self .add_pass (RewriteMatmulPass ())
280+ self .add_passes (
281+ [
282+ RewriteUpsamplePass (),
283+ RewriteConv2dPass (exported_program ),
284+ RewriteMatmulPass (),
285+ ]
286+ )
267287
268288 # Postprocessing/cleanup passes
269-
270- self .add_pass (CastInt64BuffersToInt32Pass (exported_program ))
271- self .add_pass (FuseEqualPlaceholdersPass (exported_program ))
272- self .add_pass (ToTosaMemoryFormatPass (exported_program ))
273- self .add_pass (RemoveNoopPass ())
274- self .add_pass (InsertRescalePass ())
289+ self .add_passes (
290+ [
291+ CastInt64BuffersToInt32Pass (exported_program ),
292+ FuseEqualPlaceholdersPass (exported_program ),
293+ ToTosaMemoryFormatPass (exported_program ),
294+ RemoveNoopPass (),
295+ InsertRescalePass (),
296+ ]
297+ )
275298
276299 self .validate_constraints_mandatory ()
277300 return self ._transform (graph_module )
@@ -287,66 +310,73 @@ def transform_to_backend_pipeline(
287310 return self ._tosa_pipeline (exported_program , graph_module )
288311 else :
289312 raise NotImplementedError (
290- f"No pass pipeline implemented for { self .tosa_spec = } "
313+ f"No pass pipeline implemented for { self .tosa_spec } "
291314 )
292315
293316 def transform_for_annotation_pipeline (self , graph_module : GraphModule ):
294317 # Preprocessing passes
295-
296- self .add_pass (
297- RemoveGraphAssertsPass ()
298- ) # ConvertInt64ConstOpsToInt32Pass requires this pass to remove the assertation in Graph
318+ self .add_pass (RemoveGraphAssertsPass ())
299319
300320 # Transformation passes (pre scalar -> tensor)
301-
302- self .add_pass (ConvertInt64ConstOpsToInt32Pass ())
303- self .add_pass (ConvertInt64OutputOpsToInt32Pass ())
304- self .add_pass (InsertInt32CastsAfterInt64PlaceholdersPass ())
305- self .add_pass (DecomposeEmbeddingPass ())
306- self .add_pass (DecomposeScaledDotProductAttention ())
307- self .add_pass (DecomposeRoundPass ())
308- self .add_pass (DecomposeLogitPass ())
309- self .add_pass (CastBoolToInt8Pass ())
310- self .add_pass (DecomposeSignPass ())
311- self .add_pass (DecomposeAddmmPass ())
312- self .add_pass (DecomposeRemainderPass ())
313- self .add_pass (DecomposeFloorDividePass ())
314- self .add_pass (DecomposeDivTensorModePass ())
321+ self .add_passes (
322+ [
323+ ConvertInt64ConstOpsToInt32Pass (),
324+ ConvertInt64OutputOpsToInt32Pass (),
325+ InsertInt32CastsAfterInt64PlaceholdersPass (),
326+ DecomposeEmbeddingPass (),
327+ DecomposeScaledDotProductAttention (),
328+ DecomposeRoundPass (),
329+ DecomposeLogitPass (),
330+ CastBoolToInt8Pass (),
331+ DecomposeSignPass (),
332+ DecomposeAddmmPass (),
333+ DecomposeRemainderPass (),
334+ DecomposeFloorDividePass (),
335+ DecomposeDivTensorModePass (),
336+ ]
337+ )
315338
316339 # Scalars -> tensors
317-
318- self .add_pass (ReplaceScalarWithTensorByProfilePass ())
319- self .add_pass (ScalarsToAttributePass ())
340+ self .add_passes (
341+ [
342+ ReplaceScalarWithTensorByProfilePass (),
343+ ScalarsToAttributePass (),
344+ ]
345+ )
320346
321347 # Transformation passes (post scalar removal)
322-
323- self .add_pass (DecomposeAddSubAlphaPass ())
324- self .add_pass (DecomposeGroupNormPass ())
325- self .add_pass (DecomposeLayerNormPass ())
326- self .add_pass (DecomposeVarPass ())
327- self .add_pass (DecomposeMeanDimPass (graph_module , self .tosa_spec ))
328- self .add_pass (DecomposeNotEqualPass ())
329- self .add_pass (DecomposeCosineSimilarityPass ())
330- self .add_pass (DecomposeGluPass ())
331- self .add_pass (DecomposeDivPass ())
332- self .add_pass (DecomposeLeakyReLUPass ())
333- self .add_pass (DecomposeLinearVectorNormPass ())
334- self .add_pass (DecomposeSqrtPass ())
335- self .add_pass (DecomposeSiluPass ())
336- self .add_pass (DecomposeAvgPool2d ())
337- if self .tosa_spec .is_U55_subset :
338- # Numerically stable softmax uses amax which is not supported on Ethos-U55
339- self .add_pass (DecomposeSoftmaxUnstablePass ())
340- else :
341- self .add_pass (DecomposeSoftmaxPass ())
342- self .add_pass (ConvertMinMaxPass ())
348+ self .add_passes (
349+ [
350+ DecomposeAddSubAlphaPass (),
351+ DecomposeGroupNormPass (),
352+ DecomposeLayerNormPass (),
353+ DecomposeVarPass (),
354+ DecomposeMeanDimPass (graph_module , self .tosa_spec ),
355+ DecomposeNotEqualPass (),
356+ DecomposeCosineSimilarityPass (),
357+ DecomposeGluPass (),
358+ DecomposeDivPass (),
359+ DecomposeLeakyReLUPass (),
360+ DecomposeLinearVectorNormPass (),
361+ DecomposeSqrtPass (),
362+ DecomposeSiluPass (),
363+ DecomposeAvgPool2d (),
364+ (
365+ DecomposeSoftmaxUnstablePass ()
366+ if self .tosa_spec .is_U55_subset
367+ else DecomposeSoftmaxPass ()
368+ ),
369+ ConvertMinMaxPass (),
370+ ]
371+ )
343372
344373 # Postprocessing passes
345-
346- self .add_pass (ReplaceInfValues ())
347- if not self .tosa_spec .is_U55_subset :
348- # Uses where which is not supported on Ethos-U55
349- self .add_pass (DecomposeMaskedFill ())
374+ self .add_passes (
375+ [
376+ ReplaceInfValues (),
377+ DecomposeMaskedFill () if not self .tosa_spec .is_U55_subset else None ,
378+ ]
379+ )
350380
351381 return self ._transform (graph_module )
352382
0 commit comments