5252from  executorch .backends .arm ._passes .fuse_quantized_activation_pass  import  (  # type: ignore[import-not-found] 
5353    FuseQuantizedActivationPass ,
5454)
55+ from  executorch .backends .arm ._passes .insert_rescales_pass  import  InsertRescalePass 
5556from  executorch .backends .arm ._passes .insert_table_ops  import  InsertTableOpsPass 
5657from  executorch .backends .arm ._passes .keep_dims_false_to_squeeze_pass  import  (
5758    KeepDimsFalseToSqueezePass ,
7576    UnsqueezeScalarPlaceholdersPass ,
7677)
7778from  executorch .backends .arm .tosa_specification  import  TosaSpecification 
79+ 
80+ from  executorch .backends .transforms .replace_scalar_with_tensor  import  (
81+     ReplaceScalarWithTensorArgPass ,
82+ )
7883from  executorch .backends .xnnpack ._passes .remove_getitem_op  import  RemoveGetItemPass 
7984from  executorch .exir  import  ExportedProgram 
8085from  executorch .exir .pass_manager  import  PassManager 
@@ -100,6 +105,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
100105        self .add_pass (ConvertMeanDimToAveragePoolPass ())
101106        self .add_pass (ConvertFullLikeToFullPass ())
102107
108+         self .add_pass (ReplaceScalarWithTensorArgPass ())
103109        self .add_pass (AnnotateDecomposedMatmulPass ())
104110        self .add_pass (QuantizeOperatorArguments ())
105111        self .add_pass (FoldAndAnnotateQParamsPass ())  # type: ignore[call-arg] 
@@ -119,11 +125,11 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
119125        self .add_pass (ConvertSqueezesToViewPass ())
120126
121127        self .add_pass (AnnotateChannelsLastDimOrder ())
122- 
128+          self . add_pass ( InsertRescalePass ()) 
123129        return  self ._transform (exported_program .graph_module )
124130
125131    def  _tosa_080_MI_pipeline (self , exported_program : ExportedProgram ) ->  GraphModule :
126- 
132+          self . add_pass ( ReplaceScalarWithTensorArgPass ()) 
127133        self .add_pass (FuseQuantizedActivationPass ())
128134        self .add_pass (RemoveGetItemPass ())
129135        self .add_pass (ConvertSplitToSlicePass ())
@@ -157,6 +163,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
157163        self .add_pass (ConvertSqueezesToViewPass ())
158164
159165        self .add_pass (AnnotateChannelsLastDimOrder ())
166+         self .add_pass (InsertRescalePass ())
160167
161168        return  self ._transform (exported_program .graph_module )
162169
@@ -173,6 +180,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
173180
174181    def  transform_for_annotation_pipeline (self , graph_module : GraphModule ):
175182        self .add_pass (ScalarsToAttributePass ())
183+         self .add_pass (ReplaceScalarWithTensorArgPass ())
176184        self .add_pass (DecomposeLayerNormPass ())
177185        self .add_pass (DecomposeVarPass ())
178186        self .add_pass (DecomposeMeanDimPass ())
0 commit comments