Skip to content

Commit 9d7595f

Browse files
authored
Merge pull request #475 from Xilinx/bump_to_42ba541c
[AutoBump] Merge with fixes of 42ba541 (Oct 22) (91)
2 parents 129404f + 8201fa8 commit 9d7595f

File tree

3 files changed

+56
-48
lines changed

3 files changed

+56
-48
lines changed

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -420,15 +420,7 @@
420420
"ContainsIntList_False",
421421
"ContainsIntList_True",
422422
"Conv1dNoPaddingGroupModule_basic",
423-
"Conv2dQInt8Module_basic",
424-
"Conv2dQInt8Module_depthwise",
425-
"Conv2dQInt8Module_grouped",
426-
"Conv2dQInt8Module_not_depthwise",
427-
"Conv2dQInt8PerChannelModule_basic",
428-
"Conv2dQInt8PerChannelModule_depthwise",
429-
"Conv2dQInt8PerChannelModule_grouped",
430423
"ConvTbcModule_basic",
431-
"ConvTranspose2DQInt8_basic",
432424
"ConvolutionBackwardModule2DPadded_basic",
433425
"ConvolutionBackwardModule2DStrided_basic",
434426
"ConvolutionBackwardModule2D_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py

Lines changed: 55 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,23 +1300,28 @@ def ConvTbcModule_basic(module, tu: TestUtils):
13001300
module.forward(tu.rand(9, 4, 5), tu.rand(3, 5, 6), tu.rand(6))
13011301

13021302

1303+
# For DQ-Q fake quantization ops
1304+
import torch.ao.quantization.fx._decomposed
1305+
1306+
13031307
class Conv2dQInt8ModuleBase(torch.nn.Module):
13041308
def __init__(self, groups=1):
13051309
self.groups = groups
13061310
super().__init__()
13071311

1308-
def _forward(self, inputVec, weight, bias):
1309-
inputVec = torch._make_per_tensor_quantized_tensor(inputVec, 0.01, 7)
1310-
inputVec = torch.dequantize(inputVec)
1311-
1312-
weight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 3)
1313-
weight = torch.dequantize(weight)
1314-
1315-
bias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
1316-
bias = torch.dequantize(bias)
1312+
def _forward(self, input, weight, bias):
1313+
input = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1314+
input, 0.01, 7, -128, 127, torch.int8
1315+
)
1316+
weight = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1317+
weight, 0.01, 3, -128, 127, torch.int8
1318+
)
1319+
bias = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1320+
bias, 1, 0, -1000, 1000, torch.int32
1321+
)
13171322

1318-
return torch.ops.aten.conv2d(
1319-
inputVec,
1323+
conv = torch.ops.aten.conv2d(
1324+
input,
13201325
weight,
13211326
bias=bias,
13221327
stride=[1, 1],
@@ -1325,6 +1330,11 @@ def _forward(self, inputVec, weight, bias):
13251330
groups=self.groups,
13261331
)
13271332

1333+
# Use int32 to avoid overflows
1334+
return torch.ops.quantized_decomposed.quantize_per_tensor.default(
1335+
conv, 1, 0, -(2**31), 2**31 - 1, torch.int32
1336+
)
1337+
13281338

13291339
class Conv2dQInt8ModuleDyn(Conv2dQInt8ModuleBase):
13301340
@export
@@ -1333,7 +1343,7 @@ class Conv2dQInt8ModuleDyn(Conv2dQInt8ModuleBase):
13331343
None,
13341344
([-1, -1, -1, -1], torch.int8, True),
13351345
([-1, -1, -1, -1], torch.int8, True),
1336-
([-1], torch.float, True),
1346+
([-1], torch.int32, True),
13371347
]
13381348
)
13391349
def forward(self, inputVec, weight, bias):
@@ -1347,7 +1357,7 @@ class Conv2dQInt8ModuleStatic(Conv2dQInt8ModuleBase):
13471357
None,
13481358
([2, 3, 12, 12], torch.int8, True),
13491359
([3, 1, 5, 3], torch.int8, True),
1350-
([3], torch.float, True),
1360+
([3], torch.int32, True),
13511361
]
13521362
)
13531363
def forward(self, inputVec, weight, bias):
@@ -1361,7 +1371,7 @@ class Conv2dQInt8ModuleStatic_MoreOutChannels(Conv2dQInt8ModuleBase):
13611371
None,
13621372
([2, 3, 12, 12], torch.int8, True),
13631373
([6, 1, 5, 3], torch.int8, True),
1364-
([6], torch.float, True),
1374+
([6], torch.int32, True),
13651375
]
13661376
)
13671377
def forward(self, inputVec, weight, bias):
@@ -1372,23 +1382,23 @@ def forward(self, inputVec, weight, bias):
13721382
def Conv2dQInt8Module_basic(module, tu: TestUtils):
13731383
inputVec = tu.randint(2, 4, 7, 8, low=-128, high=127).to(torch.int8)
13741384
weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8)
1375-
bias = torch.rand(3)
1385+
bias = tu.randint(3, low=-1000, high=1000).to(torch.int32)
13761386
module.forward(inputVec, weight, bias)
13771387

13781388

13791389
@register_test_case(module_factory=lambda: Conv2dQInt8ModuleDyn(groups=2))
13801390
def Conv2dQInt8Module_grouped(module, tu: TestUtils):
13811391
inputVec = tu.randint(2, 8, 7, 8, low=-128, high=127).to(torch.int8)
13821392
weight = tu.randint(6, 4, 3, 2, low=-128, high=127).to(torch.int8)
1383-
bias = torch.rand(6)
1393+
bias = tu.randint(6, low=-1000, high=1000).to(torch.int32)
13841394
module.forward(inputVec, weight, bias)
13851395

13861396

13871397
@register_test_case(module_factory=lambda: Conv2dQInt8ModuleStatic(groups=3))
13881398
def Conv2dQInt8Module_depthwise(module, tu: TestUtils):
13891399
inputVec = tu.randint(2, 3, 12, 12, low=-128, high=127).to(torch.int8)
13901400
weight = tu.randint(3, 1, 5, 3, low=-128, high=127).to(torch.int8)
1391-
bias = torch.rand(3)
1401+
bias = tu.randint(3, low=-1000, high=1000).to(torch.int32)
13921402
module.forward(inputVec, weight, bias)
13931403

13941404

@@ -1398,7 +1408,7 @@ def Conv2dQInt8Module_depthwise(module, tu: TestUtils):
13981408
def Conv2dQInt8Module_not_depthwise(module, tu: TestUtils):
13991409
inputVec = tu.randint(2, 3, 12, 12, low=-128, high=127).to(torch.int8)
14001410
weight = tu.randint(6, 1, 5, 3, low=-128, high=127).to(torch.int8)
1401-
bias = torch.rand(6)
1411+
bias = tu.randint(6, low=-1000, high=1000).to(torch.int32)
14021412
module.forward(inputVec, weight, bias)
14031413

14041414

@@ -1417,24 +1427,29 @@ def __init__(self):
14171427
]
14181428
)
14191429
def forward(self, input, weight, bias):
1420-
qinput = torch._make_per_tensor_quantized_tensor(input, 0.01, -25)
1421-
qinput = torch.dequantize(qinput)
1422-
qweight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 50)
1423-
qweight = torch.dequantize(qweight)
1424-
qbias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
1425-
qbias = torch.dequantize(qbias)
1426-
qz = torch.ops.aten.convolution(
1427-
qinput,
1428-
qweight,
1429-
bias=qbias,
1430+
input = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1431+
input, 0.01, -25, -128, 127, torch.int8
1432+
)
1433+
weight = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1434+
weight, 0.01, 50, -128, 127, torch.int8
1435+
)
1436+
1437+
res = torch.ops.aten.convolution(
1438+
input,
1439+
weight,
1440+
bias=bias,
14301441
stride=[2, 1],
14311442
padding=[1, 1],
14321443
dilation=[1, 1],
14331444
transposed=True,
14341445
output_padding=[0, 0],
14351446
groups=1,
14361447
)
1437-
return qz
1448+
1449+
# Use int32 to avoid overflows
1450+
return torch.ops.quantized_decomposed.quantize_per_tensor.default(
1451+
res, 1, 0, -(2**31), 2**31 - 1, torch.int32
1452+
)
14381453

14391454

14401455
@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module())
@@ -1459,18 +1474,14 @@ def __init__(self, groups=1):
14591474
super().__init__()
14601475

14611476
def _forward(self, inputVec, weight, scales, zeropoints, bias):
1462-
inputVec = torch._make_per_tensor_quantized_tensor(inputVec, 0.01, 7)
1463-
inputVec = torch.dequantize(inputVec)
1464-
1465-
weight = torch._make_per_channel_quantized_tensor(
1466-
weight, scales, zeropoints, axis=0
1477+
inputVec = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1478+
inputVec, 0.01, 7, -128, 127, torch.int8
1479+
)
1480+
weight = torch.ops.quantized_decomposed.dequantize_per_channel.default(
1481+
weight, scales, zeropoints, 0, -128, 127, torch.int8
14671482
)
1468-
weight = torch.dequantize(weight)
1469-
1470-
bias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
1471-
bias = torch.dequantize(bias)
14721483

1473-
return torch.ops.aten.conv2d(
1484+
conv = torch.ops.aten.conv2d(
14741485
inputVec,
14751486
weight,
14761487
bias=bias,
@@ -1480,6 +1491,11 @@ def _forward(self, inputVec, weight, scales, zeropoints, bias):
14801491
groups=self.groups,
14811492
)
14821493

1494+
# Use int32 to avoid overflows
1495+
return torch.ops.quantized_decomposed.quantize_per_tensor.default(
1496+
conv, 1, 0, -(2**31), 2**31 - 1, torch.int32
1497+
)
1498+
14831499

14841500
class Conv2dQInt8PerChannelModuleDyn(Conv2dQInt8PerChannelModuleBase):
14851501
@export

python/torch_mlir/fx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _module_lowering(
4141
option_string = "{extra-library=" + extra_library_file_name + "}"
4242
run_pipeline_with_repro_report(
4343
torch_mod,
44-
f"builtin.module(torchdynamo-export-to-torch-backend-pipeline{option_string})",
44+
f"builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{option_string})",
4545
"Lowering TorchFX IR -> Torch Backend IR",
4646
enable_ir_printing=verbose,
4747
)

0 commit comments

Comments
 (0)