@@ -1300,23 +1300,28 @@ def ConvTbcModule_basic(module, tu: TestUtils):
13001300 module .forward (tu .rand (9 , 4 , 5 ), tu .rand (3 , 5 , 6 ), tu .rand (6 ))
13011301
13021302
1303+ # For DQ-Q fake quantization ops
1304+ import torch .ao .quantization .fx ._decomposed
1305+
1306+
13031307class Conv2dQInt8ModuleBase (torch .nn .Module ):
13041308 def __init__ (self , groups = 1 ):
13051309 self .groups = groups
13061310 super ().__init__ ()
13071311
1308- def _forward (self , inputVec , weight , bias ):
1309- inputVec = torch ._make_per_tensor_quantized_tensor (inputVec , 0.01 , 7 )
1310- inputVec = torch .dequantize (inputVec )
1311-
1312- weight = torch ._make_per_tensor_quantized_tensor (weight , 0.01 , 3 )
1313- weight = torch .dequantize (weight )
1314-
1315- bias = torch .quantize_per_tensor (bias , 0.0001 , 0 , torch .qint32 )
1316- bias = torch .dequantize (bias )
1312+ def _forward (self , input , weight , bias ):
1313+ input = torch .ops .quantized_decomposed .dequantize_per_tensor .default (
1314+ input , 0.01 , 7 , - 128 , 127 , torch .int8
1315+ )
1316+ weight = torch .ops .quantized_decomposed .dequantize_per_tensor .default (
1317+ weight , 0.01 , 3 , - 128 , 127 , torch .int8
1318+ )
1319+ bias = torch .ops .quantized_decomposed .dequantize_per_tensor .default (
1320+ bias , 1 , 0 , - 1000 , 1000 , torch .int32
1321+ )
13171322
1318- return torch .ops .aten .conv2d (
1319- inputVec ,
1323+ conv = torch .ops .aten .conv2d (
1324+ input ,
13201325 weight ,
13211326 bias = bias ,
13221327 stride = [1 , 1 ],
@@ -1325,6 +1330,11 @@ def _forward(self, inputVec, weight, bias):
13251330 groups = self .groups ,
13261331 )
13271332
1333+ # Use int32 to avoid overflows
1334+ return torch .ops .quantized_decomposed .quantize_per_tensor .default (
1335+ conv , 1 , 0 , - (2 ** 31 ), 2 ** 31 - 1 , torch .int32
1336+ )
1337+
13281338
13291339class Conv2dQInt8ModuleDyn (Conv2dQInt8ModuleBase ):
13301340 @export
@@ -1333,7 +1343,7 @@ class Conv2dQInt8ModuleDyn(Conv2dQInt8ModuleBase):
13331343 None ,
13341344 ([- 1 , - 1 , - 1 , - 1 ], torch .int8 , True ),
13351345 ([- 1 , - 1 , - 1 , - 1 ], torch .int8 , True ),
1336- ([- 1 ], torch .float , True ),
1346+ ([- 1 ], torch .int32 , True ),
13371347 ]
13381348 )
13391349 def forward (self , inputVec , weight , bias ):
@@ -1347,7 +1357,7 @@ class Conv2dQInt8ModuleStatic(Conv2dQInt8ModuleBase):
13471357 None ,
13481358 ([2 , 3 , 12 , 12 ], torch .int8 , True ),
13491359 ([3 , 1 , 5 , 3 ], torch .int8 , True ),
1350- ([3 ], torch .float , True ),
1360+ ([3 ], torch .int32 , True ),
13511361 ]
13521362 )
13531363 def forward (self , inputVec , weight , bias ):
@@ -1361,7 +1371,7 @@ class Conv2dQInt8ModuleStatic_MoreOutChannels(Conv2dQInt8ModuleBase):
13611371 None ,
13621372 ([2 , 3 , 12 , 12 ], torch .int8 , True ),
13631373 ([6 , 1 , 5 , 3 ], torch .int8 , True ),
1364- ([6 ], torch .float , True ),
1374+ ([6 ], torch .int32 , True ),
13651375 ]
13661376 )
13671377 def forward (self , inputVec , weight , bias ):
@@ -1372,23 +1382,23 @@ def forward(self, inputVec, weight, bias):
13721382def Conv2dQInt8Module_basic (module , tu : TestUtils ):
13731383 inputVec = tu .randint (2 , 4 , 7 , 8 , low = - 128 , high = 127 ).to (torch .int8 )
13741384 weight = tu .randint (3 , 4 , 3 , 2 , low = - 128 , high = 127 ).to (torch .int8 )
1375- bias = torch . rand ( 3 )
1385+ bias = tu . randint ( 3 , low = - 1000 , high = 1000 ). to ( torch . int32 )
13761386 module .forward (inputVec , weight , bias )
13771387
13781388
13791389@register_test_case (module_factory = lambda : Conv2dQInt8ModuleDyn (groups = 2 ))
13801390def Conv2dQInt8Module_grouped (module , tu : TestUtils ):
13811391 inputVec = tu .randint (2 , 8 , 7 , 8 , low = - 128 , high = 127 ).to (torch .int8 )
13821392 weight = tu .randint (6 , 4 , 3 , 2 , low = - 128 , high = 127 ).to (torch .int8 )
1383- bias = torch . rand ( 6 )
1393+ bias = tu . randint ( 6 , low = - 1000 , high = 1000 ). to ( torch . int32 )
13841394 module .forward (inputVec , weight , bias )
13851395
13861396
13871397@register_test_case (module_factory = lambda : Conv2dQInt8ModuleStatic (groups = 3 ))
13881398def Conv2dQInt8Module_depthwise (module , tu : TestUtils ):
13891399 inputVec = tu .randint (2 , 3 , 12 , 12 , low = - 128 , high = 127 ).to (torch .int8 )
13901400 weight = tu .randint (3 , 1 , 5 , 3 , low = - 128 , high = 127 ).to (torch .int8 )
1391- bias = torch . rand ( 3 )
1401+ bias = tu . randint ( 3 , low = - 1000 , high = 1000 ). to ( torch . int32 )
13921402 module .forward (inputVec , weight , bias )
13931403
13941404
@@ -1398,7 +1408,7 @@ def Conv2dQInt8Module_depthwise(module, tu: TestUtils):
13981408def Conv2dQInt8Module_not_depthwise (module , tu : TestUtils ):
13991409 inputVec = tu .randint (2 , 3 , 12 , 12 , low = - 128 , high = 127 ).to (torch .int8 )
14001410 weight = tu .randint (6 , 1 , 5 , 3 , low = - 128 , high = 127 ).to (torch .int8 )
1401- bias = torch . rand ( 6 )
1411+ bias = tu . randint ( 6 , low = - 1000 , high = 1000 ). to ( torch . int32 )
14021412 module .forward (inputVec , weight , bias )
14031413
14041414
@@ -1417,24 +1427,29 @@ def __init__(self):
14171427 ]
14181428 )
14191429 def forward (self , input , weight , bias ):
1420- qinput = torch ._make_per_tensor_quantized_tensor (input , 0.01 , - 25 )
1421- qinput = torch .dequantize (qinput )
1422- qweight = torch ._make_per_tensor_quantized_tensor (weight , 0.01 , 50 )
1423- qweight = torch .dequantize (qweight )
1424- qbias = torch .quantize_per_tensor (bias , 0.0001 , 0 , torch .qint32 )
1425- qbias = torch .dequantize (qbias )
1426- qz = torch .ops .aten .convolution (
1427- qinput ,
1428- qweight ,
1429- bias = qbias ,
1430+ input = torch .ops .quantized_decomposed .dequantize_per_tensor .default (
1431+ input , 0.01 , - 25 , - 128 , 127 , torch .int8
1432+ )
1433+ weight = torch .ops .quantized_decomposed .dequantize_per_tensor .default (
1434+ weight , 0.01 , 50 , - 128 , 127 , torch .int8
1435+ )
1436+
1437+ res = torch .ops .aten .convolution (
1438+ input ,
1439+ weight ,
1440+ bias = bias ,
14301441 stride = [2 , 1 ],
14311442 padding = [1 , 1 ],
14321443 dilation = [1 , 1 ],
14331444 transposed = True ,
14341445 output_padding = [0 , 0 ],
14351446 groups = 1 ,
14361447 )
1437- return qz
1448+
1449+ # Use int32 to avoid overflows
1450+ return torch .ops .quantized_decomposed .quantize_per_tensor .default (
1451+ res , 1 , 0 , - (2 ** 31 ), 2 ** 31 - 1 , torch .int32
1452+ )
14381453
14391454
14401455@register_test_case (module_factory = lambda : ConvTranspose2DQInt8Module ())
@@ -1459,18 +1474,14 @@ def __init__(self, groups=1):
14591474 super ().__init__ ()
14601475
14611476 def _forward (self , inputVec , weight , scales , zeropoints , bias ):
1462- inputVec = torch ._make_per_tensor_quantized_tensor ( inputVec , 0.01 , 7 )
1463- inputVec = torch .dequantize ( inputVec )
1464-
1465- weight = torch ._make_per_channel_quantized_tensor (
1466- weight , scales , zeropoints , axis = 0
1477+ inputVec = torch .ops . quantized_decomposed . dequantize_per_tensor . default (
1478+ inputVec , 0.01 , 7 , - 128 , 127 , torch .int8
1479+ )
1480+ weight = torch .ops . quantized_decomposed . dequantize_per_channel . default (
1481+ weight , scales , zeropoints , 0 , - 128 , 127 , torch . int8
14671482 )
1468- weight = torch .dequantize (weight )
1469-
1470- bias = torch .quantize_per_tensor (bias , 0.0001 , 0 , torch .qint32 )
1471- bias = torch .dequantize (bias )
14721483
1473- return torch .ops .aten .conv2d (
1484+ conv = torch .ops .aten .conv2d (
14741485 inputVec ,
14751486 weight ,
14761487 bias = bias ,
@@ -1480,6 +1491,11 @@ def _forward(self, inputVec, weight, scales, zeropoints, bias):
14801491 groups = self .groups ,
14811492 )
14821493
1494+ # Use int32 to avoid overflows
1495+ return torch .ops .quantized_decomposed .quantize_per_tensor .default (
1496+ conv , 1 , 0 , - (2 ** 31 ), 2 ** 31 - 1 , torch .int32
1497+ )
1498+
14831499
14841500class Conv2dQInt8PerChannelModuleDyn (Conv2dQInt8PerChannelModuleBase ):
14851501 @export
0 commit comments