1313from executorch .backends .arm .test import common
1414from executorch .backends .arm .test .tester .arm_tester import ArmTester
1515from executorch .exir import EdgeCompileConfig
16+ from executorch .exir .backend .compile_spec_schema import CompileSpec
1617from parameterized import parameterized
1718
1819
@@ -92,16 +93,17 @@ def _test_add_tosa_BI_pipeline(
9293 .run_method_and_compare_outputs (inputs = test_data , qtol = 1 )
9394 )
9495
95- def _test_add_u55_BI_pipeline (
96+ def _test_add_ethos_BI_pipeline (
9697 self ,
9798 module : torch .nn .Module ,
99+ compile_spec : CompileSpec ,
98100 test_data : Tuple [torch .Tensor ],
99101 ):
100102 tester = (
101103 ArmTester (
102104 module ,
103105 example_inputs = test_data ,
104- compile_spec = common . get_u55_compile_spec ( permute_memory_to_nhwc = True ) ,
106+ compile_spec = compile_spec ,
105107 )
106108 .quantize ()
107109 .export ()
@@ -114,8 +116,7 @@ def _test_add_u55_BI_pipeline(
114116 .serialize ()
115117 )
116118
117- if common .is_option_enabled ("corstone300" ):
118- tester .run_method_and_compare_outputs (qtol = 1 , inputs = test_data )
119+ return tester
119120
120121 @parameterized .expand (Add .test_parameters )
121122 def test_add_tosa_MI (self , test_data : torch .Tensor ):
@@ -130,7 +131,22 @@ def test_add_tosa_BI(self, test_data: torch.Tensor):
130131 @parameterized .expand (Add .test_parameters )
131132 def test_add_u55_BI (self , test_data : torch .Tensor ):
132133 test_data = (test_data ,)
133- self ._test_add_u55_BI_pipeline (self .Add (), test_data )
134+ tester = self ._test_add_ethos_BI_pipeline (
135+ self .Add (),
136+ common .get_u55_compile_spec (permute_memory_to_nhwc = True ),
137+ test_data ,
138+ )
139+ if common .is_option_enabled ("corstone300" ):
140+ tester .run_method_and_compare_outputs (qtol = 1 , inputs = test_data )
141+
142+ @parameterized .expand (Add .test_parameters )
143+ def test_add_u85_BI (self , test_data : torch .Tensor ):
144+ test_data = (test_data ,)
145+ self ._test_add_ethos_BI_pipeline (
146+ self .Add (),
147+ common .get_u85_compile_spec (permute_memory_to_nhwc = True ),
148+ test_data ,
149+ )
134150
135151 @parameterized .expand (Add2 .test_parameters )
136152 def test_add2_tosa_MI (self , operand1 : torch .Tensor , operand2 : torch .Tensor ):
@@ -145,4 +161,15 @@ def test_add2_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
145161 @parameterized .expand (Add2 .test_parameters )
146162 def test_add2_u55_BI (self , operand1 : torch .Tensor , operand2 : torch .Tensor ):
147163 test_data = (operand1 , operand2 )
148- self ._test_add_u55_BI_pipeline (self .Add2 (), test_data )
164+ tester = self ._test_add_ethos_BI_pipeline (
165+ self .Add2 (), common .get_u55_compile_spec (), test_data
166+ )
167+ if common .is_option_enabled ("corstone300" ):
168+ tester .run_method_and_compare_outputs (qtol = 1 , inputs = test_data )
169+
170+ @parameterized .expand (Add2 .test_parameters )
171+ def test_add2_u85_BI (self , operand1 : torch .Tensor , operand2 : torch .Tensor ):
172+ test_data = (operand1 , operand2 )
173+ self ._test_add_ethos_BI_pipeline (
174+ self .Add2 (), common .get_u85_compile_spec (), test_data
175+ )
0 commit comments