2525from parameterized import parameterized
2626
2727
28- class TestSimpleView (unittest .TestCase ):
28+ class TestView (unittest .TestCase ):
2929 """Tests the view operation."""
3030
3131 class View (torch .nn .Module ):
3232
33- sizes = [10 , 15 , 50 , 100 ]
34- test_parameters = [(torch .ones (n ),) for n in sizes ]
35-
36- def forward (self , x : torch .Tensor ):
37- return x .view (- 1 , 5 )
33+ needs_transpose_tests = [
34+ (torch .rand (100 ), (1 , - 1 , 5 , 2 )),
35+ (torch .rand (10 , 2 , 1 , 5 ), (1 , - 1 , 5 , 2 )),
36+ (torch .rand (1 , 2 , 1 , 9 ), (3 , 1 , 3 , 2 )),
37+ (torch .rand (2 , 1 , 1 , 9 ), (3 , 2 , 3 , 1 )),
38+ (torch .rand (2 , 50 , 2 , 1 ), (1 , 200 )),
39+ (torch .rand (2 , 5 , 2 , 3 ), (1 , 15 , 4 )),
40+ ]
41+
42+ no_transpose_tests = [
43+ (torch .rand (2 , 1 , 1 , 9 ), (3 , 1 , 3 , 2 )),
44+ (torch .rand (5 , 10 , 1 , 1 ), (25 , 2 , 1 , 1 )),
45+ (torch .rand (10 , 2 ), (1 , 1 , 5 , 4 )),
46+ (torch .rand (10 , 10 ), (5 , 1 , 5 , 4 )),
47+ (torch .rand (1 , 1 , 1 , 10 ), (1 , 1 , 10 , 1 )),
48+ (torch .rand (1 , 1 , 5 , 10 ), (1 , 1 , 50 , 1 )),
49+ (torch .rand (5 , 10 , 1 , 1 ), (1 , 25 , 2 )),
50+ (torch .rand (2 , 50 , 1 , 1 ), (1 , 100 )),
51+ ]
52+
53+ def forward (self , x : torch .Tensor , new_shape ):
54+ return x .view (new_shape )
3855
3956 def _test_view_tosa_MI_pipeline (
4057 self , module : torch .nn .Module , test_data : torch .Tensor
@@ -82,11 +99,7 @@ def _test_view_ethos_BI_pipeline(
8299 ):
83100 quantizer = ArmQuantizer ().set_io (get_symmetric_quantization_config ())
84101 (
85- ArmTester (
86- module ,
87- example_inputs = test_data ,
88- compile_spec = common .get_u55_compile_spec (),
89- )
102+ ArmTester (module , example_inputs = test_data , compile_spec = compile_spec )
90103 .quantize (Quantize (quantizer , get_symmetric_quantization_config ()))
91104 .export ()
92105 .check_count ({"torch.ops.aten.view.default" : 1 })
@@ -110,18 +123,23 @@ def _test_view_u85_BI_pipeline(
110123 common .get_u85_compile_spec (), module , test_data
111124 )
112125
113- @parameterized .expand (View .test_parameters )
114- def test_view_tosa_MI (self , test_tensor : torch .Tensor ):
115- self ._test_view_tosa_MI_pipeline (self .View (), (test_tensor ,))
126+ @parameterized .expand (View .needs_transpose_tests + View .no_transpose_tests )
127+ def test_view_tosa_MI (self , test_tensor : torch .Tensor , new_shape ):
128+ self ._test_view_tosa_MI_pipeline (self .View (), (test_tensor , new_shape ))
129+
130+ @parameterized .expand (View .needs_transpose_tests + View .no_transpose_tests )
131+ def test_view_tosa_BI (self , test_tensor : torch .Tensor , new_shape ):
132+ self ._test_view_tosa_BI_pipeline (self .View (), (test_tensor , new_shape ))
116133
117- @parameterized .expand (View .test_parameters )
118- def test_view_tosa_BI (self , test_tensor : torch .Tensor ):
119- self ._test_view_tosa_BI_pipeline (self .View (), (test_tensor ,))
134+ @parameterized .expand (View .no_transpose_tests )
135+ def test_view_u55_BI (self , test_tensor : torch .Tensor , new_shape ):
136+ self ._test_view_u55_BI_pipeline (self .View (), (test_tensor , new_shape ))
120137
121- @parameterized .expand (View .test_parameters )
122- def test_view_u55_BI (self , test_tensor : torch .Tensor ):
123- self ._test_view_u55_BI_pipeline (self .View (), (test_tensor ,))
138+ @parameterized .expand (View .needs_transpose_tests )
139+ @unittest .expectedFailure
140+ def test_view_transpose_u55_BI (self , test_tensor : torch .Tensor , new_shape ):
141+ self ._test_view_u55_BI_pipeline (self .View (), (test_tensor , new_shape ))
124142
125- @parameterized .expand (View .test_parameters )
126- def test_view_u85_BI (self , test_tensor : torch .Tensor ):
127- self ._test_view_u85_BI_pipeline (self .View (), (test_tensor ,))
143+ @parameterized .expand (View .needs_transpose_tests + View . no_transpose_tests )
144+ def test_view_u85_BI (self , test_tensor : torch .Tensor , new_shape ):
145+ self ._test_view_u85_BI_pipeline (self .View (), (test_tensor , new_shape ))
0 commit comments