@@ -37,15 +37,17 @@ class Expand(torch.nn.Module):
3737 test_parameters = [
3838 (torch .rand (1 ), (2 ,)),
3939 (torch .randn (1 , 4 ), (1 , - 1 )),
40- (torch .rand (1 , 1 , 2 , 2 ), (4 , 3 , - 1 , 2 )),
4140 (torch .randn (1 ), (2 , 2 , 4 )),
42- (torch .rand ( 3 , 2 , 4 , 1 ), (- 1 , - 1 , - 1 , 3 )),
41+ (torch .randn ( 1 , 1 , 1 , 5 ), (1 , 4 , - 1 , - 1 )),
4342 (torch .randn (1 , 1 , 192 ), (1 , - 1 , - 1 )),
43+ (torch .randn (1 , 1 ), (1 , 2 , 2 , 4 )),
44+ (torch .randn (1 , 1 ), (2 , 2 , 2 , 4 )),
4445 (torch .randn (10 , 1 , 1 , 97 ), (- 1 , 4 , - 1 , - 1 )),
46+ (torch .rand (1 , 1 , 2 , 2 ), (4 , 3 , - 1 , 2 )),
4547 ]
4648
47- def forward (self , x : torch .Tensor , multiples : Sequence ):
48- return x .expand (multiples )
49+ def forward (self , x : torch .Tensor , m : Sequence ):
50+ return x .expand (m )
4951
5052 def _test_expand_tosa_MI_pipeline (self , module : torch .nn .Module , test_data : Tuple ):
5153 (
@@ -113,20 +115,34 @@ def test_expand_tosa_MI(self, test_input, multiples):
113115 def test_expand_tosa_BI (self , test_input , multiples ):
114116 self ._test_expand_tosa_BI_pipeline (self .Expand (), (test_input , multiples ))
115117
116- # Mismatch in provided number of inputs and model signature, MLETORCH 519
117- @parameterized .expand (Expand .test_parameters )
118+ @parameterized .expand (Expand .test_parameters [:- 3 ])
118119 @pytest .mark .corstone_fvp
119- @conftest .expectedFailureOnFVP
120120 def test_expand_u55_BI (self , test_input , multiples ):
121121 self ._test_expand_ethosu_BI_pipeline (
122122 common .get_u55_compile_spec (), self .Expand (), (test_input , multiples )
123123 )
124124
125- # Mismatch in provided number of inputs and model signature, MLETORCH 519
126- @parameterized .expand (Expand .test_parameters )
125+ # MLETORCH-629: Expand does not work on FVP with batch>1
126+ @parameterized .expand (Expand .test_parameters [ - 3 :] )
127127 @pytest .mark .corstone_fvp
128128 @conftest .expectedFailureOnFVP
129+ def test_expand_u55_BI_xfails (self , test_input , multiples ):
130+ self ._test_expand_ethosu_BI_pipeline (
131+ common .get_u55_compile_spec (), self .Expand (), (test_input , multiples )
132+ )
133+
134+ @parameterized .expand (Expand .test_parameters [:- 3 ])
135+ @pytest .mark .corstone_fvp
129136 def test_expand_u85_BI (self , test_input , multiples ):
130137 self ._test_expand_ethosu_BI_pipeline (
131138 common .get_u85_compile_spec (), self .Expand (), (test_input , multiples )
132139 )
140+
141+ # MLETORCH-629: Expand does not work on FVP with batch>1
142+ @parameterized .expand (Expand .test_parameters [- 3 :])
143+ @pytest .mark .corstone_fvp
144+ @conftest .expectedFailureOnFVP
145+ def test_expand_u85_BI_xfails (self , test_input , multiples ):
146+ self ._test_expand_ethosu_BI_pipeline (
147+ common .get_u85_compile_spec (), self .Expand (), (test_input , multiples )
148+ )
0 commit comments