1111
1212
1313class TestCat (unittest .TestCase ):
14- class Cat2 (torch .nn .Module ):
15- def forward (self , arg1 , arg2 ):
16- xs = [arg1 , arg2 ]
17- x = torch .cat (xs )
18- return x + x # Quantize by propagation.
19-
20- class Cat3 (torch .nn .Module ):
21- def forward (self , arg1 , arg2 , arg3 ):
22- xs = [arg1 , arg2 , arg3 ]
23- x = torch .cat (xs )
24- return x + x # Quantize by propagation.
25-
26- class Cat4 (torch .nn .Module ):
27- def forward (self , arg1 , arg2 , arg3 , arg4 ):
28- xs = [arg1 , arg2 , arg3 , arg4 ]
29- x = torch .cat (xs )
30- return x + x # Quantize by propagation.
31-
32- class Cat5 (torch .nn .Module ):
33- def forward (self , arg1 , arg2 , arg3 , arg4 , arg5 ):
34- xs = [arg1 , arg2 , arg3 , arg4 , arg5 ]
14+ class Cat (torch .nn .Module ):
15+ def forward (self , * args ):
16+ xs = [* args ]
3517 x = torch .cat (xs )
3618 return x + x # Quantize by propagation.
3719
@@ -84,7 +66,7 @@ def test_fp16_cat2(self):
8466 torch .randn (1 , 2 , 3 ).to (torch .float16 ),
8567 torch .randn (3 , 2 , 3 ).to (torch .float16 ),
8668 )
87- self ._test_cat (self .Cat2 (), inputs )
69+ self ._test_cat (self .Cat (), inputs )
8870
8971 def test_fp16_cat3 (self ):
9072 """
@@ -95,7 +77,7 @@ def test_fp16_cat3(self):
9577 torch .randn (3 , 2 , 3 ).to (torch .float16 ),
9678 torch .randn (2 , 2 , 3 ).to (torch .float16 ),
9779 )
98- self ._test_cat (self .Cat3 (), inputs )
80+ self ._test_cat (self .Cat (), inputs )
9981
10082 def test_fp16_cat4 (self ):
10183 """
@@ -107,15 +89,15 @@ def test_fp16_cat4(self):
10789 torch .randn (2 , 2 , 3 ).to (torch .float16 ),
10890 torch .randn (5 , 2 , 3 ).to (torch .float16 ),
10991 )
110- self ._test_cat (self .Cat4 (), inputs )
92+ self ._test_cat (self .Cat (), inputs )
11193
11294 def test_fp32_cat2 (self ):
11395 inputs = (torch .randn (1 , 2 , 3 ), torch .randn (3 , 2 , 3 ))
114- self ._test_cat (self .Cat2 (), inputs )
96+ self ._test_cat (self .Cat (), inputs )
11597
11698 def test_fp32_cat3 (self ):
11799 inputs = (torch .randn (1 , 2 , 3 ), torch .randn (3 , 2 , 3 ), torch .randn (2 , 2 , 3 ))
118- self ._test_cat (self .Cat3 (), inputs )
100+ self ._test_cat (self .Cat (), inputs )
119101
120102 def test_fp32_cat4 (self ):
121103 inputs = (
@@ -124,15 +106,25 @@ def test_fp32_cat4(self):
124106 torch .randn (2 , 2 , 3 ),
125107 torch .randn (5 , 2 , 3 ),
126108 )
127- self ._test_cat (self .Cat4 (), inputs )
109+ self ._test_cat (self .Cat (), inputs )
110+
111+ def test_fp32_cat5 (self ):
112+ inputs = (
113+ torch .randn (1 , 2 , 3 ),
114+ torch .randn (3 , 2 , 3 ),
115+ torch .randn (2 , 2 , 3 ),
116+ torch .randn (5 , 2 , 3 ),
117+ torch .randn (1 , 2 , 3 ),
118+ )
119+ self ._test_cat (self .Cat (), inputs )
128120
129121 def test_qs8_cat2 (self ):
130122 inputs = (torch .randn (1 , 2 , 3 ), torch .randn (3 , 2 , 3 ))
131- self ._test_cat (self .Cat2 (), inputs , cat_num = 2 , quant = True )
123+ self ._test_cat (self .Cat (), inputs , cat_num = 2 , quant = True )
132124
133125 def test_qs8_cat3 (self ):
134126 inputs = (torch .randn (1 , 2 , 3 ), torch .randn (3 , 2 , 3 ), torch .randn (2 , 2 , 3 ))
135- self ._test_cat (self .Cat3 (), inputs , cat_num = 3 , quant = True )
127+ self ._test_cat (self .Cat (), inputs , cat_num = 3 , quant = True )
136128
137129 def test_qs8_cat4 (self ):
138130 inputs = (
@@ -141,7 +133,7 @@ def test_qs8_cat4(self):
141133 torch .randn (2 , 2 , 3 ),
142134 torch .randn (5 , 2 , 3 ),
143135 )
144- self ._test_cat (self .Cat4 (), inputs , cat_num = 4 , quant = True )
136+ self ._test_cat (self .Cat (), inputs , cat_num = 4 , quant = True )
145137
146138 def test_fp32_cat_unsupported (self ):
147139 """
@@ -153,9 +145,10 @@ def test_fp32_cat_unsupported(self):
153145 torch .randn (2 , 2 , 3 ),
154146 torch .randn (5 , 2 , 3 ),
155147 torch .randn (1 , 2 , 3 ),
148+ torch .randn (2 , 2 , 3 ),
156149 )
157150 (
158- Tester (self .Cat5 (), inputs )
151+ Tester (self .Cat (), inputs )
159152 .export ()
160153 .check_count ({"torch.ops.aten.cat" : 1 })
161154 .to_edge_transform_and_lower ()
@@ -164,17 +157,18 @@ def test_fp32_cat_unsupported(self):
164157
165158 def test_fp32_cat_unsupported_legacy_mode (self ):
166159 """
167- XNNPACK only supports concatenating up to 4 values, so it should not delegate here.
160+ XNNPACK only supports concatenating up to 5 values, so it should not delegate here.
168161 """
169162 inputs = (
170163 torch .randn (1 , 2 , 3 ),
171164 torch .randn (3 , 2 , 3 ),
172165 torch .randn (2 , 2 , 3 ),
173166 torch .randn (5 , 2 , 3 ),
174167 torch .randn (1 , 2 , 3 ),
168+ torch .randn (6 , 2 , 3 ),
175169 )
176170 (
177- Tester (self .Cat5 (), inputs )
171+ Tester (self .Cat (), inputs )
178172 .export ()
179173 .check_count ({"torch.ops.aten.cat" : 1 })
180174 .to_edge ()
0 commit comments