1717
1818test_data_suite = [
1919 # (test_name, test_data, dim)
20- ("zeros" , torch .zeros (10 , 10 , 10 , 10 ), 0 ),
21- ("zeros_neg_dim" , torch .zeros (10 , 10 , 10 , 10 ), - 4 ),
20+ ("zeros" , torch .zeros (10 , 8 , 5 , 2 ), 0 ),
21+ ("zeros_neg_dim" , torch .zeros (10 , 7 , 8 , 9 ), - 4 ),
2222 ("ones" , torch .ones (10 , 10 ), 1 ),
23- ("rand_neg_dim" , torch .rand (10 , 10 , 10 ), - 1 ),
24- ("rand" , torch .rand (10 , 10 , 10 , 10 ), 2 ),
25- ("rand_neg_dim" , torch .rand (10 , 10 , 2 , 3 ), - 2 ),
26- ("randn" , torch .randn (10 , 10 , 5 , 10 ), 3 ),
27- ("randn_neg_dim" , torch .randn (1 , 10 , 10 , 10 ), - 3 ),
23+ ("ones_neg_dim" , torch .ones (10 , 3 , 4 ), - 1 ),
24+ ("rand" , torch .rand (1 , 2 , 5 , 8 ), 2 ),
25+ ("rand_neg_dim" , torch .rand (2 , 10 , 8 , 10 ), - 2 ),
26+ ("randn" , torch .randn (10 , 10 , 10 , 10 ), 3 ),
27+ ("randn_neg_dim" , torch .randn (10 , 5 , 8 , 7 ), - 3 ),
28+ ]
29+ test_data_suite_u55 = [
30+ # (test_name, test_data, dim)
31+ ("ones" , torch .ones (10 , 10 ), 1 ),
32+ ("ones_neg_dim" , torch .ones (10 , 3 , 4 ), - 1 ),
33+ ("randn_neg_dim" , torch .randn (10 , 5 , 8 , 7 ), - 3 ),
34+ ]
35+
36+ test_data_suite_u55_xfails = [
37+ # (test_name, test_data, dim)
38+ ("zeros" , torch .zeros (10 , 8 , 5 , 2 ), 0 ),
39+ ("zeros_neg_dim" , torch .zeros (10 , 7 , 8 , 9 ), - 4 ),
40+ ("rand" , torch .rand (1 , 2 , 5 , 8 ), 2 ),
41+ ("rand_neg_dim" , torch .rand (2 , 10 , 8 , 10 ), - 2 ),
42+ ("randn" , torch .randn (10 , 10 , 10 , 10 ), 3 ),
2843]
2944
3045
@@ -135,7 +150,7 @@ def test_logsoftmax_tosa_BI(
135150 ):
136151 self ._test_logsoftmax_tosa_BI_pipeline (self .LogSoftmax (dim = dim ), (test_data ,))
137152
138- @parameterized .expand (test_data_suite )
153+ @parameterized .expand (test_data_suite_u55 )
139154 def test_logsoftmax_tosa_u55_BI (
140155 self ,
141156 test_name : str ,
@@ -146,13 +161,26 @@ def test_logsoftmax_tosa_u55_BI(
146161 self .LogSoftmax (dim = dim ), (test_data ,)
147162 )
148163
164+ # Expected to fail as this is not supported on u55.
165+ @parameterized .expand (test_data_suite_u55_xfails )
166+ @unittest .expectedFailure
167+ def test_logsoftmax_tosa_u55_BI_xfails (
168+ self ,
169+ test_name : str ,
170+ test_data : torch .Tensor ,
171+ dim : int ,
172+ ):
173+ self ._test_logsoftmax_tosa_u55_BI_pipeline (
174+ self .LogSoftmax (dim = dim ), (test_data ,)
175+ )
176+
149177 @parameterized .expand (test_data_suite )
150178 def test_logsoftmax_tosa_u85_BI (
151179 self ,
152180 test_name : str ,
153181 test_data : torch .Tensor ,
154182 dim : int ,
155183 ):
156- self ._test_logsoftmax_tosa_u55_BI_pipeline (
184+ self ._test_logsoftmax_tosa_u85_BI_pipeline (
157185 self .LogSoftmax (dim = dim ), (test_data ,)
158186 )
0 commit comments