1010from executorch .backends .arm ._passes import (
1111 AnnotateChannelsLastDimOrder ,
1212 AnnotateDecomposedMatmulPass ,
13- CastInt64ToInt32Pass ,
13+ CastInt64BuffersToInt32Pass ,
14+ CastToInt32Pass ,
1415 ComputeConstantOpsAOT ,
1516 Conv1dUnsqueezePass ,
1617 ConvertAnyDefaultDimDimsPass ,
@@ -80,6 +81,8 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
8081 self .add_pass (ConvertToClampPass ())
8182 self .add_pass (ConvertMinMaxPass ())
8283 self .add_pass (ConvertAnyDefaultDimDimsPass ())
84+ if isinstance (self .tosa_spec , Tosa_0_80 ) and self .tosa_spec .is_U55_subset :
85+ self .add_pass (CastToInt32Pass ())
8386
8487 self .add_pass (ReplaceScalarWithTensorArgPass ())
8588 self .add_pass (AnnotateDecomposedMatmulPass ())
@@ -94,7 +97,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
9497 self .add_pass (SizeAdjustConv2DPass ())
9598 self .add_pass (ConvertExpandCopyToRepeatPass ())
9699 self .add_pass (UnsqueezeBeforeRepeatPass ())
97- self .add_pass (CastInt64ToInt32Pass (exported_program ))
100+ self .add_pass (CastInt64BuffersToInt32Pass (exported_program ))
98101 self .add_pass (KeepDimsFalseToSqueezePass ())
99102 self .add_pass (Conv1dUnsqueezePass (exported_program ))
100103 self .add_pass (DecomposeSelectPass ())
@@ -141,7 +144,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
141144 self .add_pass (SizeAdjustConv2DPass ())
142145 self .add_pass (ConvertExpandCopyToRepeatPass ())
143146 self .add_pass (UnsqueezeBeforeRepeatPass ())
144- self .add_pass (CastInt64ToInt32Pass (exported_program ))
147+ self .add_pass (CastInt64BuffersToInt32Pass (exported_program ))
145148 self .add_pass (KeepDimsFalseToSqueezePass ())
146149 self .add_pass (Conv1dUnsqueezePass (exported_program ))
147150 self .add_pass (DecomposeSelectPass ())
0 commit comments