@@ -337,6 +337,8 @@ def AtenMmIntTypes_basic(module, tu: TestUtils):
337337
338338
339339# ==============================================================================
340+ # For DQ-Q fake quantization ops
341+ import torch .ao .quantization .fx ._decomposed
340342
341343
342344class AtenMmQint8 (torch .nn .Module ):
@@ -352,12 +354,14 @@ def __init__(self):
352354 ]
353355 )
354356 def forward (self , x , y ):
355- qx = torch ._make_per_tensor_quantized_tensor (x , 0.0215 , - 25 )
356- qx = torch .dequantize (qx )
357- qy = torch ._make_per_tensor_quantized_tensor (y , 0.0176 , 18 )
358- qy = torch .dequantize (qy )
359- qz = torch .mm (qx , qy )
360- return qz
357+ x = torch .torch .ops .quantized_decomposed .dequantize_per_tensor .default (
358+ x , 0.0215 , - 25 , - 128 , 127 , torch .int8
359+ )
360+ y = torch .torch .ops .quantized_decomposed .dequantize_per_tensor .default (
361+ y , 0.0176 , 18 , - 128 , 127 , torch .int8
362+ )
363+ z = torch .mm (x , y )
364+ return z
361365
362366
363367@register_test_case (module_factory = lambda : AtenMmQint8 ())
@@ -384,12 +388,14 @@ def __init__(self):
384388 ]
385389 )
386390 def forward (self , x , y ):
387- qx = torch ._make_per_tensor_quantized_tensor (x , 0.199 , 65 )
388- qx = torch .dequantize (qx )
389- qy = torch ._make_per_tensor_quantized_tensor (y , 0.0215 , 160 )
390- qy = torch .dequantize (qy )
391- qz = torch .mm (qx , qy )
392- return qz
391+ x = torch .torch .ops .quantized_decomposed .dequantize_per_tensor .default (
392+ x , 0.199 , 65 , 0 , 255 , torch .uint8
393+ )
394+ y = torch .torch .ops .quantized_decomposed .dequantize_per_tensor .default (
395+ y , 0.0215 , 160 , 0 , 255 , torch .uint8
396+ )
397+ z = torch .mm (x , y )
398+ return z
393399
394400
395401@register_test_case (module_factory = lambda : AtenMmQuint8 ())
@@ -416,12 +422,14 @@ def __init__(self):
416422 ]
417423 )
418424 def forward (self , x , y ):
419- qx = torch ._make_per_tensor_quantized_tensor (x , 0.03 , - 66 )
420- qx = torch .dequantize (qx )
421- qy = torch ._make_per_tensor_quantized_tensor (y , 0.025 , 160 )
422- qy = torch .dequantize (qy )
423- qz = torch .mm (qx , qy )
424- return qz
425+ x = torch .torch .ops .quantized_decomposed .dequantize_per_tensor .default (
426+ x , 0.03 , - 66 , - 128 , 127 , torch .int8
427+ )
428+ y = torch .torch .ops .quantized_decomposed .dequantize_per_tensor .default (
429+ y , 0.025 , 160 , 0 , 255 , torch .uint8
430+ )
431+ z = torch .mm (x , y )
432+ return z
425433
426434
427435@register_test_case (module_factory = lambda : AtenMmQMixedSigni8 ())
@@ -475,12 +483,14 @@ def __init__(self):
475483 ]
476484 )
477485 def forward (self , x , y ):
478- qx = torch ._make_per_tensor_quantized_tensor (x , 0.0215 , - 25 )
479- qx = torch .dequantize (qx )
480- qy = torch ._make_per_tensor_quantized_tensor (y , 0.0176 , 18 )
481- qy = torch .dequantize (qy )
482- qz = torch .matmul (qx , qy )
483- return qz
486+ x = torch .torch .ops .quantized_decomposed .dequantize_per_tensor .default (
487+ x , 0.0215 , - 25 , - 128 , 127 , torch .int8
488+ )
489+ y = torch .torch .ops .quantized_decomposed .dequantize_per_tensor .default (
490+ y , 0.0176 , 18 , - 128 , 127 , torch .int8
491+ )
492+ z = torch .matmul (x , y )
493+ return z
484494
485495
486496@register_test_case (module_factory = lambda : AtenMatmulQint8VM ())
@@ -505,12 +515,14 @@ def __init__(self):
505515 ]
506516 )
507517 def forward (self , x , y ):
508- qx = torch ._make_per_tensor_quantized_tensor (x , 0.0215 , - 25 )
509- qx = torch .dequantize (qx )
510- qy = torch ._make_per_tensor_quantized_tensor (y , 0.0176 , 18 )
511- qy = torch .dequantize (qy )
512- qz = torch .matmul (qx , qy )
513- return qz
518+ x = torch .torch .ops .quantized_decomposed .dequantize_per_tensor .default (
519+ x , 0.0215 , - 25 , - 128 , 127 , torch .int8
520+ )
521+ y = torch .torch .ops .quantized_decomposed .dequantize_per_tensor .default (
522+ y , 0.0176 , 18 , - 128 , 127 , torch .int8
523+ )
524+ z = torch .matmul (x , y )
525+ return z
514526
515527
516528@register_test_case (module_factory = lambda : AtenMatmulQint8VV ())
@@ -535,12 +547,14 @@ def __init__(self):
535547 ]
536548 )
537549 def forward (self , x , y ):
538- qx = torch ._make_per_tensor_quantized_tensor (x , 0.0215 , - 25 )
539- qx = torch .dequantize (qx )
540- qy = torch ._make_per_tensor_quantized_tensor (y , 0.0176 , 18 )
541- qy = torch .dequantize (qy )
542- qz = torch .matmul (qx , qy )
543- return qz
550+ x = torch .torch .ops .quantized_decomposed .dequantize_per_tensor .default (
551+ x , 0.0215 , - 25 , - 128 , 127 , torch .int8
552+ )
553+ y = torch .torch .ops .quantized_decomposed .dequantize_per_tensor .default (
554+ y , 0.0176 , 18 , - 128 , 127 , torch .int8
555+ )
556+ z = torch .matmul (x , y )
557+ return z
544558
545559
546560@register_test_case (module_factory = lambda : AtenMatmulQint8MV ())
@@ -565,12 +579,14 @@ def __init__(self):
565579 ]
566580 )
567581 def forward (self , x , y ):
568- qx = torch ._make_per_tensor_quantized_tensor (x , 0.0215 , - 25 )
569- qx = torch .dequantize (qx )
570- qy = torch ._make_per_tensor_quantized_tensor (y , 0.0176 , 18 )
571- qy = torch .dequantize (qy )
572- qz = torch .matmul (qx , qy )
573- return qz
582+ x = torch .torch .ops .quantized_decomposed .dequantize_per_tensor .default (
583+ x , 0.0215 , - 25 , - 128 , 127 , torch .int8
584+ )
585+ y = torch .torch .ops .quantized_decomposed .dequantize_per_tensor .default (
586+ y , 0.0176 , 18 , - 128 , 127 , torch .int8
587+ )
588+ z = torch .matmul (x , y )
589+ return z
574590
575591
576592@register_test_case (module_factory = lambda : AtenMatmulQint8 ())
@@ -597,12 +613,14 @@ def __init__(self):
597613 ]
598614 )
599615 def forward (self , x , y ):
600- qx = torch ._make_per_tensor_quantized_tensor (x , 0.03 , - 66 )
601- qx = torch .dequantize (qx )
602- qy = torch ._make_per_tensor_quantized_tensor (y , 0.025 , 160 )
603- qy = torch .dequantize (qy )
604- qz = torch .matmul (qx , qy )
605- return qz
616+ x = torch .torch .ops .quantized_decomposed .dequantize_per_tensor .default (
617+ x , 0.03 , - 66 , - 128 , 127 , torch .int8
618+ )
619+ y = torch .torch .ops .quantized_decomposed .dequantize_per_tensor .default (
620+ y , 0.025 , 160 , 0 , 255 , torch .uint8
621+ )
622+ z = torch .matmul (x , y )
623+ return z
606624
607625
608626@register_test_case (module_factory = lambda : AtenMatmulQMixedSigni8 ())
@@ -629,13 +647,15 @@ def __init__(self):
629647 ]
630648 )
631649 def forward (self , x , y ):
632- qx = torch ._make_per_tensor_quantized_tensor (x , 0.03 , - 66 )
633- qx = torch .dequantize (qx )
634- qy = torch ._make_per_tensor_quantized_tensor (y , 0.025 , 160 )
635- qy = torch .dequantize (qy )
636- qy = torch .transpose (qy , 1 , 2 )
637- qz = torch .matmul (qx , qy )
638- return qz
650+ x = torch .torch .ops .quantized_decomposed .dequantize_per_tensor .default (
651+ x , 0.03 , - 66 , - 128 , 127 , torch .int8
652+ )
653+ y = torch .torch .ops .quantized_decomposed .dequantize_per_tensor .default (
654+ y , 0.025 , 160 , 0 , 255 , torch .uint8
655+ )
656+ y = torch .transpose (y , 1 , 2 )
657+ z = torch .matmul (x , y )
658+ return z
639659
640660
641661@register_test_case (module_factory = lambda : AtenMatmulQMixedSigni8Transpose ())
0 commit comments