1010from executorch .backends .arm ._passes import (
1111 AnnotateChannelsLastDimOrder ,
1212 AnnotateDecomposedMatmulPass ,
13+ BroadcastArgsPass ,
1314 CastInt64BuffersToInt32Pass ,
1415 CastToInt32Pass ,
1516 ComputeConstantOpsAOT ,
2930 DecomposeLayerNormPass ,
3031 DecomposeLeakyReLUPass ,
3132 DecomposeLinearPass ,
33+ DecomposeLinearVectorNormPass ,
3234 DecomposeMeanDimPass ,
3335 DecomposeNotEqualPass ,
3436 DecomposeSelectPass ,
@@ -86,6 +88,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
8688 self .add_pass (ConvertSplitToSlicePass ())
8789 self .add_pass (ConvertMmToBmmPass ())
8890 self .add_pass (DecomposeLinearPass ())
91+ self .add_pass (DecomposeLinearVectorNormPass ())
8992 self .add_pass (DecomposeMeanDimPass ())
9093 self .add_pass (ConvertFullLikeToFullPass ())
9194 self .add_pass (ConvertToClampPass ())
@@ -102,6 +105,8 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
102105 self .add_pass (RetraceFoldedDtypesPass ())
103106 self .add_pass (UnsqueezeScalarPlaceholdersPass (exported_program ))
104107 self .add_pass (MatchArgRanksPass (exported_program ))
108+ if self .tosa_spec .is_U55_subset :
109+ self .add_pass (BroadcastArgsPass ())
105110 self .add_pass (ComputeConstantOpsAOT (exported_program ))
106111
107112 self .add_pass (RemoveClonePass ())
@@ -133,6 +138,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
133138 self .add_pass (FuseBatchnorm2DPass (exported_program ))
134139 self .add_pass (ConvertMmToBmmPass ())
135140 self .add_pass (DecomposeLinearPass ())
141+ self .add_pass (DecomposeLinearVectorNormPass ())
136142 self .add_pass (DecomposeLeakyReLUPass ())
137143 self .add_pass (DecomposeBatchNormPass ())
138144 self .add_pass (DecomposeLayerNormPass ())
@@ -207,6 +213,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
207213 self .add_pass (DecomposeCosineSimilarityPass ())
208214 self .add_pass (DecomposeDivPass ())
209215 self .add_pass (DecomposeLeakyReLUPass ())
216+ self .add_pass (DecomposeLinearVectorNormPass ())
210217 self .add_pass (DecomposeSqrtPass ())
211218 self .add_pass (DecomposeSiluPass ())
212219
0 commit comments