@@ -41,7 +41,7 @@ def forward(self, x, y):
4141 class BMMSingleInput (torch .nn .Module ):
4242 test_parameters = [
4343 (torch .rand (20 , 3 , 3 ),),
44- (torch .ones (2 , 128 , 128 ),),
44+ (torch .rand (2 , 128 , 128 ),),
4545 (10000 * torch .randn (4 , 25 , 25 ),),
4646 (5 + 5 * torch .randn (3 , 64 , 64 ),),
4747 ]
@@ -96,7 +96,7 @@ def _test_bmm_ethosu_BI_pipeline(
9696 compile_spec : CompileSpec ,
9797 test_data : Tuple [torch .Tensor , ...],
9898 ):
99- (
99+ tester = (
100100 ArmTester (
101101 module ,
102102 example_inputs = test_data ,
@@ -110,7 +110,10 @@ def _test_bmm_ethosu_BI_pipeline(
110110 .partition ()
111111 .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
112112 .to_executorch ()
113+ .serialize ()
113114 )
115+ if common .is_option_enabled ("corstone300" ):
116+ tester .run_method_and_compare_outputs (inputs = test_data , qtol = 1 )
114117
115118 @parameterized .expand (BMM .test_parameters )
116119 def test_bmm_tosa_MI (self , operand1 : torch .Tensor , operand2 : torch .Tensor ):
@@ -143,9 +146,20 @@ def test_bmm_single_input_tosa_BI(self, operand1: torch.Tensor):
143146 self ._test_bmm_tosa_BI_pipeline (self .BMMSingleInput (), test_data )
144147
145148 @parameterized .expand (BMM .test_parameters )
149+ @unittest .expectedFailure
146150 def test_bmm_u55_BI (self , operand1 : torch .Tensor , operand2 : torch .Tensor ):
147151 test_data = (operand1 , operand2 )
148- self ._test_bmm_tosa_BI_pipeline (self .BMM (), test_data )
152+ self ._test_bmm_ethosu_BI_pipeline (
153+ self .BMM (), common .get_u55_compile_spec (), test_data
154+ )
155+
156+ @parameterized .expand (BMM .test_parameters )
157+ @common .expectedFailureOnFVP
158+ def test_bmm_u85_BI (self , operand1 : torch .Tensor , operand2 : torch .Tensor ):
159+ test_data = (operand1 , operand2 )
160+ self ._test_bmm_ethosu_BI_pipeline (
161+ self .BMM (), common .get_u85_compile_spec (), test_data
162+ )
149163
150164 # Expected to fail with error: Warning, unsupported fusing of TOSA Rescale previous operator is of type: Memcpy
151165 @parameterized .expand (BMMSingleInput .test_parameters )
@@ -156,7 +170,9 @@ def test_bmm_single_input_u55_BI(self, operand1: torch.Tensor):
156170 self .BMMSingleInput (), common .get_u55_compile_spec (), test_data
157171 )
158172
173+ # Numerical issues on FVP, MLETORCH 534
159174 @parameterized .expand (BMMSingleInput .test_parameters )
175+ @common .expectedFailureOnFVP
160176 def test_bmm_single_input_u85_BI (self , operand1 : torch .Tensor ):
161177 test_data = (operand1 ,)
162178 self ._test_bmm_ethosu_BI_pipeline (
0 commit comments