77
88# pyre-unsafe 
99from  executorch .backends .arm ._passes  import  (
10+     AddBiasPass ,
1011    AnnotateChannelsLastDimOrder ,
1112    AnnotateDecomposedMatmulPass ,
1213    BroadcastArgsPass ,
14+     CastBoolToInt8Pass ,
1315    CastInt64BuffersToInt32Pass ,
1416    CastToInt32Pass ,
1517    ComputeConstantOpsAOT ,
2325    ConvertSplitToSlicePass ,
2426    ConvertSqueezesToViewPass ,
2527    ConvertToClampPass ,
28+     DecomposeAvgPool2d ,
2629    DecomposeCosineSimilarityPass ,
2730    DecomposeDivPass ,
2831    DecomposeEmbeddingPass ,
2932    DecomposeGeluPass ,
33+     DecomposeGroupedConv ,
3034    DecomposeGroupNormPass ,
3135    DecomposeLayerNormPass ,
3236    DecomposeLeakyReLUPass ,
3539    DecomposeMaxPool2DPass ,
3640    DecomposeMeanDimPass ,
3741    DecomposeNotEqualPass ,
42+     DecomposeRoundPass ,
3843    DecomposeSelectPass ,
3944    DecomposeSiluPass ,
4045    DecomposeSoftmaxPass ,
6368    UnsqueezeBeforeRepeatPass ,
6469    UnsqueezeScalarPlaceholdersPass ,
6570)
66- 
6771from  executorch .backends .arm .tosa_specification  import  (
6872    TosaLoweringContext ,
6973    TosaSpecification ,
@@ -105,6 +109,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
105109        if  self .tosa_spec .is_U55_subset :
106110            self .add_pass (CastToInt32Pass ())
107111
112+         self .add_pass (CastBoolToInt8Pass ())
108113        self .add_pass (ReplaceScalarWithTensorArgPassTOSABI ())
109114        self .add_pass (AnnotateDecomposedMatmulPass ())
110115        self .add_pass (QuantizeOperatorArguments ())
@@ -115,8 +120,10 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
115120        if  self .tosa_spec .is_U55_subset :
116121            self .add_pass (BroadcastArgsPass ())
117122        self .add_pass (DecomposeLinearPass ())
123+         self .add_pass (DecomposeAvgPool2d ())
118124        self .add_pass (ComputeConstantOpsAOT (exported_program ))
119125
126+         self .add_pass (DecomposeGroupedConv ())
120127        self .add_pass (RemoveClonePass ())
121128        self .add_pass (SizeAdjustConv2DPass ())
122129        self .add_pass (ConvertExpandCopyToRepeatPass ())
@@ -130,6 +137,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
130137
131138        self .add_pass (FuseViewCopyTransform ())
132139        self .add_pass (FuseConstantArgsPass (exported_program ))
140+         self .add_pass (AddBiasPass (exported_program ))
133141
134142        self .add_pass (InsertTableOpsPass (exported_program ))
135143        self .add_pass (FuseEqualPlaceholdersPass (exported_program ))
@@ -139,8 +147,10 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
139147        return  self ._transform (exported_program .graph_module )
140148
141149    def  _tosa_080_MI_pipeline (self , exported_program : ExportedProgram ) ->  GraphModule :
150+         self .add_pass (DecomposeRoundPass ())
142151        self .add_pass (DecomposeSqrtPass ())
143152        self .add_pass (ConvertIntPowToMuls ())
153+         self .add_pass (CastBoolToInt8Pass ())
144154        self .add_pass (ReplaceScalarWithTensorArgPassTOSAMI ())
145155        self .add_pass (DecomposeEmbeddingPass ())
146156        self .add_pass (FuseQuantizedActivationPass ())
@@ -172,8 +182,10 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
172182        self .add_pass (RetraceFoldedDtypesPass ())
173183        self .add_pass (UnsqueezeScalarPlaceholdersPass (exported_program ))
174184        self .add_pass (MatchArgRanksPass (exported_program ))
185+         self .add_pass (DecomposeAvgPool2d ())
175186        self .add_pass (ComputeConstantOpsAOT (exported_program ))
176187
188+         self .add_pass (DecomposeGroupedConv ())
177189        self .add_pass (RemoveClonePass ())
178190        self .add_pass (SizeAdjustConv2DPass ())
179191        self .add_pass (ConvertExpandCopyToRepeatPass ())
@@ -187,6 +199,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
187199
188200        self .add_pass (FuseViewCopyTransform ())
189201        self .add_pass (FuseConstantArgsPass (exported_program ))
202+         self .add_pass (AddBiasPass (exported_program ))
190203        self .add_pass (InsertTableOpsPass (exported_program ))
191204        self .add_pass (FuseEqualPlaceholdersPass (exported_program ))
192205        self .add_pass (AnnotateChannelsLastDimOrder ())
@@ -219,6 +232,8 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
219232        self .add_pass (InsertCastForOpsWithInt64InputPass ())
220233        self .add_pass (DecomposeEmbeddingPass ())
221234        self .add_pass (DecomposeScaledDotProductAttention ())
235+         self .add_pass (DecomposeRoundPass ())
236+         self .add_pass (CastBoolToInt8Pass ())
222237        self .add_pass (ReplaceScalarWithTensorArgPassTOSABI ())
223238        self .add_pass (ScalarsToAttributePass ())
224239        self .add_pass (DecomposeGroupNormPass ())
@@ -232,6 +247,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
232247        self .add_pass (DecomposeLinearVectorNormPass ())
233248        self .add_pass (DecomposeSqrtPass ())
234249        self .add_pass (DecomposeSiluPass ())
250+         self .add_pass (DecomposeAvgPool2d ())
235251
236252        if  self .tosa_spec .is_U55_subset :
237253            # Numerically stable softmax uses amax which is not supported on Ethos-U55 
0 commit comments