1515 ConvertBmmToMatmul ,
1616 ConvertConv1dToConv2d ,
1717 DecomposeAny ,
18+ DecomposeCDist ,
1819 DecomposeEinsum ,
1920 DecomposeExpM1 ,
2021 DecomposeLinalgVectorNorm ,
3132 RecomposePixelUnshuffle ,
3233 RecomposeRmsNorm ,
3334 ReduceDynamicRange ,
35+ Remove0DTensor ,
3436 RemoveRedundancy ,
3537 ReplaceArangeArgs ,
3638 ReplaceIndexPutInput ,
@@ -70,7 +72,7 @@ def get_capture_program_passes():
7072 # If a pass is activated, it will be executed by default.
7173 default_passes_and_setting = [
7274 (AnnotateQuantAttrs , True ),
73- (AnnotateStack , False ),
75+ (AnnotateStack , True ),
7476 (AnnotateUnbind , True ),
7577 (ConvertBmmToMatmul , True ),
7678 (ConvertConv1dToConv2d , True ),
@@ -82,6 +84,7 @@ def get_capture_program_passes():
8284 (LayoutTransform , True ),
8385 (RecomposePixelUnshuffle , True ),
8486 (RecomposeRmsNorm , False ),
87+ (Remove0DTensor , True ),
8588 (RemoveRedundancy , True ),
8689 (ReplaceIndexPutInput , True ),
8790 (TagQuantIO , False ),
@@ -174,7 +177,23 @@ def transform_for_to_edge_pipeline(
174177
175178 return exported_program
176179
180+ # Before quantizer
181+ def transform_for_annotation_pipeline (self , graph_module : GraphModule ):
182+ self .add_pass (ReduceDynamicRange ())
183+ self .add_pass (RecomposePixelUnshuffle (quantization_capture = True ))
184+ self .add_pass (ReplaceArangeArgs ())
185+ self .add_pass (DecomposeCDist ())
186+ self .add_pass (DecomposeScaledDotProductAttention ())
187+ self .add_pass (DecomposeSilu ())
188+ self .add_pass (DecomposeEinsum ())
189+ self .add_pass (DecomposeExpM1 ())
190+ self .add_pass (DecomposeLinalgVectorNorm (quantization_capture = True ))
191+ self .add_pass (ReplaceInfValues ())
192+ self .add_pass (LiftConstantScalarOperands ())
193+ return self ._transform (graph_module )
194+
177195 def transform_for_export_pipeline (self , exported_program : ExportedProgram ):
196+ self .add_pass (DecomposeCDist ())
178197 self .add_pass (DecomposeScaledDotProductAttention ())
179198 self .add_pass (DecomposeLinalgVectorNorm (quantization_capture = True ))
180199 self .add_pass (DecomposeExpM1 ())
@@ -189,16 +208,3 @@ def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram):
189208 self .add_pass (LayoutTransform (exported_program , insert_permute = True ))
190209 self .add_pass (FuseConsecutiveTranspose ())
191210 return self ._transform (exported_program .graph_module )
192-
193- def transform_for_annotation_pipeline (self , graph_module : GraphModule ):
194- self .add_pass (ReduceDynamicRange ())
195- self .add_pass (RecomposePixelUnshuffle (quantization_capture = True ))
196- self .add_pass (ReplaceArangeArgs ())
197- self .add_pass (DecomposeScaledDotProductAttention ())
198- self .add_pass (DecomposeSilu ())
199- self .add_pass (DecomposeEinsum ())
200- self .add_pass (DecomposeExpM1 ())
201- self .add_pass (DecomposeLinalgVectorNorm (quantization_capture = True ))
202- self .add_pass (ReplaceInfValues ())
203- self .add_pass (LiftConstantScalarOperands ())
204- return self ._transform (graph_module )
0 commit comments