|
88 | 88 | QuantizeOperatorArguments, |
89 | 89 | RemoveNoopPass, |
90 | 90 | ReplaceInfValues, |
91 | | - ReplaceScalarWithTensorByProfilePass, |
| 91 | + ReplaceScalarWithTensorArgPassTOSABI, |
| 92 | + ReplaceScalarWithTensorArgPassTOSAMI, |
92 | 93 | RetraceFoldedDtypesPass, |
93 | 94 | RewriteUpsamplePass, |
94 | 95 | ScalarsToAttributePass, |
@@ -170,7 +171,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: |
170 | 171 | self.add_pass(CastToInt32Pass()) |
171 | 172 |
|
172 | 173 | self.add_pass(CastBoolToInt8Pass()) |
173 | | - self.add_pass(ReplaceScalarWithTensorByProfilePass()) |
| 174 | + self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) |
174 | 175 | self.add_pass(AnnotateDecomposedMatmulPass()) |
175 | 176 | self.add_pass(QuantizeOperatorArguments()) |
176 | 177 | self.add_pass(ConvertELUParamsPass()) |
@@ -238,7 +239,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: |
238 | 239 | self.add_pass(DecomposeSinhPass()) |
239 | 240 | self.add_pass(DecomposeSignPass()) |
240 | 241 | self.add_pass(DecomposeDivTensorModePass()) |
241 | | - self.add_pass(ReplaceScalarWithTensorByProfilePass()) |
| 242 | + self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI()) |
242 | 243 | self.add_pass(DecomposeEmbeddingPass()) |
243 | 244 | self.add_pass(FuseQuantizedActivationPass()) |
244 | 245 | self.add_pass(RemoveGetItemPass()) |
@@ -328,7 +329,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): |
328 | 329 | self.add_pass(DecomposeSignPass()) |
329 | 330 | self.add_pass(DecomposeAddmmPass()) |
330 | 331 | self.add_pass(DecomposeDivTensorModePass()) |
331 | | - self.add_pass(ReplaceScalarWithTensorByProfilePass()) |
| 332 | + self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) |
332 | 333 | self.add_pass(ScalarsToAttributePass()) |
333 | 334 | self.add_pass(DecomposeGroupNormPass()) |
334 | 335 | self.add_pass(DecomposeLayerNormPass()) |
|
0 commit comments