44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ # pyre-unsafe
8+
79import unittest
810
911import torch
1012from executorch .backends .xnnpack .test .tester import Tester
1113
1214
1315class TestCat (unittest .TestCase ):
14- class Cat2 (torch .nn .Module ):
15- def forward (self , arg1 , arg2 ):
16- xs = [arg1 , arg2 ]
17- x = torch .cat (xs )
18- return x + x # Quantize by propagation.
19-
20- class Cat3 (torch .nn .Module ):
21- def forward (self , arg1 , arg2 , arg3 ):
22- xs = [arg1 , arg2 , arg3 ]
23- x = torch .cat (xs )
24- return x + x # Quantize by propagation.
25-
26- class Cat4 (torch .nn .Module ):
27- def forward (self , arg1 , arg2 , arg3 , arg4 ):
28- xs = [arg1 , arg2 , arg3 , arg4 ]
29- x = torch .cat (xs )
30- return x + x # Quantize by propagation.
31-
32- class Cat5 (torch .nn .Module ):
33- def forward (self , arg1 , arg2 , arg3 , arg4 , arg5 ):
34- xs = [arg1 , arg2 , arg3 , arg4 , arg5 ]
16+ class Cat (torch .nn .Module ):
17+ def forward (self , * args ):
18+ xs = [* args ]
3519 x = torch .cat (xs )
3620 return x + x # Quantize by propagation.
3721
@@ -84,7 +68,7 @@ def test_fp16_cat2(self):
8468 torch .randn (1 , 2 , 3 ).to (torch .float16 ),
8569 torch .randn (3 , 2 , 3 ).to (torch .float16 ),
8670 )
87- self ._test_cat (self .Cat2 (), inputs )
71+ self ._test_cat (self .Cat (), inputs )
8872
8973 def test_fp16_cat3 (self ):
9074 """
@@ -95,7 +79,7 @@ def test_fp16_cat3(self):
9579 torch .randn (3 , 2 , 3 ).to (torch .float16 ),
9680 torch .randn (2 , 2 , 3 ).to (torch .float16 ),
9781 )
98- self ._test_cat (self .Cat3 (), inputs )
82+ self ._test_cat (self .Cat (), inputs )
9983
10084 def test_fp16_cat4 (self ):
10185 """
@@ -107,15 +91,15 @@ def test_fp16_cat4(self):
10791 torch .randn (2 , 2 , 3 ).to (torch .float16 ),
10892 torch .randn (5 , 2 , 3 ).to (torch .float16 ),
10993 )
110- self ._test_cat (self .Cat4 (), inputs )
94+ self ._test_cat (self .Cat (), inputs )
11195
11296 def test_fp32_cat2 (self ):
11397 inputs = (torch .randn (1 , 2 , 3 ), torch .randn (3 , 2 , 3 ))
114- self ._test_cat (self .Cat2 (), inputs )
98+ self ._test_cat (self .Cat (), inputs )
11599
116100 def test_fp32_cat3 (self ):
117101 inputs = (torch .randn (1 , 2 , 3 ), torch .randn (3 , 2 , 3 ), torch .randn (2 , 2 , 3 ))
118- self ._test_cat (self .Cat3 (), inputs )
102+ self ._test_cat (self .Cat (), inputs )
119103
120104 def test_fp32_cat4 (self ):
121105 inputs = (
@@ -124,15 +108,25 @@ def test_fp32_cat4(self):
124108 torch .randn (2 , 2 , 3 ),
125109 torch .randn (5 , 2 , 3 ),
126110 )
127- self ._test_cat (self .Cat4 (), inputs )
111+ self ._test_cat (self .Cat (), inputs )
112+
113+ def test_fp32_cat5 (self ):
114+ inputs = (
115+ torch .randn (1 , 2 , 3 ),
116+ torch .randn (3 , 2 , 3 ),
117+ torch .randn (2 , 2 , 3 ),
118+ torch .randn (5 , 2 , 3 ),
119+ torch .randn (1 , 2 , 3 ),
120+ )
121+ self ._test_cat (self .Cat (), inputs )
128122
129123 def test_qs8_cat2 (self ):
130124 inputs = (torch .randn (1 , 2 , 3 ), torch .randn (3 , 2 , 3 ))
131- self ._test_cat (self .Cat2 (), inputs , cat_num = 2 , quant = True )
125+ self ._test_cat (self .Cat (), inputs , cat_num = 2 , quant = True )
132126
133127 def test_qs8_cat3 (self ):
134128 inputs = (torch .randn (1 , 2 , 3 ), torch .randn (3 , 2 , 3 ), torch .randn (2 , 2 , 3 ))
135- self ._test_cat (self .Cat3 (), inputs , cat_num = 3 , quant = True )
129+ self ._test_cat (self .Cat (), inputs , cat_num = 3 , quant = True )
136130
137131 def test_qs8_cat4 (self ):
138132 inputs = (
@@ -141,7 +135,7 @@ def test_qs8_cat4(self):
141135 torch .randn (2 , 2 , 3 ),
142136 torch .randn (5 , 2 , 3 ),
143137 )
144- self ._test_cat (self .Cat4 (), inputs , cat_num = 4 , quant = True )
138+ self ._test_cat (self .Cat (), inputs , cat_num = 4 , quant = True )
145139
146140 def test_fp32_cat_unsupported (self ):
147141 """
@@ -153,9 +147,10 @@ def test_fp32_cat_unsupported(self):
153147 torch .randn (2 , 2 , 3 ),
154148 torch .randn (5 , 2 , 3 ),
155149 torch .randn (1 , 2 , 3 ),
150+ torch .randn (2 , 2 , 3 ),
156151 )
157152 (
158- Tester (self .Cat5 (), inputs )
153+ Tester (self .Cat (), inputs )
159154 .export ()
160155 .check_count ({"torch.ops.aten.cat" : 1 })
161156 .to_edge_transform_and_lower ()
@@ -164,17 +159,18 @@ def test_fp32_cat_unsupported(self):
164159
165160 def test_fp32_cat_unsupported_legacy_mode (self ):
166161 """
167- XNNPACK only supports concatenating up to 4 values, so it should not delegate here.
162+ XNNPACK only supports concatenating up to 5 values, so it should not delegate here.
168163 """
169164 inputs = (
170165 torch .randn (1 , 2 , 3 ),
171166 torch .randn (3 , 2 , 3 ),
172167 torch .randn (2 , 2 , 3 ),
173168 torch .randn (5 , 2 , 3 ),
174169 torch .randn (1 , 2 , 3 ),
170+ torch .randn (6 , 2 , 3 ),
175171 )
176172 (
177- Tester (self .Cat5 (), inputs )
173+ Tester (self .Cat (), inputs )
178174 .export ()
179175 .check_count ({"torch.ops.aten.cat" : 1 })
180176 .to_edge ()
0 commit comments