2424
2525@impl (m , "quantize_per_tensor" )
2626def quantize_per_tensor (
27- input : torch .Tensor ,
27+ input_tensor : torch .Tensor ,
2828 scale : float ,
2929 zero_point : int ,
3030 quant_min : int ,
@@ -35,10 +35,10 @@ def quantize_per_tensor(
3535 Quantizes a floating-point tensor to an integral tensor.
3636
3737 Args:
38- - input (Tensor): input tensor
39- - scale (float): Quantization scale. Derived from the ratio
38+ - input_tensor (Tensor): input tensor
39+ - scale (float): Inverse of quantization scale. Derived from the ratio
4040 between the min/max of the floating-point tensor and the
41- min/max of the quantized range.
41+ min/max of the quantized range, and then inverted .
4242 - zero_point (int): The point which represents 0 in the quantized
4343 range. For example, consider the floating point range [-1., 2.] and
4444 quantized integer range [-7, 7]. In this case, 0 is 1/3 of way from
@@ -61,7 +61,13 @@ def quantize_per_tensor(
6161 raise ValueError (
6262 f"Unsupported dtype to quantize to. Supported dtypes must be one of { supported_quant_types } "
6363 )
64- return torch .round (input / scale + zero_point ).to (dtype )
64+
65+
66+ dequantized = torch .round (input_tensor * scale + zero_point ).to (dtype )
67+ return torch .max (
68+ torch .min (dequantized , torch .tensor (quant_max )),
69+ torch .tensor (quant_min ),
70+ )
6571
6672
6773@impl (m , "dequantize_per_tensor" )
@@ -173,9 +179,11 @@ def quantized_add(
173179 dequant_X = X_scale * (X - X_zero_point )
174180 dequant_Y = Y_scale * (Y - Y_zero_point )
175181
182+ out_scale_inv = 1 / out_scale
183+
176184 # q_min/q_max are unused args
177185 return quantize_per_tensor (
178- dequant_X + dequant_Y , out_scale , out_zero_point , - 128 , 127 , dtype
186+ dequant_X + dequant_Y , out_scale_inv , out_zero_point , torch . iinfo ( dtype ). min , torch . iinfo ( dtype ). max , dtype
179187 )
180188
181189
@@ -206,6 +214,7 @@ def quantized_linear(
206214 - offset (Tensor): Unused
207215 """
208216 out_scale = - out_multiplier * (1 / (1 << 31 )) * (2 ** out_shift [0 ])
217+ out_scale_inv = 1 / out_scale
209218
210219 N , K = weight .shape
211220
@@ -223,7 +232,7 @@ def quantized_linear(
223232 src - in_zero_point , weight - weight_zero_point , bias
224233 )
225234 return quantize_per_tensor (
226- out , out_scale , out_zero_point , - 128 , 127 , dtype
235+ out , out_scale_inv , out_zero_point , torch . iinfo ( dtype ). min , torch . iinfo ( dtype ). max , dtype
227236 ).reshape (* leading_dims , N )
228237
229238
0 commit comments