1818from executorch .backends .arm .test import common
1919from executorch .backends .arm .test .tester .arm_tester import ArmTester
2020from executorch .backends .xnnpack .test .tester .tester import Quantize
21+ from executorch .exir .backend .compile_spec_schema import CompileSpec
2122from parameterized import parameterized
2223from torchvision .ops import Permute
2324
2425test_data_suite = [
2526 # (test_name,test_data,dims)
26- ("zeros " , torch .zeros (10 , 10 , 10 , 10 ), [1 , 0 , 3 , 2 ]),
27- ("ones " , torch .ones (10 , 10 , 10 , 10 ), [3 , 1 , 0 , 2 ]),
28- ("rand " , torch .rand (10 , 10 , 10 , 10 ) - 0.5 , [0 , 2 , 3 , 1 ]),
29- ("randn_pos " , torch .randn ( 10 , 10 , 10 ) + 10 , [2 , 0 , 1 ]),
30- ("randn_neg " , torch .randn ( 10 , 10 , 10 ) - 10 , [1 , 2 , 0 ]),
31- ("ramp " , torch .arange ( - 16 , 16 , 0.2 ), [0 ]),
27+ ("rank_2 " , torch .rand (10 , 10 ), [1 , 0 ]),
28+ ("rank_3 " , torch .rand (10 , 10 , 10 ), [2 , 0 , 1 ]),
29+ ("rank_3 " , torch .rand (10 , 10 , 10 ) , [1 , 2 , 0 ]),
30+ ("rank_4 " , torch .rand ( 1 , 5 , 1 , 10 ) , [0 , 2 , 3 , 1 ]),
31+ ("rank_4 " , torch .rand ( 1 , 2 , 5 , 10 ) , [1 , 0 , 2 , 3 ]),
32+ ("rank_4 " , torch .rand ( 1 , 10 , 10 , 5 ), [2 , 0 , 1 , 3 ]),
3233]
3334
3435
@@ -46,13 +47,18 @@ def forward(self, x):
4647 return self .permute (x )
4748
4849 def _test_permute_tosa_MI_pipeline (
49- self , module : torch .nn .Module , test_data : Tuple [torch .tensor ]
50+ self ,
51+ module : torch .nn .Module ,
52+ test_data : Tuple [torch .tensor ],
53+ permute_memory_to_nhwc : bool ,
5054 ):
5155 (
5256 ArmTester (
5357 module ,
5458 example_inputs = test_data ,
55- compile_spec = common .get_tosa_compile_spec (),
59+ compile_spec = common .get_tosa_compile_spec (
60+ permute_memory_to_nhwc = permute_memory_to_nhwc
61+ ),
5662 )
5763 .export ()
5864 .check (["torch.ops.aten.permute.default" ])
@@ -87,15 +93,18 @@ def _test_permute_tosa_BI_pipeline(
8793 .run_method_and_compare_outputs (inputs = test_data )
8894 )
8995
90- def _test_permute_tosa_u55_BI_pipeline (
91- self , module : torch .nn .Module , test_data : Tuple [torch .tensor ]
96+ def _test_permute_ethos_BI_pipeline (
97+ self ,
98+ module : torch .nn .Module ,
99+ compile_spec : CompileSpec ,
100+ test_data : Tuple [torch .Tensor ],
92101 ):
93102 quantizer = ArmQuantizer ().set_io (get_symmetric_quantization_config ())
94103 (
95104 ArmTester (
96105 module ,
97106 example_inputs = test_data ,
98- compile_spec = common . get_u55_compile_spec () ,
107+ compile_spec = compile_spec ,
99108 )
100109 .quantize (Quantize (quantizer , get_symmetric_quantization_config ()))
101110 .export ()
@@ -106,24 +115,38 @@ def _test_permute_tosa_u55_BI_pipeline(
106115 .check_not (["executorch_exir_dialects_edge__ops_aten_permute_default" ])
107116 .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
108117 .to_executorch ()
118+ .serialize ()
109119 )
110120
111121 @parameterized .expand (test_data_suite )
112122 def test_permute_tosa_MI (
113123 self , test_name : str , test_data : torch .Tensor , dims : list [int ]
114124 ):
115- self ._test_permute_tosa_MI_pipeline (self .Permute (dims = dims ), (test_data ,))
125+ self ._test_permute_tosa_MI_pipeline (self .Permute (dims = dims ), (test_data ,), True )
126+ self ._test_permute_tosa_MI_pipeline (
127+ self .Permute (dims = dims ), (test_data ,), False
128+ )
116129
117130 @parameterized .expand (test_data_suite )
118131 def test_permute_tosa_BI (
119132 self , test_name : str , test_data : torch .Tensor , dims : list [int ]
120133 ):
121134 self ._test_permute_tosa_BI_pipeline (self .Permute (dims = dims ), (test_data ,))
122135
123- # Expected to fail as Permute is not supported by the NPU
124- @parameterized .expand (test_data_suite )
136+ # Expected to fail as TOSA.Transpose is not supported by Ethos-U55.
137+ @parameterized .expand (test_data_suite [ 0 : 1 ] )
125138 @unittest .expectedFailure
126- def test_permute_tosa_u55_BI (
139+ def test_permute_u55_BI (
127140 self , test_name : str , test_data : torch .Tensor , dims : list [int ]
128141 ):
129- self ._test_permute_tosa_u55_BI_pipeline (self .Permute (dims = dims ), (test_data ,))
142+ self ._test_permute_ethos_BI_pipeline (
143+ self .Permute (dims = dims ), common .get_u55_compile_spec (), (test_data ,)
144+ )
145+
146+ @parameterized .expand (test_data_suite )
147+ def test_permute_u85_BI (
148+ self , test_name : str , test_data : torch .Tensor , dims : list [int ]
149+ ):
150+ self ._test_permute_ethos_BI_pipeline (
151+ self .Permute (dims = dims ), common .get_u85_compile_spec (), (test_data ,)
152+ )
0 commit comments