Skip to content

Commit 76209db

Browse files
authored
Update quantized matmul tests to DQ/Q format supported by fx_importer (llvm#3815)
Continuation of llvm#3809 for the matmul tests.
1 parent 1259e8a commit 76209db

File tree

2 files changed

+75
-64
lines changed

2 files changed

+75
-64
lines changed

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -394,15 +394,6 @@
394394
"AtenIntBoolOpModule_basic",
395395
"AtenIntMM_basic",
396396
"AtenItemFpOpModule_basic",
397-
"AtenMatmulQMixedSigni8Transpose_basic",
398-
"AtenMatmulQMixedSigni8_basic",
399-
"AtenMatmulQint8MV_basic",
400-
"AtenMatmulQint8_basic",
401-
"AtenMatmulQint8VM_basic",
402-
"AtenMatmulQint8VV_basic",
403-
"AtenMmQMixedSigni8_basic",
404-
"AtenMmQint8_basic",
405-
"AtenMmQuint8_basic",
406397
"QuantizedReluInt32_basic",
407398
"QuantizedReluInt8_basic",
408399
"QuantizedReluUint8_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py

Lines changed: 75 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,8 @@ def AtenMmIntTypes_basic(module, tu: TestUtils):
337337

338338

339339
# ==============================================================================
340+
# For DQ-Q fake quantization ops
341+
import torch.ao.quantization.fx._decomposed
340342

341343

342344
class AtenMmQint8(torch.nn.Module):
@@ -352,12 +354,14 @@ def __init__(self):
352354
]
353355
)
354356
def forward(self, x, y):
355-
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
356-
qx = torch.dequantize(qx)
357-
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
358-
qy = torch.dequantize(qy)
359-
qz = torch.mm(qx, qy)
360-
return qz
357+
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
358+
x, 0.0215, -25, -128, 127, torch.int8
359+
)
360+
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
361+
y, 0.0176, 18, -128, 127, torch.int8
362+
)
363+
z = torch.mm(x, y)
364+
return z
361365

362366

363367
@register_test_case(module_factory=lambda: AtenMmQint8())
@@ -384,12 +388,14 @@ def __init__(self):
384388
]
385389
)
386390
def forward(self, x, y):
387-
qx = torch._make_per_tensor_quantized_tensor(x, 0.199, 65)
388-
qx = torch.dequantize(qx)
389-
qy = torch._make_per_tensor_quantized_tensor(y, 0.0215, 160)
390-
qy = torch.dequantize(qy)
391-
qz = torch.mm(qx, qy)
392-
return qz
391+
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
392+
x, 0.199, 65, 0, 255, torch.uint8
393+
)
394+
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
395+
y, 0.0215, 160, 0, 255, torch.uint8
396+
)
397+
z = torch.mm(x, y)
398+
return z
393399

394400

395401
@register_test_case(module_factory=lambda: AtenMmQuint8())
@@ -416,12 +422,14 @@ def __init__(self):
416422
]
417423
)
418424
def forward(self, x, y):
419-
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
420-
qx = torch.dequantize(qx)
421-
qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160)
422-
qy = torch.dequantize(qy)
423-
qz = torch.mm(qx, qy)
424-
return qz
425+
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
426+
x, 0.03, -66, -128, 127, torch.int8
427+
)
428+
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
429+
y, 0.025, 160, 0, 255, torch.uint8
430+
)
431+
z = torch.mm(x, y)
432+
return z
425433

426434

427435
@register_test_case(module_factory=lambda: AtenMmQMixedSigni8())
@@ -475,12 +483,14 @@ def __init__(self):
475483
]
476484
)
477485
def forward(self, x, y):
478-
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
479-
qx = torch.dequantize(qx)
480-
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
481-
qy = torch.dequantize(qy)
482-
qz = torch.matmul(qx, qy)
483-
return qz
486+
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
487+
x, 0.0215, -25, -128, 127, torch.int8
488+
)
489+
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
490+
y, 0.0176, 18, -128, 127, torch.int8
491+
)
492+
z = torch.matmul(x, y)
493+
return z
484494

485495

486496
@register_test_case(module_factory=lambda: AtenMatmulQint8VM())
@@ -505,12 +515,14 @@ def __init__(self):
505515
]
506516
)
507517
def forward(self, x, y):
508-
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
509-
qx = torch.dequantize(qx)
510-
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
511-
qy = torch.dequantize(qy)
512-
qz = torch.matmul(qx, qy)
513-
return qz
518+
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
519+
x, 0.0215, -25, -128, 127, torch.int8
520+
)
521+
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
522+
y, 0.0176, 18, -128, 127, torch.int8
523+
)
524+
z = torch.matmul(x, y)
525+
return z
514526

515527

516528
@register_test_case(module_factory=lambda: AtenMatmulQint8VV())
@@ -535,12 +547,14 @@ def __init__(self):
535547
]
536548
)
537549
def forward(self, x, y):
538-
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
539-
qx = torch.dequantize(qx)
540-
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
541-
qy = torch.dequantize(qy)
542-
qz = torch.matmul(qx, qy)
543-
return qz
550+
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
551+
x, 0.0215, -25, -128, 127, torch.int8
552+
)
553+
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
554+
y, 0.0176, 18, -128, 127, torch.int8
555+
)
556+
z = torch.matmul(x, y)
557+
return z
544558

545559

546560
@register_test_case(module_factory=lambda: AtenMatmulQint8MV())
@@ -565,12 +579,14 @@ def __init__(self):
565579
]
566580
)
567581
def forward(self, x, y):
568-
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
569-
qx = torch.dequantize(qx)
570-
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
571-
qy = torch.dequantize(qy)
572-
qz = torch.matmul(qx, qy)
573-
return qz
582+
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
583+
x, 0.0215, -25, -128, 127, torch.int8
584+
)
585+
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
586+
y, 0.0176, 18, -128, 127, torch.int8
587+
)
588+
z = torch.matmul(x, y)
589+
return z
574590

575591

576592
@register_test_case(module_factory=lambda: AtenMatmulQint8())
@@ -597,12 +613,14 @@ def __init__(self):
597613
]
598614
)
599615
def forward(self, x, y):
600-
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
601-
qx = torch.dequantize(qx)
602-
qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160)
603-
qy = torch.dequantize(qy)
604-
qz = torch.matmul(qx, qy)
605-
return qz
616+
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
617+
x, 0.03, -66, -128, 127, torch.int8
618+
)
619+
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
620+
y, 0.025, 160, 0, 255, torch.uint8
621+
)
622+
z = torch.matmul(x, y)
623+
return z
606624

607625

608626
@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8())
@@ -629,13 +647,15 @@ def __init__(self):
629647
]
630648
)
631649
def forward(self, x, y):
632-
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
633-
qx = torch.dequantize(qx)
634-
qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160)
635-
qy = torch.dequantize(qy)
636-
qy = torch.transpose(qy, 1, 2)
637-
qz = torch.matmul(qx, qy)
638-
return qz
650+
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
651+
x, 0.03, -66, -128, 127, torch.int8
652+
)
653+
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
654+
y, 0.025, 160, 0, 255, torch.uint8
655+
)
656+
y = torch.transpose(y, 1, 2)
657+
z = torch.matmul(x, y)
658+
return z
639659

640660

641661
@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8Transpose())

0 commit comments

Comments
 (0)