@@ -155,79 +155,26 @@ def _transform(self, graph_module: GraphModule):
155155 with TosaLoweringContext (self .tosa_spec ):
156156 return self (graph_module ).graph_module
157157
158- def _tosa_INT_pipeline (
158+ def _tosa_pipeline (
159159 self , exported_program : ExportedProgram , graph_module : GraphModule
160160 ) -> GraphModule :
161161 self .add_pass (AnnotateOutputDimOrderPass ())
162162 self .add_pass (FuseQuantizedActivationPass ())
163163 self .add_pass (RemoveGetItemPass ())
164- self .add_pass (ConvertSplitToSlicePass ())
165- self .add_pass (ConvertMmToBmmPass ())
166- self .add_pass (DecomposeMeanDimPass (graph_module , self .tosa_spec ))
167- self .add_pass (ConvertFullLikeToFullPass ())
168164 self .add_pass (ConvertToClampPass ())
169- self .add_pass (ConvertMinMaxPass ())
170- self .add_pass (ConvertAnyDefaultDimDimsPass ())
171- self .add_pass (MatchArgDtypePass ())
172- if self .tosa_spec .is_U55_subset :
173- self .add_pass (CastToInt32Pass ())
174-
175- self .add_pass (CastBoolToInt8Pass ())
176- self .add_pass (ReplaceScalarWithTensorByProfilePass ())
165+ self .add_pass (DecomposeGroupNormPass ())
166+ self .add_pass (DecomposeLayerNormPass ())
167+ self .add_pass (DecomposeBatchNormNoStatsPass ())
168+ self .add_pass (DecomposeVarPass ())
169+ self .add_pass (
170+ DecomposeMeanDimPass (exported_program .graph_module , self .tosa_spec )
171+ )
177172 self .add_pass (AnnotateDecomposedMatmulPass ())
178- self .add_pass (QuantizeOperatorArguments ())
179173 self .add_pass (ConvertELUParamsPass ())
174+ self .add_pass (ConvertSplitToSlicePass ())
175+ self .add_pass (QuantizeOperatorArguments ())
180176 self .add_pass (FoldAndAnnotateQParamsPass (exported_program )) # type: ignore[call-arg]
181177 self .add_pass (FuseDuplicateUsersPass ())
182- self .add_pass (UnsqueezeScalarPlaceholdersPass (exported_program ))
183- self .add_pass (MatchArgRanksPass (exported_program ))
184- if self .tosa_spec .is_U55_subset :
185- self .add_pass (BroadcastArgsPass ())
186- self .add_pass (DecomposeLinearPass ())
187- self .add_pass (DecomposeAdaptiveAvgPool2dPass ())
188- self .add_pass (DecomposeAvgPool2d ())
189- self .add_pass (ComputeConstantOpsAOT (exported_program ))
190-
191- self .add_pass (DecomposeGroupedConv ())
192-
193- self .add_pass (ConvertExpandCopyToRepeatPass ())
194- self .add_pass (UnsqueezeBeforeRepeatPass ())
195- self .add_pass (CastInt64BuffersToInt32Pass (exported_program ))
196- self .add_pass (DecomposeCumsumPass (exported_program ))
197- self .add_pass (Conv1dUnsqueezePass ())
198- self .add_pass (DecomposeMaxPool2DPass ())
199- self .add_pass (SizeAdjustInputPass ())
200- self .add_pass (DecomposeSelectPass ())
201- self .add_pass (ConvertSqueezesToViewPass ())
202-
203- self .add_pass (FuseViewCopyTransform ())
204- self .add_pass (FuseConstantArgsPass (exported_program ))
205- self .add_pass (InsertTableOpsPass (exported_program ))
206- # If we have a conv2d with int16 activation split up into a convolution
207- # and an addition, to work-around the lack of support for int48 in torch
208- # needs to happen before RewriteConv2dPass, but after the table ops are inserted
209- # to be able to validate that conv2d has right dtype arguments.
210- self .add_pass (DecomposeConv2dWithInt16ActivationPass ())
211- self .add_pass (RewriteConv2dPass (exported_program ))
212-
213- self .add_pass (RewriteMatmulPass ())
214- self .add_pass (RewriteUpsamplePass ())
215- self .add_pass (FuseEqualPlaceholdersPass (exported_program ))
216-
217- self .add_pass (InsertRescaleInt32Pass ())
218- self .add_pass (DecomposeSumPass ())
219- self .add_pass (ToTosaMemoryFormatPass (exported_program ))
220- self .add_pass (RemoveNoopPass ())
221- self .add_pass (InsertRescalePass ())
222-
223- self .validate_constraints_mandatory ()
224- return self ._transform (graph_module )
225-
226- def _tosa_FP_pipeline (
227- self , exported_program : ExportedProgram , graph_module : GraphModule
228- ) -> GraphModule :
229- self .add_pass (AnnotateOutputDimOrderPass ())
230- self .add_pass (FuseDuplicateUsersPass ())
231178 self .add_pass (DecomposeExpm1Pass ())
232179 self .add_pass (DecomposeLogitPass ())
233180 self .add_pass (DecomposeMaskedFill ())
@@ -252,32 +199,20 @@ def _tosa_FP_pipeline(
252199 self .add_pass (DecomposeRemainderPass ())
253200 self .add_pass (DecomposeDivTensorModePass ())
254201 self .add_pass (DecomposeEmbeddingPass ())
255- self .add_pass (FuseQuantizedActivationPass ())
256- self .add_pass (RemoveGetItemPass ())
257- self .add_pass (ConvertSplitToSlicePass ())
258202 self .add_pass (FuseBatchnorm2DPass (exported_program ))
259203 self .add_pass (ConvertMmToBmmPass ())
260204 self .add_pass (DecomposeGluPass ())
261205 self .add_pass (DecomposeLinearPass ())
262206 self .add_pass (DecomposeLeakyReLUPass ())
263- self .add_pass (DecomposeGroupNormPass ())
264- self .add_pass (DecomposeLayerNormPass ())
265- self .add_pass (DecomposeBatchNormNoStatsPass ())
266- self .add_pass (DecomposeVarPass ())
267- self .add_pass (DecomposeMeanDimPass (graph_module , self .tosa_spec ))
268207 self .add_pass (DecomposeNotEqualPass ())
269208 self .add_pass (DecomposeDivPass ())
270209 self .add_pass (DecomposeAddSubAlphaPass ())
271210 self .add_pass (DecomposeSoftmaxPass ())
272211 self .add_pass (DecomposeGeluPass ())
273212 self .add_pass (ConvertFullLikeToFullPass ())
274- self .add_pass (ConvertToClampPass ())
275213 self .add_pass (ConvertMinMaxPass ())
276214 self .add_pass (ConvertAnyDefaultDimDimsPass ())
277215 self .add_pass (MatchArgDtypePass ())
278- self .add_pass (AnnotateDecomposedMatmulPass ())
279- self .add_pass (QuantizeOperatorArguments ())
280- self .add_pass (FoldAndAnnotateQParamsPass (exported_program )) # type: ignore[call-arg]
281216 self .add_pass (UnsqueezeScalarPlaceholdersPass (exported_program ))
282217 self .add_pass (MatchArgRanksPass (exported_program ))
283218 self .add_pass (DecomposeAdaptiveAvgPool2dPass ())
@@ -290,22 +225,26 @@ def _tosa_FP_pipeline(
290225 self .add_pass (DecomposeGroupedConv ())
291226 self .add_pass (ConvertExpandCopyToRepeatPass ())
292227 self .add_pass (UnsqueezeBeforeRepeatPass ())
293- self .add_pass (DecomposeSumPass ())
294228 self .add_pass (DecomposeCumsumPass (exported_program ))
295229 self .add_pass (Conv1dUnsqueezePass ())
296230 self .add_pass (DecomposeMaxPool2DPass ())
297231 self .add_pass (SizeAdjustInputPass ())
298232 self .add_pass (DecomposeSelectPass ())
299233 self .add_pass (ConvertSqueezesToViewPass ())
234+ self .add_pass (CastToInt32Pass ())
235+ self .add_pass (BroadcastArgsPass ())
300236
301237 self .add_pass (FuseViewCopyTransform ())
302238 self .add_pass (FuseConstantArgsPass (exported_program ))
303- self .add_pass (RewriteConv2dPass ( exported_program ))
239+ self .add_pass (DecomposeConv2dWithInt16ActivationPass ( ))
304240 self .add_pass (CastInt64BuffersToInt32Pass (exported_program ))
305- self .add_pass (RewriteUpsamplePass ())
306241 self .add_pass (InsertTableOpsPass (exported_program ))
242+ self .add_pass (RewriteUpsamplePass ())
243+ self .add_pass (RewriteConv2dPass (exported_program ))
307244 self .add_pass (RewriteMatmulPass ())
308245 self .add_pass (FuseEqualPlaceholdersPass (exported_program ))
246+ self .add_pass (InsertRescaleInt32Pass ())
247+ self .add_pass (DecomposeSumPass ())
309248 self .add_pass (ToTosaMemoryFormatPass (exported_program ))
310249 self .add_pass (RemoveNoopPass ())
311250 self .add_pass (InsertRescalePass ())
@@ -317,10 +256,11 @@ def transform_to_backend_pipeline(
317256 self , exported_program : ExportedProgram , graph_module : GraphModule
318257 ):
319258 """Apply passes before transforming program to backend"""
320- if self .tosa_spec == TosaSpecification .create_from_string ("TOSA-1.0+FP" ):
321- return self ._tosa_FP_pipeline (exported_program , graph_module )
322- elif self .tosa_spec == TosaSpecification .create_from_string ("TOSA-1.0+INT" ):
323- return self ._tosa_INT_pipeline (exported_program , graph_module )
259+ if self .tosa_spec in (
260+ TosaSpecification .create_from_string ("TOSA-1.0+FP" ),
261+ TosaSpecification .create_from_string ("TOSA-1.0+INT" ),
262+ ):
263+ return self ._tosa_pipeline (exported_program , graph_module )
324264 else :
325265 raise NotImplementedError (
326266 f"No pass pipeline implemented for { self .tosa_spec = } "
0 commit comments