2424
2525@impl (m , "quantize_per_tensor" )
2626def quantize_per_tensor (
27- input : torch .Tensor ,
27+ input_tensor : torch .Tensor ,
2828 scale : float ,
2929 zero_point : int ,
3030 quant_min : int ,
@@ -35,10 +35,10 @@ def quantize_per_tensor(
3535 Quantizes a floating-point tensor to an integral tensor.
3636
3737 Args:
38- - input (Tensor): input tensor
39- - scale (float): Quantization scale. Derived from the ratio
38+ - input_tensor (Tensor): input tensor
39+ - scale (float): Inverse of quantization scale. Derived from the ratio
4040 between the min/max of the floating-point tensor and the
41- min/max of the quantized range.
41+ min/max of the quantized range, and then inverted .
4242 - zero_point (int): The point which represents 0 in the quantized
4343 range. For example, consider the floating point range [-1., 2.] and
4444 quantized integer range [-7, 7]. In this case, 0 is 1/3 of way from
@@ -61,7 +61,11 @@ def quantize_per_tensor(
6161 raise ValueError (
6262 f"Unsupported dtype to quantize to. Supported dtypes must be one of { supported_quant_types } "
6363 )
64- return torch .round (input / scale + zero_point ).to (dtype )
64+ tmp = torch .round (input_tensor * scale + zero_point ).to (dtype )
65+ return torch .max (
66+ torch .min (tmp , torch .tensor (torch .iinfo (dtype ).max )),
67+ torch .tensor (torch .iinfo (dtype ).min ),
68+ )
6569
6670
6771@impl (m , "dequantize_per_tensor" )
@@ -173,9 +177,11 @@ def quantized_add(
173177 dequant_X = X_scale * (X - X_zero_point )
174178 dequant_Y = Y_scale * (Y - Y_zero_point )
175179
180+ out_scale_inv = 1 / out_scale
181+
176182 # q_min/q_max are unused args
177183 return quantize_per_tensor (
178- dequant_X + dequant_Y , out_scale , out_zero_point , - 128 , 127 , dtype
184+ dequant_X + dequant_Y , out_scale_inv , out_zero_point , - 128 , 127 , dtype
179185 )
180186
181187
@@ -206,6 +212,7 @@ def quantized_linear(
206212 - offset (Tensor): The offset tensor
207213 """
208214 out_scale = - out_multiplier * (1 / (1 << 31 )) * (2 ** out_shift [0 ])
215+ out_scale_inv = 1 / out_scale
209216
210217 N , K = weight .shape
211218
@@ -228,7 +235,7 @@ def quantized_linear(
228235 src [m ] - in_zero_point , weight [n ] - weight_zero_point
229236 )
230237 out [m ][n ] = quantize_per_tensor (
231- dot , out_scale , out_zero_point , - 128 , 127 , torch .int8
238+ dot , out_scale_inv , out_zero_point , - 128 , 127 , torch .int8
232239 )
233240
234241 return out .reshape (* leading_dims , N )
0 commit comments