2424
2525@impl (m , "quantize_per_tensor" )
2626def quantize_per_tensor (
27- input : torch .Tensor ,
27+ input_tensor : torch .Tensor ,
2828 scale : float ,
2929 zero_point : int ,
3030 quant_min : int ,
@@ -35,10 +35,10 @@ def quantize_per_tensor(
3535 Quantizes a floating-point tensor to an integral tensor.
3636
3737 Args:
38- - input (Tensor): input tensor
39- - scale (float): Quantization scale. Derived from the ratio
38+ - input_tensor (Tensor): input tensor
39+ - scale (float): Inverse of quantization scale. Derived from the ratio
4040 between the min/max of the floating-point tensor and the
41- min/max of the quantized range.
41+ min/max of the quantized range, and then inverted .
4242 - zero_point (int): The point which represents 0 in the quantized
4343 range. For example, consider the floating point range [-1., 2.] and
4444 quantized integer range [-7, 7]. In this case, 0 is 1/3 of way from
@@ -61,7 +61,12 @@ def quantize_per_tensor(
6161 raise ValueError (
6262 f"Unsupported dtype to quantize to. Supported dtypes must be one of { supported_quant_types } "
6363 )
64- return torch .round (input / scale + zero_point ).to (dtype )
64+
65+ dequantized = torch .round (input_tensor * scale + zero_point ).to (dtype )
66+ return torch .max (
67+ torch .min (dequantized , torch .tensor (quant_max )),
68+ torch .tensor (quant_min ),
69+ )
6570
6671
6772@impl (m , "dequantize_per_tensor" )
@@ -173,9 +178,16 @@ def quantized_add(
173178 dequant_X = X_scale * (X - X_zero_point )
174179 dequant_Y = Y_scale * (Y - Y_zero_point )
175180
181+ out_scale_inv = 1 / out_scale
182+
176183 # q_min/q_max are unused args
177184 return quantize_per_tensor (
178- dequant_X + dequant_Y , out_scale , out_zero_point , - 128 , 127 , dtype
185+ dequant_X + dequant_Y ,
186+ out_scale_inv ,
187+ out_zero_point ,
188+ torch .iinfo (dtype ).min ,
189+ torch .iinfo (dtype ).max ,
190+ dtype ,
179191 )
180192
181193
@@ -206,6 +218,7 @@ def quantized_linear(
206218 - offset (Tensor): Unused
207219 """
208220 out_scale = - out_multiplier * (1 / (1 << 31 )) * (2 ** out_shift [0 ])
221+ out_scale_inv = 1 / out_scale
209222
210223 N , K = weight .shape
211224
@@ -223,7 +236,12 @@ def quantized_linear(
223236 src - in_zero_point , weight - weight_zero_point , bias
224237 )
225238 return quantize_per_tensor (
226- out , out_scale , out_zero_point , - 128 , 127 , dtype
239+ out ,
240+ out_scale_inv ,
241+ out_zero_point ,
242+ torch .iinfo (dtype ).min ,
243+ torch .iinfo (dtype ).max ,
244+ dtype ,
227245 ).reshape (* leading_dims , N )
228246
229247
0 commit comments