2727 ConvertIntPowToMuls ,
2828 ConvertMinMaxPass ,
2929 ConvertMmToBmmPass ,
30+ ConvertPermuteSingletonToViewPass ,
3031 ConvertSplitToSlicePass ,
3132 ConvertSqueezesToViewPass ,
3233 ConvertToClampPass ,
@@ -158,7 +159,12 @@ def _transform(self, graph_module: GraphModule):
158159 def _tosa_pipeline (
159160 self , exported_program : ExportedProgram , graph_module : GraphModule
160161 ) -> GraphModule :
162+ # Preprocessing passes
163+
161164 self .add_pass (AnnotateOutputDimOrderPass ())
165+
166+ # Node transformation passes (pre q/dq folding)
167+
162168 self .add_pass (FuseQuantizedActivationPass ())
163169 self .add_pass (RemoveGetItemPass ())
164170 self .add_pass (ConvertToClampPass ())
@@ -173,8 +179,19 @@ def _tosa_pipeline(
173179 self .add_pass (ConvertELUParamsPass ())
174180 self .add_pass (ConvertSplitToSlicePass ())
175181 self .add_pass (QuantizeOperatorArguments ())
182+
183+ # Fold Q/DQ nodes, insert INT8/INT32 rescales.
184+
176185 self .add_pass (FoldAndAnnotateQParamsPass (exported_program )) # type: ignore[call-arg]
177186 self .add_pass (FuseDuplicateUsersPass ())
187+ # TODO: DecomposeLinearPass should run after InsertRescaleInt32Pass or
188+ # before FoldAndAnnotateQParamsPass but is unable to at the moment.
189+ # Ticket: MLETORCH-1539
190+ self .add_pass (DecomposeLinearPass ())
191+ self .add_pass (InsertRescaleInt32Pass ())
192+
193+ # Node transformation passes (post q/dq folding)
194+
178195 self .add_pass (DecomposeExpm1Pass ())
179196 self .add_pass (DecomposeLogitPass ())
180197 self .add_pass (DecomposeMaskedFill ())
@@ -195,56 +212,67 @@ def _tosa_pipeline(
195212 self .add_pass (DecomposeSignPass ())
196213 self .add_pass (DecomposeFloorDividePass ())
197214 self .add_pass (DecomposeDivTensorModePass ())
215+ self .add_pass (DecomposeGeluPass ())
216+ self .add_pass (DecomposeAddSubAlphaPass ())
217+ self .add_pass (DecomposeGroupedConv ())
218+ self .add_pass (Conv1dUnsqueezePass ())
219+
220+ # Scalars -> tensors, match tensor dtypes and ranks.
221+
198222 self .add_pass (ReplaceScalarWithTensorByProfilePass ())
223+ self .add_pass (ConvertFullLikeToFullPass ())
224+ self .add_pass (MatchArgDtypePass ())
225+ self .add_pass (UnsqueezeScalarPlaceholdersPass (exported_program ))
226+ # TODO: Move DecomposeNotEqualPass to before or after this block of
227+ # passes. Ticket: MLETORCH-1540
228+ self .add_pass (DecomposeNotEqualPass ())
229+ self .add_pass (MatchArgRanksPass (exported_program ))
230+ self .add_pass (FuseConstantArgsPass (exported_program ))
231+
232+ # Node transformation passes (post scalar-removal)
233+
199234 self .add_pass (DecomposeRemainderPass ())
200235 self .add_pass (DecomposeDivTensorModePass ())
201236 self .add_pass (DecomposeEmbeddingPass ())
202237 self .add_pass (FuseBatchnorm2DPass (exported_program ))
203238 self .add_pass (ConvertMmToBmmPass ())
204239 self .add_pass (DecomposeGluPass ())
205- self .add_pass (DecomposeLinearPass ())
206240 self .add_pass (DecomposeLeakyReLUPass ())
207- self .add_pass (DecomposeNotEqualPass ())
208241 self .add_pass (DecomposeDivPass ())
209- self .add_pass (DecomposeAddSubAlphaPass ())
210242 self .add_pass (DecomposeSoftmaxPass ())
211- self .add_pass (DecomposeGeluPass ())
212- self .add_pass (ConvertFullLikeToFullPass ())
213243 self .add_pass (ConvertMinMaxPass ())
214244 self .add_pass (ConvertAnyDefaultDimDimsPass ())
215- self .add_pass (MatchArgDtypePass ())
216- self .add_pass (UnsqueezeScalarPlaceholdersPass (exported_program ))
217- self .add_pass (MatchArgRanksPass (exported_program ))
218245 self .add_pass (DecomposeAdaptiveAvgPool2dPass ())
219246 self .add_pass (DecomposeAvgPool2d ())
220247 self .add_pass (
221248 DecorateFp32toInt32CastingPass ()
222249 ) # Require that no new fp32->int32 is introduced after this pass
223250 self .add_pass (ComputeConstantOpsAOT (exported_program ))
224-
225- self .add_pass (DecomposeGroupedConv ())
226251 self .add_pass (ConvertExpandCopyToRepeatPass ())
227252 self .add_pass (UnsqueezeBeforeRepeatPass ())
228253 self .add_pass (DecomposeCumsumPass (exported_program ))
229- self .add_pass (Conv1dUnsqueezePass ())
230254 self .add_pass (DecomposeMaxPool2DPass ())
231255 self .add_pass (SizeAdjustInputPass ())
232256 self .add_pass (DecomposeSelectPass ())
233257 self .add_pass (ConvertSqueezesToViewPass ())
234258 self .add_pass (CastToInt32Pass ())
235259 self .add_pass (BroadcastArgsPass ())
236-
260+ self . add_pass ( ConvertPermuteSingletonToViewPass ())
237261 self .add_pass (FuseViewCopyTransform ())
238- self .add_pass (FuseConstantArgsPass (exported_program ))
239262 self .add_pass (DecomposeConv2dWithInt16ActivationPass ())
240- self .add_pass (CastInt64BuffersToInt32Pass ( exported_program ))
263+ self .add_pass (DecomposeSumPass ( ))
241264 self .add_pass (InsertTableOpsPass (exported_program ))
265+
266+ # Aten -> TOSA transformation passes
267+
242268 self .add_pass (RewriteUpsamplePass ())
243269 self .add_pass (RewriteConv2dPass (exported_program ))
244270 self .add_pass (RewriteMatmulPass ())
271+
272+ # Postprocessing/cleanup passes
273+
274+ self .add_pass (CastInt64BuffersToInt32Pass (exported_program ))
245275 self .add_pass (FuseEqualPlaceholdersPass (exported_program ))
246- self .add_pass (InsertRescaleInt32Pass ())
247- self .add_pass (DecomposeSumPass ())
248276 self .add_pass (ToTosaMemoryFormatPass (exported_program ))
249277 self .add_pass (RemoveNoopPass ())
250278 self .add_pass (InsertRescalePass ())
0 commit comments