44# LICENSE file in the root directory of this source tree.
55
66
7- import pytest
87import torch
98from executorch .backends .arm .test .common import parametrize
109from executorch .backends .cortex_m .test .tester import (
@@ -60,6 +59,16 @@ class CortexMTensorMul(Model):
6059 }
6160
6261
62+ class CortexMTensorMulBroadCast (Model ):
63+ ops_before_transforms = {
64+ "executorch_exir_dialects_edge__ops_aten_mul_Tensor" : 1 ,
65+ }
66+
67+ ops_after_transforms = {
68+ "executorch_exir_dialects_edge__ops_aten_mul_Tensor" : 1 ,
69+ }
70+
71+
6372test_cases = {
6473 "self_scalar" : McuTestCase (
6574 CortexMSelfMul (),
@@ -91,22 +100,22 @@ class CortexMTensorMul(Model):
91100 ),
92101 "tensor_scalar" : McuTestCase (
93102 CortexMScalarMul (),
94- (torch .ones (2 , 2 ), 1.0 ),
103+ (torch .ones (1 ), 1.0 ),
95104 ),
96105 "scalar_tensor" : McuTestCase (
97106 CortexMScalarMul (),
98- (1000.0 , torch .ones (2 , 2 )),
107+ (1000.0 , torch .ones (1 )),
99108 ),
100109 "broadcast_1" : McuTestCase (
101- CortexMTensorMul (),
110+ CortexMTensorMulBroadCast (),
102111 (torch .ones (1 ), torch .ones (2 , 2 , 2 , 2 )),
103112 ),
104113 "broadcast_2" : McuTestCase (
105- CortexMTensorMul (),
114+ CortexMTensorMulBroadCast (),
106115 (torch .ones ((2 , 1 , 1 , 1 )), torch .ones (1 )),
107116 ),
108117 "broadcast_3" : McuTestCase (
109- CortexMTensorMul (),
118+ CortexMTensorMulBroadCast (),
110119 (
111120 ramp_tensor (- 2 , 2 , (2 , 1 , 2 , 1 )),
112121 ramp_tensor (- 5 , 5 , (1 , 2 , 1 , 2 )),
@@ -115,17 +124,23 @@ class CortexMTensorMul(Model):
115124}
116125
117126
118- @pytest .mark .skip (reason = "Not implemented yet" )
119- @parametrize ("test_case" , test_cases )
127+ xfail_cases = {
128+ "self_scalar" : "lift_constant_tensor_pass assumes fake tensors for scalars" ,
129+ "scalar_scalar" : "lift_constant_tensor_pass assumes fake tensors for scalars" ,
130+ }
131+
132+
133+ @parametrize ("test_case" , test_cases , xfails = xfail_cases )
120134def test_dialect_mul (test_case ):
121135 tester = CortexMTester (test_case .model , test_case .example_inputs )
122136 tester .test_dialect (
123- test_case .model .ops_before_transforms , test_case .model .ops_after_transforms
137+ test_case .model .ops_before_transforms ,
138+ test_case .model .ops_after_transforms ,
139+ qtol = 1 ,
124140 )
125141
126142
127- @pytest .mark .skip (reason = "Not implemented yet" )
128- @parametrize ("test_case" , test_cases )
143+ @parametrize ("test_case" , test_cases , xfails = xfail_cases )
129144def test_implementation_mul (test_case ):
130145 tester = CortexMTester (test_case .model , test_case .example_inputs )
131- tester .test_implementation ()
146+ tester .test_implementation (qtol = 1 )
0 commit comments