1414
1515class TestCat (unittest .TestCase ):
1616 class Cat (torch .nn .Module ):
17+ def __init__ (self , dim = 0 ):
18+ super ().__init__ ()
19+ self .dim = dim
20+
1721 def forward (self , * args ):
1822 xs = [* args ]
19- x = torch .cat (xs )
23+ x = torch .cat (xs , dim = self . dim )
2024 return x + x # Quantize by propagation.
2125
2226 def _test_cat (self , module , inputs , cat_num = 1 , quant = False , quant_ops = 2 ):
@@ -27,7 +31,6 @@ def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2):
2731 tester .quantize ()
2832
2933 tester .export ().check_count ({"torch.ops.aten.cat" : 1 })
30- tester .dump_artifact ()
3134
3235 if quant :
3336 # Expect multiple quantize ops - one per input, cat, and add.
@@ -93,6 +96,29 @@ def test_fp16_cat4(self):
9396 )
9497 self ._test_cat (self .Cat (), inputs )
9598
99+ def test_fp16_cat5 (self ):
100+ """
101+ Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
102+ """
103+ inputs = (
104+ torch .randn (1 , 2 , 3 ).to (torch .float16 ),
105+ torch .randn (3 , 2 , 3 ).to (torch .float16 ),
106+ torch .randn (2 , 2 , 3 ).to (torch .float16 ),
107+ torch .randn (5 , 2 , 3 ).to (torch .float16 ),
108+ torch .randn (5 , 2 , 3 ).to (torch .float16 ),
109+ )
110+ self ._test_cat (self .Cat (), inputs )
111+
112+ def test_fp16_cat_gt_5 (self ):
113+ """
114+ Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
115+ """
116+ for num_inputs in range (6 , 10 ):
117+ inputs = []
118+ for _ in range (num_inputs ):
119+ inputs .append (torch .randn (1 , 2 , 3 ).to (torch .float16 ))
120+ self ._test_cat (self .Cat (), tuple (inputs ))
121+
96122 def test_fp32_cat2 (self ):
97123 inputs = (torch .randn (1 , 2 , 3 ), torch .randn (3 , 2 , 3 ))
98124 self ._test_cat (self .Cat (), inputs )
@@ -120,6 +146,13 @@ def test_fp32_cat5(self):
120146 )
121147 self ._test_cat (self .Cat (), inputs )
122148
149+ def test_fp32_cat_gt_5 (self ):
150+ for num_inputs in range (6 , 10 ):
151+ inputs = []
152+ for _ in range (num_inputs ):
153+ inputs .append (torch .randn (1 , 2 , 3 ))
154+ self ._test_cat (self .Cat (), tuple (inputs ))
155+
123156 def test_qs8_cat2 (self ):
124157 inputs = (torch .randn (1 , 2 , 3 ), torch .randn (3 , 2 , 3 ))
125158 self ._test_cat (self .Cat (), inputs , cat_num = 2 , quant = True )
@@ -137,46 +170,22 @@ def test_qs8_cat4(self):
137170 )
138171 self ._test_cat (self .Cat (), inputs , cat_num = 4 , quant = True )
139172
140- def test_fp32_cat_unsupported (self ):
141- """
142- XNNPACK only supports concatenating up to 4 values, so it should not delegate here.
143- """
173+ def test_qs8_cat5 (self ):
144174 inputs = (
145175 torch .randn (1 , 2 , 3 ),
146176 torch .randn (3 , 2 , 3 ),
147177 torch .randn (2 , 2 , 3 ),
148178 torch .randn (5 , 2 , 3 ),
149- torch .randn (1 , 2 , 3 ),
150- torch .randn (2 , 2 , 3 ),
151- )
152- (
153- Tester (self .Cat (), inputs )
154- .export ()
155- .check_count ({"torch.ops.aten.cat" : 1 })
156- .to_edge_transform_and_lower ()
157- .check_count ({"executorch_exir_dialects_edge__ops_aten_cat" : 1 })
158- )
159-
160- def test_fp32_cat_unsupported_legacy_mode (self ):
161- """
162- XNNPACK only supports concatenating up to 5 values, so it should not delegate here.
163- """
164- inputs = (
165- torch .randn (1 , 2 , 3 ),
166- torch .randn (3 , 2 , 3 ),
167- torch .randn (2 , 2 , 3 ),
168179 torch .randn (5 , 2 , 3 ),
169- torch .randn (1 , 2 , 3 ),
170- torch .randn (6 , 2 , 3 ),
171- )
172- (
173- Tester (self .Cat (), inputs )
174- .export ()
175- .check_count ({"torch.ops.aten.cat" : 1 })
176- .to_edge ()
177- .partition ()
178- .check_count ({"executorch_exir_dialects_edge__ops_aten_cat" : 1 })
179180 )
181+ self ._test_cat (self .Cat (), inputs , cat_num = 5 , quant = True )
182+
183+ def test_qs8_cat_gt_5 (self ):
184+ for num_inputs in range (6 , 10 ):
185+ inputs = []
186+ for _ in range (num_inputs ):
187+ inputs .append (torch .randn (1 , 2 , 3 ))
188+ self ._test_cat (self .Cat (), tuple (inputs ), cat_num = num_inputs , quant = True )
180189
181190 class CatNegativeDim (torch .nn .Module ):
182191 def __init__ (self ):
0 commit comments