@@ -76,14 +76,13 @@ def test_conv2d_int8_in_f32_out(self):
7676 bias = bias )
7777 x = torch .rand (1 , in_channels * g , spatial , spatial )
7878 patterns = [
79- ["aten::quantize_per_tensor" ],
8079 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::_convolution" ]
8180 ]
8281 #TODO: enable torch.per_tensor_symmetric case.
8382 for qscheme in [torch .per_tensor_affine ]:
8483 graph = self .checkQuantizeTrace (m , [x ], x_var = [torch .rand (5 , in_channels * g , spatial , spatial , requires_grad = False )], atol = 2e-1 , config_name = "conv2d" , qscheme = qscheme )
85- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 2 )
86- self .assertFused (graph , ['aten::_convolution' , 'aten::quantize_per_tensor ' , 'aten::quantize_per_channel ' ])
84+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 1 )
85+ self .assertFused (graph , ['aten::_convolution' , 'aten::quantize_per_channel ' , 'aten::dequantize ' ])
8786 self .checkPatterns (graph , patterns )
8887
8988 @llga_test_env
@@ -93,13 +92,12 @@ def test_linear_int8_in_f32_out(self):
9392 m = torch .nn .Linear (in_features = 28 , out_features = 64 , bias = bias )
9493
9594 patterns = [
96- ["aten::quantize_per_tensor" ],
9795 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::linear" ],
9896 ]
9997 for qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
10098 graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 , config_name = "linear" , qscheme = qscheme )
101- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 2 )
102- self .assertFused (graph , ['aten::linear' , 'aten::quantize_per_tensor' , 'aten:: quantize_per_channel' , 'aten::dequantize' ])
99+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 1 )
100+ self .assertFused (graph , ['aten::linear' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
103101 self .checkPatterns (graph , patterns )
104102
105103 @llga_test_env
@@ -121,16 +119,14 @@ def forward(self, x, y):
121119 m = M (bias )
122120
123121 patterns = [
124- ["aten::quantize_per_tensor" ],
125122 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::linear" , "aten::quantize_per_tensor" ],
126123 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::linear" ]
127124 ]
128125
129126 for qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
130127 graph = self .checkQuantizeTrace (m , [x , y ], atol = 2e-1 , config_name = "linear_int8" , qscheme = qscheme )
131- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
132- self .assertFused (graph , ['aten::linear' ,
133- 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
128+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 2 )
129+ self .assertFused (graph , ['aten::linear' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
134130 self .checkPatterns (graph , patterns )
135131
136132 @llga_test_env
@@ -158,14 +154,12 @@ def test_max_pool2d(self):
158154 x = torch .rand (1 , 3 , spatial , spatial )
159155
160156 patterns = [
161- ["aten::quantize_per_tensor" ],
162157 ["aten::dequantize" , "aten::max_pool2d" , "aten::quantize_per_tensor" ],
163- ["aten::dequantize" ]
164158 ]
165159 for qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
166160 graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 , config_name = "max_pool2d" , qscheme = qscheme )
167- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
168- self .assertFused (graph , ['aten::max_pool2d' , 'aten::quantize_per_tensor' , 'aten::dequantize' ])
161+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 1 )
162+ self .assertFused (graph , ['aten::max_pool2d' ])
169163 self .checkPatterns (graph , patterns )
170164
171165 @llga_test_env
@@ -212,14 +206,13 @@ def forward(self, x):
212206 x = torch .rand (1 , 32 , 28 , 28 )
213207
214208 patterns = [
215- ["aten::quantize_per_tensor" ],
216209 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::_convolution" , 'aten::' + eltwise , "aten::quantize_per_tensor" ], # inplace op will become outplace op on the JIT graph
217210 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::_convolution" ]
218211 ]
219212 for qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
220213 graph = self .checkQuantizeTrace (m , [x ], atol = 2e-1 , config_name = "conv2d_eltwise" , qscheme = qscheme )
221- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
222- self .assertFused (graph , ['aten::_convolution' , 'aten::' + eltwise , 'aten::quantize_per_tensor' , 'aten:: quantize_per_channel' , 'aten::dequantize' ])
214+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 2 )
215+ self .assertFused (graph , ['aten::_convolution' , 'aten::' + eltwise , 'aten::quantize_per_channel' , 'aten::dequantize' ])
223216 self .checkPatterns (graph , patterns )
224217
225218 @llga_test_env
@@ -241,14 +234,13 @@ def forward(self, x):
241234 # x = torch.rand(1, 32, 28, 28)
242235
243236 patterns = [
244- ["aten::quantize_per_tensor" ],
245237 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::_convolution" ]
246238 ]
247239 # TODO: add torch.per_tensor_symmetric case.
248240 for qscheme in [torch .per_tensor_affine ]:
249241 graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 , folding = True , config_name = "conv2d_bn" , qscheme = qscheme )
250- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 2 )
251- self .assertFused (graph , ['aten::_convolution' , 'aten::quantize_per_tensor ' , 'aten::quantize_per_channel ' ])
242+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 1 )
243+ self .assertFused (graph , ['aten::_convolution' , 'aten::quantize_per_channel ' , 'aten::dequantize ' ])
252244 self .checkPatterns (graph , patterns )
253245
254246 @llga_test_env
@@ -268,15 +260,12 @@ def forward(self, x):
268260 m = M ().eval ()
269261 x = torch .rand (1 , 32 , 28 , 28 )
270262 patterns = [
271- ["aten::quantize_per_tensor" ],
272263 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::_convolution" , "aten::relu" , "aten::quantize_per_tensor" ],
273- ["aten::dequantize" ]
274264 ]
275265 for qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
276266 graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 , folding = True , config_name = "conv2d_bn_relu" , qscheme = qscheme )
277- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
278- self .assertFused (graph , ['aten::_convolution' , 'aten::relu' ,
279- 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
267+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 1 )
268+ self .assertFused (graph , ['aten::_convolution' , 'aten::relu' , 'aten::quantize_per_channel' ])
280269 self .checkPatterns (graph , patterns )
281270
282271 @llga_test_env
@@ -305,13 +294,11 @@ def forward(self, x):
305294 m = M (eltwise_fn , has_bias )
306295 x = torch .rand (32 , 28 , requires_grad = False )
307296 patterns = [
308- ["aten::quantize_per_tensor" ],
309297 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::linear" , "aten::" + eltwise , "aten::quantize_per_tensor" ],
310- ["aten::dequantize" ]
311298 ]
312299 for qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
313300 graph = self .checkQuantizeTrace (m , [x ], x_var = [torch .rand (2 , 28 , requires_grad = False )], atol = 1e-1 , config_name = "linear_eltwise" , qscheme = qscheme )
314- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
301+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 1 )
315302 self .assertFused (graph , ['aten::' + eltwise ])
316303 self .checkPatterns (graph , patterns )
317304
@@ -343,15 +330,14 @@ def forward(self, x, y):
343330 x = torch .rand (1 , 32 , 16 , 16 , requires_grad = False )
344331 y = torch .rand (1 , 32 , 16 , 16 , requires_grad = False )
345332 patterns = [
346- ["aten::quantize_per_tensor" ],
347- ["aten::quantize_per_tensor" ],
348333 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::_convolution" , "aten::quantize_per_tensor" ],
349334 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::_convolution" , "aten::relu" , "aten::add" , "aten::quantize_per_tensor" ],
350335 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::_convolution" ]
351336 ]
352337 for qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
353338 graph = self .checkQuantizeTrace (m , [x , y ], folding = True , atol = 1e-1 , config_name = "conv2d_sum" , qscheme = qscheme )
354- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 5 )
339+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
340+ self .assertFused (graph , ['aten::_convolution' , 'aten::relu' , 'aten::add' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
355341 self .checkPatterns (graph , patterns )
356342
357343 @llga_test_env
@@ -373,29 +359,15 @@ def forward(self, x, y):
373359 y = torch .randn (2 , 20 )
374360 m = M ()
375361 patterns = [
376- ["aten::quantize_per_tensor" ],
377- ["aten::quantize_per_tensor" ],
378362 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::linear" , "aten::add" , "aten::quantize_per_tensor" ],
379363 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::linear" ]
380364 ]
381365 for qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
382366 graph = self .checkQuantizeTrace (m , [x , y ], atol = 2e-1 , remove_dropout = True , config_name = "linear_dropout_sum" , qscheme = qscheme )
383- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 4 )
384- self .assertFused (graph , ['aten::linear' , 'aten::add' ,
385- 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
367+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 2 )
368+ self .assertFused (graph , ['aten::linear' , 'aten::add' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
386369 self .checkPatterns (graph , patterns )
387370
388- # TODO: check patterns when oneDNN support sum post_ops with zps
389- # patterns = [
390- # ["aten::quantize_per_tensor"],
391- # ["aten::quantize_per_channel"],
392- # ["aten::dequantize", "aten::linear", "aten::add", "aten::quantize_per_tensor"],
393- # ["aten::quantize_per_channel"],
394- # ["aten::dequantize", "aten::linear", "aten::quantize_per_tensor"],
395- # ["aten::dequantize"]
396- # ]
397- # self.checkPatterns(graph, patterns)
398-
399371 @llga_test_env
400372 def test_defer_size (self ):
401373 class M (nn .Module ):
@@ -415,14 +387,13 @@ def forward(self, x):
415387 m = M ()
416388 x = torch .rand (1 , 32 , 28 , 28 )
417389 patterns = [
418- ["aten::quantize_per_tensor" ],
419390 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::_convolution" , 'aten::relu' , "aten::quantize_per_tensor" ],
420391 ["aten::quantize_per_channel" , "aten::dequantize" , "aten::_convolution" ]
421392 ]
422393 for qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
423394 graph = self .checkQuantizeTrace (m , [x ], atol = 2e-1 , config_name = "defer_size" , qscheme = qscheme )
424- self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 3 )
425- self .assertFused (graph , ['aten::_convolution' , 'aten::relu' , 'aten::quantize_per_tensor' , 'aten:: quantize_per_channel' , 'aten::dequantize' ])
395+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 2 )
396+ self .assertFused (graph , ['aten::_convolution' , 'aten::relu' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
426397 self .checkPatterns (graph , patterns )
427398
428399class TestShapeFallback (JitLlgaTestCase ):
@@ -486,9 +457,7 @@ def _test_vision(self, model_name):
486457
487458 # TODO: aten::adaptive_avg_pool2d also need to be fused once backend supported it
488459 self .assertFused (graph , ['aten::_convolution' , 'aten::relu' ,
489- 'aten::max_pool2d' , 'aten::linear'
490- 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' ,
491- 'aten::dequantize' ])
460+ 'aten::max_pool2d' , 'aten::linear' , 'aten::quantize_per_channel' ])
492461
493462
494463for model_name , enabled in [
0 commit comments