1616 ConvertConv1dToConv2d ,
1717 ConvertUpsampleBicubicWithBilinear ,
1818 DecomposeAny ,
19+ DecomposeCDist ,
1920 DecomposeEinsum ,
2021 DecomposeExpM1 ,
2122 DecomposeLinalgVectorNorm ,
3233 RecomposePixelUnshuffle ,
3334 RecomposeRmsNorm ,
3435 ReduceDynamicRange ,
36+ Remove0DTensor ,
3537 RemoveRedundancy ,
3638 ReplaceArangeArgs ,
3739 ReplaceIndexPutInput ,
@@ -71,7 +73,7 @@ def get_capture_program_passes():
7173 # If a pass is activated, it will be executed by default.
7274 default_passes_and_setting = [
7375 (AnnotateQuantAttrs , True ),
74- (AnnotateStack , False ),
76+ (AnnotateStack , True ),
7577 (AnnotateUnbind , True ),
7678 (ConvertBmmToMatmul , True ),
7779 (ConvertConv1dToConv2d , True ),
@@ -84,6 +86,7 @@ def get_capture_program_passes():
8486 (LayoutTransform , True ),
8587 (RecomposePixelUnshuffle , True ),
8688 (RecomposeRmsNorm , False ),
89+ (Remove0DTensor , True ),
8790 (RemoveRedundancy , True ),
8891 (ReplaceIndexPutInput , True ),
8992 (TagQuantIO , False ),
@@ -176,7 +179,23 @@ def transform_for_to_edge_pipeline(
176179
177180 return exported_program
178181
182+ # Before quantizer
183+ def transform_for_annotation_pipeline (self , graph_module : GraphModule ):
184+ self .add_pass (ReduceDynamicRange ())
185+ self .add_pass (RecomposePixelUnshuffle (quantization_capture = True ))
186+ self .add_pass (ReplaceArangeArgs ())
187+ self .add_pass (DecomposeCDist ())
188+ self .add_pass (DecomposeScaledDotProductAttention ())
189+ self .add_pass (DecomposeSilu ())
190+ self .add_pass (DecomposeEinsum ())
191+ self .add_pass (DecomposeExpM1 ())
192+ self .add_pass (DecomposeLinalgVectorNorm (quantization_capture = True ))
193+ self .add_pass (ReplaceInfValues ())
194+ self .add_pass (LiftConstantScalarOperands ())
195+ return self ._transform (graph_module )
196+
179197 def transform_for_export_pipeline (self , exported_program : ExportedProgram ):
198+ self .add_pass (DecomposeCDist ())
180199 self .add_pass (DecomposeScaledDotProductAttention ())
181200 self .add_pass (DecomposeLinalgVectorNorm (quantization_capture = True ))
182201 self .add_pass (DecomposeExpM1 ())
@@ -191,16 +210,3 @@ def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram):
191210 self .add_pass (LayoutTransform (exported_program , insert_permute = True ))
192211 self .add_pass (FuseConsecutiveTranspose ())
193212 return self ._transform (exported_program .graph_module )
194-
195- def transform_for_annotation_pipeline (self , graph_module : GraphModule ):
196- self .add_pass (ReduceDynamicRange ())
197- self .add_pass (RecomposePixelUnshuffle (quantization_capture = True ))
198- self .add_pass (ReplaceArangeArgs ())
199- self .add_pass (DecomposeScaledDotProductAttention ())
200- self .add_pass (DecomposeSilu ())
201- self .add_pass (DecomposeEinsum ())
202- self .add_pass (DecomposeExpM1 ())
203- self .add_pass (DecomposeLinalgVectorNorm (quantization_capture = True ))
204- self .add_pass (ReplaceInfValues ())
205- self .add_pass (LiftConstantScalarOperands ())
206- return self ._transform (graph_module )
0 commit comments