Skip to content

Commit 42ba541

Browse files
authored
[fx] Fix importing and tests for quantized conv (#3809)
The fx tracer does not support tracing "real" quantized tensors currently. A "real" quantized tensor here means a tensor that is created using a method like `torch.quantize_per_tensor()` and carries the quantization parameters (scale, zero_point, scheme) in the object. However, it seems like the DQ-Q type fake quantizatation is now commonly used as a high level representation of quantized operators and is only lowered to native quantized ops (if available) in the respective hardware backend. Quantization of floating point modules in PyTorch is recently also performed as a graph transformation after exporting/tracing the original module. ```python # Examples of "real"/native quantization tens = torch.randint(-127, 127, (1,), dtype=torch.int8) torch._make_per_tensor_quantized_tensor(tens, 1, 0) # tensor([90.], size=(1,), dtype=torch.qint8, # quantization_scheme=torch.per_tensor_affine, scale=1.0, zero_point=0) tens = torch.rand((1,)) torch.quantize_per_tensor(tens, 1, 0, torch.qint8) # tensor([1.], size=(1,), dtype=torch.qint8, # quantization_scheme=torch.per_tensor_affine, scale=1.0, zero_point=0) # Example of DQ/Q quantization import torch.ao.quantization.fx._decomposed tens = torch.rand((1,)) torch.ops.quantized_decomposed.quantize_per_tensor.default(tens, 1, 0, -128, 127, torch.int8) # tensor([1], dtype=torch.int8) ``` This means that a typical import flow for a quantized network into/through torch-mlir would look like this: `torch.export() -> quantization transformations on fx graph -> fx_importer` Where the tensors in the graph are normal float/int tensors and the quantization parameters are carried by the DQ/Q ops. These kinds of graphs can be traced without issues. Currently, our quantized convolution tests use the "real" quantized tensors. This means that with the retirement of the `jit_ir_importer`, these tests cannot be imported any longer. In summary, I see no reason to stick to the "real" quantization in these tests, as both PyTorch 2.0 is using DQ/Q quantization and our linalg backend is also using it. This patch updates our quantized convolution tests to use the DQ-Q quantization with the ops from `torch.ops.quantized_decomposed`. Note: For future reference, there seems to be an ongoing consolidation of the ops for the DQ/Q scheme on the PyTorch side (pytorch/ao#986 (comment)).
1 parent 140cad5 commit 42ba541

File tree

3 files changed

+56
-48
lines changed

3 files changed

+56
-48
lines changed

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -420,15 +420,7 @@
420420
"CeilFloatModule_basic",
421421
"ContainsIntList_False",
422422
"ContainsIntList_True",
423-
"Conv2dQInt8Module_basic",
424-
"Conv2dQInt8Module_depthwise",
425-
"Conv2dQInt8Module_grouped",
426-
"Conv2dQInt8Module_not_depthwise",
427-
"Conv2dQInt8PerChannelModule_basic",
428-
"Conv2dQInt8PerChannelModule_depthwise",
429-
"Conv2dQInt8PerChannelModule_grouped",
430423
"ConvTbcModule_basic",
431-
"ConvTranspose2DQInt8_basic",
432424
"ConvolutionBackwardModule2DPadded_basic",
433425
"ConvolutionBackwardModule2DStrided_basic",
434426
"ConvolutionBackwardModule2D_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py

Lines changed: 55 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,23 +1183,28 @@ def ConvTbcModule_basic(module, tu: TestUtils):
11831183
module.forward(tu.rand(9, 4, 5), tu.rand(3, 5, 6), tu.rand(6))
11841184

11851185

1186+
# For DQ-Q fake quantization ops
1187+
import torch.ao.quantization.fx._decomposed
1188+
1189+
11861190
class Conv2dQInt8ModuleBase(torch.nn.Module):
11871191
def __init__(self, groups=1):
11881192
self.groups = groups
11891193
super().__init__()
11901194

1191-
def _forward(self, inputVec, weight, bias):
1192-
inputVec = torch._make_per_tensor_quantized_tensor(inputVec, 0.01, 7)
1193-
inputVec = torch.dequantize(inputVec)
1194-
1195-
weight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 3)
1196-
weight = torch.dequantize(weight)
1197-
1198-
bias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
1199-
bias = torch.dequantize(bias)
1195+
def _forward(self, input, weight, bias):
1196+
input = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1197+
input, 0.01, 7, -128, 127, torch.int8
1198+
)
1199+
weight = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1200+
weight, 0.01, 3, -128, 127, torch.int8
1201+
)
1202+
bias = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1203+
bias, 1, 0, -1000, 1000, torch.int32
1204+
)
12001205

1201-
return torch.ops.aten.conv2d(
1202-
inputVec,
1206+
conv = torch.ops.aten.conv2d(
1207+
input,
12031208
weight,
12041209
bias=bias,
12051210
stride=[1, 1],
@@ -1208,6 +1213,11 @@ def _forward(self, inputVec, weight, bias):
12081213
groups=self.groups,
12091214
)
12101215

1216+
# Use int32 to avoid overflows
1217+
return torch.ops.quantized_decomposed.quantize_per_tensor.default(
1218+
conv, 1, 0, -(2**31), 2**31 - 1, torch.int32
1219+
)
1220+
12111221

12121222
class Conv2dQInt8ModuleDyn(Conv2dQInt8ModuleBase):
12131223
@export
@@ -1216,7 +1226,7 @@ class Conv2dQInt8ModuleDyn(Conv2dQInt8ModuleBase):
12161226
None,
12171227
([-1, -1, -1, -1], torch.int8, True),
12181228
([-1, -1, -1, -1], torch.int8, True),
1219-
([-1], torch.float, True),
1229+
([-1], torch.int32, True),
12201230
]
12211231
)
12221232
def forward(self, inputVec, weight, bias):
@@ -1230,7 +1240,7 @@ class Conv2dQInt8ModuleStatic(Conv2dQInt8ModuleBase):
12301240
None,
12311241
([2, 3, 12, 12], torch.int8, True),
12321242
([3, 1, 5, 3], torch.int8, True),
1233-
([3], torch.float, True),
1243+
([3], torch.int32, True),
12341244
]
12351245
)
12361246
def forward(self, inputVec, weight, bias):
@@ -1244,7 +1254,7 @@ class Conv2dQInt8ModuleStatic_MoreOutChannels(Conv2dQInt8ModuleBase):
12441254
None,
12451255
([2, 3, 12, 12], torch.int8, True),
12461256
([6, 1, 5, 3], torch.int8, True),
1247-
([6], torch.float, True),
1257+
([6], torch.int32, True),
12481258
]
12491259
)
12501260
def forward(self, inputVec, weight, bias):
@@ -1255,23 +1265,23 @@ def forward(self, inputVec, weight, bias):
12551265
def Conv2dQInt8Module_basic(module, tu: TestUtils):
12561266
inputVec = tu.randint(2, 4, 7, 8, low=-128, high=127).to(torch.int8)
12571267
weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8)
1258-
bias = torch.rand(3)
1268+
bias = tu.randint(3, low=-1000, high=1000).to(torch.int32)
12591269
module.forward(inputVec, weight, bias)
12601270

12611271

12621272
@register_test_case(module_factory=lambda: Conv2dQInt8ModuleDyn(groups=2))
12631273
def Conv2dQInt8Module_grouped(module, tu: TestUtils):
12641274
inputVec = tu.randint(2, 8, 7, 8, low=-128, high=127).to(torch.int8)
12651275
weight = tu.randint(6, 4, 3, 2, low=-128, high=127).to(torch.int8)
1266-
bias = torch.rand(6)
1276+
bias = tu.randint(6, low=-1000, high=1000).to(torch.int32)
12671277
module.forward(inputVec, weight, bias)
12681278

12691279

12701280
@register_test_case(module_factory=lambda: Conv2dQInt8ModuleStatic(groups=3))
12711281
def Conv2dQInt8Module_depthwise(module, tu: TestUtils):
12721282
inputVec = tu.randint(2, 3, 12, 12, low=-128, high=127).to(torch.int8)
12731283
weight = tu.randint(3, 1, 5, 3, low=-128, high=127).to(torch.int8)
1274-
bias = torch.rand(3)
1284+
bias = tu.randint(3, low=-1000, high=1000).to(torch.int32)
12751285
module.forward(inputVec, weight, bias)
12761286

12771287

@@ -1281,7 +1291,7 @@ def Conv2dQInt8Module_depthwise(module, tu: TestUtils):
12811291
def Conv2dQInt8Module_not_depthwise(module, tu: TestUtils):
12821292
inputVec = tu.randint(2, 3, 12, 12, low=-128, high=127).to(torch.int8)
12831293
weight = tu.randint(6, 1, 5, 3, low=-128, high=127).to(torch.int8)
1284-
bias = torch.rand(6)
1294+
bias = tu.randint(6, low=-1000, high=1000).to(torch.int32)
12851295
module.forward(inputVec, weight, bias)
12861296

12871297

@@ -1300,24 +1310,29 @@ def __init__(self):
13001310
]
13011311
)
13021312
def forward(self, input, weight, bias):
1303-
qinput = torch._make_per_tensor_quantized_tensor(input, 0.01, -25)
1304-
qinput = torch.dequantize(qinput)
1305-
qweight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 50)
1306-
qweight = torch.dequantize(qweight)
1307-
qbias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
1308-
qbias = torch.dequantize(qbias)
1309-
qz = torch.ops.aten.convolution(
1310-
qinput,
1311-
qweight,
1312-
bias=qbias,
1313+
input = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1314+
input, 0.01, -25, -128, 127, torch.int8
1315+
)
1316+
weight = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1317+
weight, 0.01, 50, -128, 127, torch.int8
1318+
)
1319+
1320+
res = torch.ops.aten.convolution(
1321+
input,
1322+
weight,
1323+
bias=bias,
13131324
stride=[2, 1],
13141325
padding=[1, 1],
13151326
dilation=[1, 1],
13161327
transposed=True,
13171328
output_padding=[0, 0],
13181329
groups=1,
13191330
)
1320-
return qz
1331+
1332+
# Use int32 to avoid overflows
1333+
return torch.ops.quantized_decomposed.quantize_per_tensor.default(
1334+
res, 1, 0, -(2**31), 2**31 - 1, torch.int32
1335+
)
13211336

13221337

13231338
@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module())
@@ -1342,18 +1357,14 @@ def __init__(self, groups=1):
13421357
super().__init__()
13431358

13441359
def _forward(self, inputVec, weight, scales, zeropoints, bias):
1345-
inputVec = torch._make_per_tensor_quantized_tensor(inputVec, 0.01, 7)
1346-
inputVec = torch.dequantize(inputVec)
1347-
1348-
weight = torch._make_per_channel_quantized_tensor(
1349-
weight, scales, zeropoints, axis=0
1360+
inputVec = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1361+
inputVec, 0.01, 7, -128, 127, torch.int8
1362+
)
1363+
weight = torch.ops.quantized_decomposed.dequantize_per_channel.default(
1364+
weight, scales, zeropoints, 0, -128, 127, torch.int8
13501365
)
1351-
weight = torch.dequantize(weight)
1352-
1353-
bias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
1354-
bias = torch.dequantize(bias)
13551366

1356-
return torch.ops.aten.conv2d(
1367+
conv = torch.ops.aten.conv2d(
13571368
inputVec,
13581369
weight,
13591370
bias=bias,
@@ -1363,6 +1374,11 @@ def _forward(self, inputVec, weight, scales, zeropoints, bias):
13631374
groups=self.groups,
13641375
)
13651376

1377+
# Use int32 to avoid overflows
1378+
return torch.ops.quantized_decomposed.quantize_per_tensor.default(
1379+
conv, 1, 0, -(2**31), 2**31 - 1, torch.int32
1380+
)
1381+
13661382

13671383
class Conv2dQInt8PerChannelModuleDyn(Conv2dQInt8PerChannelModuleBase):
13681384
@export

python/torch_mlir/fx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _module_lowering(
4141
option_string = "{extra-library=" + extra_library_file_name + "}"
4242
run_pipeline_with_repro_report(
4343
torch_mod,
44-
f"builtin.module(torchdynamo-export-to-torch-backend-pipeline{option_string})",
44+
f"builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{option_string})",
4545
"Lowering TorchFX IR -> Torch Backend IR",
4646
enable_ir_printing=verbose,
4747
)

0 commit comments

Comments
 (0)