@@ -62,7 +62,7 @@ def quantize_per_tensor(
6262 ]
6363 if dtype not in supported_quant_types :
6464 raise ValueError (
65- f"Unsupported dtype to quantize to. Supported dtypes must be one of { supported_quant_types } "
65+ f"Unsupported dtype to quantize to { dtype } . Supported dtypes must be one of { supported_quant_types } "
6666 )
6767
6868 return torch .ops .quantized_decomposed .quantize_per_tensor (
@@ -264,7 +264,7 @@ def quantized_linear_common(
264264 supported_dtypes = [torch .int8 , torch .uint8 , torch .int32 ]
265265 if dtype not in supported_dtypes :
266266 raise ValueError (
267- f"Unsupported dtype to quantize to. Supported dtypes must be one of { supported_dtypes } "
267+ f"Unsupported dtype to quantize to { dtype } . Supported dtypes must be one of { supported_dtypes } "
268268 )
269269
270270 out = torch .nn .functional .linear (
@@ -427,25 +427,27 @@ def quantized_matmul(
427427 - out_multiplier (int): The multiplier used to scale the output
428428 - out_shift (int): The shift used to scale the output
429429 - out_zero_point (int): The quantized mapping of zero for the output
430- - transposed (bool): Whether to transpose the weight tensor
430+ - transposed (bool): Whether Y is transposed.
431431 """
432432 if bias is not None and not torch .all (bias == 0 ):
433433 raise ValueError ("bias must be None or all zeros since unused in out variant" )
434434
435- # Looks weird, but quantized linear assumes weights are pre-transposed,
436- # hence we transpose only if `transposed` is False.
437- if not transposed :
438- Y = Y .T
435+ if transposed :
436+ Y = Y .transpose (- 1 , - 2 )
439437
440- return quantized_linear_common (
441- X ,
442- Y ,
443- bias or torch .zeros (1 , dtype = torch .int32 ),
444- X_zero_point ,
445- Y_zero_point ,
446- out_multiplier ,
447- out_shift ,
438+ out_scale = 1.0 / (- out_multiplier * (1 / (1 << 31 )) * (2 ** out_shift ))
439+
440+ out = torch .matmul (
441+ (X - X_zero_point ).float (),
442+ (Y - Y_zero_point ).float (),
443+ )
444+ return quantize_per_tensor (
445+ out ,
446+ out_scale ,
448447 out_zero_point ,
448+ torch .iinfo (X .dtype ).min ,
449+ torch .iinfo (X .dtype ).max ,
450+ X .dtype ,
449451 )
450452
451453
0 commit comments