55# This source code is licensed under the BSD-style license found in the
66# LICENSE file in the root directory of this source tree.
77
8- # pyre-unsafe
9-
108
119from collections import defaultdict
1210
8987 QuantizeOperatorArguments ,
9088 RemoveNoopPass ,
9189 ReplaceInfValues ,
92- ReplaceScalarWithTensorArgPassTOSABI ,
93- ReplaceScalarWithTensorArgPassTOSAMI ,
90+ ReplaceScalarWithTensorByProfilePass ,
9491 RetraceFoldedDtypesPass ,
9592 RewriteConv2dPass ,
9693 RewriteMatmulPass ,
@@ -174,7 +171,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
174171 self .add_pass (CastToInt32Pass ())
175172
176173 self .add_pass (CastBoolToInt8Pass ())
177- self .add_pass (ReplaceScalarWithTensorArgPassTOSABI ())
174+ self .add_pass (ReplaceScalarWithTensorByProfilePass ())
178175 self .add_pass (AnnotateDecomposedMatmulPass ())
179176 self .add_pass (QuantizeOperatorArguments ())
180177 self .add_pass (ConvertELUParamsPass ())
@@ -194,7 +191,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
194191 self .add_pass (ConvertExpandCopyToRepeatPass ())
195192 self .add_pass (UnsqueezeBeforeRepeatPass ())
196193 self .add_pass (CastInt64BuffersToInt32Pass (exported_program ))
197- self .add_pass (DecomposeSumPass ())
198194 self .add_pass (DecomposeCumsumPass (exported_program ))
199195 self .add_pass (Conv1dUnsqueezePass ())
200196 self .add_pass (DecomposeMaxPool2DPass ())
@@ -215,10 +211,11 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
215211 self .add_pass (RewriteMatmulPass ())
216212 self .add_pass (RewriteUpsamplePass ())
217213 self .add_pass (FuseEqualPlaceholdersPass (exported_program ))
214+ self .add_pass (InsertRescaleInt32Pass ())
215+ self .add_pass (DecomposeSumPass ())
218216 self .add_pass (ToTosaMemoryFormatPass (exported_program ))
219217 self .add_pass (RemoveNoopPass ())
220218 self .add_pass (InsertRescalePass ())
221- self .add_pass (InsertRescaleInt32Pass ())
222219
223220 self .validate_constraints_mandatory ()
224221 return self ._transform (exported_program .graph_module )
@@ -244,7 +241,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
244241 self .add_pass (DecomposeSinhPass ())
245242 self .add_pass (DecomposeSignPass ())
246243 self .add_pass (DecomposeDivTensorModePass ())
247- self .add_pass (ReplaceScalarWithTensorArgPassTOSAMI ())
244+ self .add_pass (ReplaceScalarWithTensorByProfilePass ())
248245 self .add_pass (DecomposeEmbeddingPass ())
249246 self .add_pass (FuseQuantizedActivationPass ())
250247 self .add_pass (RemoveGetItemPass ())
@@ -337,7 +334,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
337334 self .add_pass (DecomposeAddmmPass ())
338335 self .add_pass (DecomposeDivTensorModePass ())
339336 self .add_pass (DecomposeAddSubAlphaPass ())
340- self .add_pass (ReplaceScalarWithTensorArgPassTOSABI ())
337+ self .add_pass (ReplaceScalarWithTensorByProfilePass ())
341338 self .add_pass (ScalarsToAttributePass ())
342339 self .add_pass (DecomposeGroupNormPass ())
343340 self .add_pass (DecomposeLayerNormPass ())
@@ -361,7 +358,6 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
361358
362359 self .add_pass (ConvertMinMaxPass ())
363360 self .add_pass (ReplaceInfValues ())
364- self .add_pass (DecomposeSumPass ())
365361
366362 if not self .tosa_spec .is_U55_subset :
367363 # Uses where which is not supported on Ethos-U55
0 commit comments