77
88# pyre-unsafe
99from executorch .backends .arm ._passes import (
10+ AddBiasPass ,
1011 AnnotateChannelsLastDimOrder ,
1112 AnnotateDecomposedMatmulPass ,
1213 BroadcastArgsPass ,
14+ CastBoolToInt8Pass ,
1315 CastInt64BuffersToInt32Pass ,
1416 CastToInt32Pass ,
1517 ComputeConstantOpsAOT ,
2325 ConvertSplitToSlicePass ,
2426 ConvertSqueezesToViewPass ,
2527 ConvertToClampPass ,
28+ DecomposeAvgPool2d ,
2629 DecomposeCosineSimilarityPass ,
2730 DecomposeDivPass ,
2831 DecomposeEmbeddingPass ,
2932 DecomposeGeluPass ,
33+ DecomposeGroupedConv ,
3034 DecomposeGroupNormPass ,
3135 DecomposeLayerNormPass ,
3236 DecomposeLeakyReLUPass ,
3539 DecomposeMaxPool2DPass ,
3640 DecomposeMeanDimPass ,
3741 DecomposeNotEqualPass ,
42+ DecomposeRoundPass ,
3843 DecomposeSelectPass ,
3944 DecomposeSiluPass ,
4045 DecomposeSoftmaxPass ,
6368 UnsqueezeBeforeRepeatPass ,
6469 UnsqueezeScalarPlaceholdersPass ,
6570)
66-
6771from executorch .backends .arm .tosa_specification import (
6872 TosaLoweringContext ,
6973 TosaSpecification ,
@@ -105,6 +109,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
105109 if self .tosa_spec .is_U55_subset :
106110 self .add_pass (CastToInt32Pass ())
107111
112+ self .add_pass (CastBoolToInt8Pass ())
108113 self .add_pass (ReplaceScalarWithTensorArgPassTOSABI ())
109114 self .add_pass (AnnotateDecomposedMatmulPass ())
110115 self .add_pass (QuantizeOperatorArguments ())
@@ -115,8 +120,10 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
115120 if self .tosa_spec .is_U55_subset :
116121 self .add_pass (BroadcastArgsPass ())
117122 self .add_pass (DecomposeLinearPass ())
123+ self .add_pass (DecomposeAvgPool2d ())
118124 self .add_pass (ComputeConstantOpsAOT (exported_program ))
119125
126+ self .add_pass (DecomposeGroupedConv ())
120127 self .add_pass (RemoveClonePass ())
121128 self .add_pass (SizeAdjustConv2DPass ())
122129 self .add_pass (ConvertExpandCopyToRepeatPass ())
@@ -130,6 +137,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
130137
131138 self .add_pass (FuseViewCopyTransform ())
132139 self .add_pass (FuseConstantArgsPass (exported_program ))
140+ self .add_pass (AddBiasPass (exported_program ))
133141
134142 self .add_pass (InsertTableOpsPass (exported_program ))
135143 self .add_pass (FuseEqualPlaceholdersPass (exported_program ))
@@ -139,8 +147,10 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
139147 return self ._transform (exported_program .graph_module )
140148
141149 def _tosa_080_MI_pipeline (self , exported_program : ExportedProgram ) -> GraphModule :
150+ self .add_pass (DecomposeRoundPass ())
142151 self .add_pass (DecomposeSqrtPass ())
143152 self .add_pass (ConvertIntPowToMuls ())
153+ self .add_pass (CastBoolToInt8Pass ())
144154 self .add_pass (ReplaceScalarWithTensorArgPassTOSAMI ())
145155 self .add_pass (DecomposeEmbeddingPass ())
146156 self .add_pass (FuseQuantizedActivationPass ())
@@ -172,8 +182,10 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
172182 self .add_pass (RetraceFoldedDtypesPass ())
173183 self .add_pass (UnsqueezeScalarPlaceholdersPass (exported_program ))
174184 self .add_pass (MatchArgRanksPass (exported_program ))
185+ self .add_pass (DecomposeAvgPool2d ())
175186 self .add_pass (ComputeConstantOpsAOT (exported_program ))
176187
188+ self .add_pass (DecomposeGroupedConv ())
177189 self .add_pass (RemoveClonePass ())
178190 self .add_pass (SizeAdjustConv2DPass ())
179191 self .add_pass (ConvertExpandCopyToRepeatPass ())
@@ -187,6 +199,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
187199
188200 self .add_pass (FuseViewCopyTransform ())
189201 self .add_pass (FuseConstantArgsPass (exported_program ))
202+ self .add_pass (AddBiasPass (exported_program ))
190203 self .add_pass (InsertTableOpsPass (exported_program ))
191204 self .add_pass (FuseEqualPlaceholdersPass (exported_program ))
192205 self .add_pass (AnnotateChannelsLastDimOrder ())
@@ -219,6 +232,8 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
219232 self .add_pass (InsertCastForOpsWithInt64InputPass ())
220233 self .add_pass (DecomposeEmbeddingPass ())
221234 self .add_pass (DecomposeScaledDotProductAttention ())
235+ self .add_pass (DecomposeRoundPass ())
236+ self .add_pass (CastBoolToInt8Pass ())
222237 self .add_pass (ReplaceScalarWithTensorArgPassTOSABI ())
223238 self .add_pass (ScalarsToAttributePass ())
224239 self .add_pass (DecomposeGroupNormPass ())
@@ -232,6 +247,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
232247 self .add_pass (DecomposeLinearVectorNormPass ())
233248 self .add_pass (DecomposeSqrtPass ())
234249 self .add_pass (DecomposeSiluPass ())
250+ self .add_pass (DecomposeAvgPool2d ())
235251
236252 if self .tosa_spec .is_U55_subset :
237253 # Numerically stable softmax uses amax which is not supported on Ethos-U55
0 commit comments