8787    QuantizeOperatorArguments ,
8888    RemoveNoopPass ,
8989    ReplaceInfValues ,
90-     ReplaceScalarWithTensorArgPassTOSABI ,
91-     ReplaceScalarWithTensorArgPassTOSAMI ,
90+     ReplaceScalarWithTensorByProfilePass ,
9291    RetraceFoldedDtypesPass ,
9392    RewriteConv2dPass ,
9493    RewriteMatmulPass ,
@@ -154,15 +153,15 @@ def _transform(self, graph_module: GraphModule):
154153        with  TosaLoweringContext (self .tosa_spec ):
155154            return  self (graph_module ).graph_module 
156155
157-     def  _tosa_INT_pipeline (self , exported_program : ExportedProgram ) ->  GraphModule :
156+     def  _tosa_INT_pipeline (
157+         self , exported_program : ExportedProgram , graph_module : GraphModule 
158+     ) ->  GraphModule :
158159        self .add_pass (AnnotateOutputDimOrderPass ())
159160        self .add_pass (FuseQuantizedActivationPass ())
160161        self .add_pass (RemoveGetItemPass ())
161162        self .add_pass (ConvertSplitToSlicePass ())
162163        self .add_pass (ConvertMmToBmmPass ())
163-         self .add_pass (
164-             DecomposeMeanDimPass (exported_program .graph_module , self .tosa_spec )
165-         )
164+         self .add_pass (DecomposeMeanDimPass (graph_module , self .tosa_spec ))
166165        self .add_pass (ConvertFullLikeToFullPass ())
167166        self .add_pass (ConvertToClampPass ())
168167        self .add_pass (ConvertMinMaxPass ())
@@ -172,7 +171,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
172171            self .add_pass (CastToInt32Pass ())
173172
174173        self .add_pass (CastBoolToInt8Pass ())
175-         self .add_pass (ReplaceScalarWithTensorArgPassTOSABI ())
174+         self .add_pass (ReplaceScalarWithTensorByProfilePass ())
176175        self .add_pass (AnnotateDecomposedMatmulPass ())
177176        self .add_pass (QuantizeOperatorArguments ())
178177        self .add_pass (ConvertELUParamsPass ())
@@ -219,9 +218,11 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
219218        self .add_pass (InsertRescalePass ())
220219
221220        self .validate_constraints_mandatory ()
222-         return  self ._transform (exported_program . graph_module )
221+         return  self ._transform (graph_module )
223222
224-     def  _tosa_FP_pipeline (self , exported_program : ExportedProgram ) ->  GraphModule :
223+     def  _tosa_FP_pipeline (
224+         self , exported_program : ExportedProgram , graph_module : GraphModule 
225+     ) ->  GraphModule :
225226        self .add_pass (AnnotateOutputDimOrderPass ())
226227        self .add_pass (DecomposeExpm1Pass ())
227228        self .add_pass (DecomposeLogitPass ())
@@ -242,7 +243,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
242243        self .add_pass (DecomposeSinhPass ())
243244        self .add_pass (DecomposeSignPass ())
244245        self .add_pass (DecomposeDivTensorModePass ())
245-         self .add_pass (ReplaceScalarWithTensorArgPassTOSAMI ())
246+         self .add_pass (ReplaceScalarWithTensorByProfilePass ())
246247        self .add_pass (DecomposeEmbeddingPass ())
247248        self .add_pass (FuseQuantizedActivationPass ())
248249        self .add_pass (RemoveGetItemPass ())
@@ -256,9 +257,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
256257        self .add_pass (DecomposeLayerNormPass ())
257258        self .add_pass (DecomposeBatchNormNoStatsPass ())
258259        self .add_pass (DecomposeVarPass ())
259-         self .add_pass (
260-             DecomposeMeanDimPass (exported_program .graph_module , self .tosa_spec )
261-         )
260+         self .add_pass (DecomposeMeanDimPass (graph_module , self .tosa_spec ))
262261        self .add_pass (DecomposeNotEqualPass ())
263262        self .add_pass (DecomposeDivPass ())
264263        self .add_pass (DecomposeAddSubAlphaPass ())
@@ -306,14 +305,16 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
306305        self .add_pass (InsertRescalePass ())
307306
308307        self .validate_constraints_mandatory ()
309-         return  self ._transform (exported_program . graph_module )
308+         return  self ._transform (graph_module )
310309
311-     def  transform_to_backend_pipeline (self , exported_program : ExportedProgram ):
310+     def  transform_to_backend_pipeline (
311+         self , exported_program : ExportedProgram , graph_module : GraphModule 
312+     ):
312313        """Apply passes before transforming program to backend""" 
313314        if  self .tosa_spec  ==  TosaSpecification .create_from_string ("TOSA-1.0+FP" ):
314-             return  self ._tosa_FP_pipeline (exported_program )
315+             return  self ._tosa_FP_pipeline (exported_program ,  graph_module )
315316        elif  self .tosa_spec  ==  TosaSpecification .create_from_string ("TOSA-1.0+INT" ):
316-             return  self ._tosa_INT_pipeline (exported_program )
317+             return  self ._tosa_INT_pipeline (exported_program ,  graph_module )
317318        else :
318319            raise  NotImplementedError (
319320                f"No pass pipeline implemented for { self .tosa_spec = }  " 
@@ -335,7 +336,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
335336        self .add_pass (DecomposeAddmmPass ())
336337        self .add_pass (DecomposeDivTensorModePass ())
337338        self .add_pass (DecomposeAddSubAlphaPass ())
338-         self .add_pass (ReplaceScalarWithTensorArgPassTOSABI ())
339+         self .add_pass (ReplaceScalarWithTensorByProfilePass ())
339340        self .add_pass (ScalarsToAttributePass ())
340341        self .add_pass (DecomposeGroupNormPass ())
341342        self .add_pass (DecomposeLayerNormPass ())
0 commit comments