|
87 | 87 | QuantizeOperatorArguments, |
88 | 88 | RemoveNoopPass, |
89 | 89 | ReplaceInfValues, |
90 | | - ReplaceScalarWithTensorByProfilePass, |
| 90 | + ReplaceScalarWithTensorArgPassTOSABI, |
| 91 | + ReplaceScalarWithTensorArgPassTOSAMI, |
91 | 92 | RetraceFoldedDtypesPass, |
92 | 93 | RewriteConv2dPass, |
93 | 94 | RewriteMatmulPass, |
@@ -171,7 +172,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: |
171 | 172 | self.add_pass(CastToInt32Pass()) |
172 | 173 |
|
173 | 174 | self.add_pass(CastBoolToInt8Pass()) |
174 | | - self.add_pass(ReplaceScalarWithTensorByProfilePass()) |
| 175 | + self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) |
175 | 176 | self.add_pass(AnnotateDecomposedMatmulPass()) |
176 | 177 | self.add_pass(QuantizeOperatorArguments()) |
177 | 178 | self.add_pass(ConvertELUParamsPass()) |
@@ -241,7 +242,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: |
241 | 242 | self.add_pass(DecomposeSinhPass()) |
242 | 243 | self.add_pass(DecomposeSignPass()) |
243 | 244 | self.add_pass(DecomposeDivTensorModePass()) |
244 | | - self.add_pass(ReplaceScalarWithTensorByProfilePass()) |
| 245 | + self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI()) |
245 | 246 | self.add_pass(DecomposeEmbeddingPass()) |
246 | 247 | self.add_pass(FuseQuantizedActivationPass()) |
247 | 248 | self.add_pass(RemoveGetItemPass()) |
@@ -334,7 +335,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): |
334 | 335 | self.add_pass(DecomposeAddmmPass()) |
335 | 336 | self.add_pass(DecomposeDivTensorModePass()) |
336 | 337 | self.add_pass(DecomposeAddSubAlphaPass()) |
337 | | - self.add_pass(ReplaceScalarWithTensorByProfilePass()) |
| 338 | + self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) |
338 | 339 | self.add_pass(ScalarsToAttributePass()) |
339 | 340 | self.add_pass(DecomposeGroupNormPass()) |
340 | 341 | self.add_pass(DecomposeLayerNormPass()) |
|
0 commit comments