24
24
25
25
@impl (m , "quantize_per_tensor" )
26
26
def quantize_per_tensor (
27
- input : torch .Tensor ,
27
+ input_tensor : torch .Tensor ,
28
28
scale : float ,
29
29
zero_point : int ,
30
30
quant_min : int ,
@@ -35,10 +35,10 @@ def quantize_per_tensor(
35
35
Quantizes a floating-point tensor to an integral tensor.
36
36
37
37
Args:
38
- - input (Tensor): input tensor
39
- - scale (float): Quantization scale. Derived from the ratio
38
+ - input_tensor (Tensor): input tensor
39
+ - scale (float): Inverse of quantization scale. Derived from the ratio
40
40
between the min/max of the floating-point tensor and the
41
- min/max of the quantized range.
41
+ min/max of the quantized range, and then inverted .
42
42
- zero_point (int): The point which represents 0 in the quantized
43
43
range. For example, consider the floating point range [-1., 2.] and
44
44
quantized integer range [-7, 7]. In this case, 0 is 1/3 of way from
@@ -61,7 +61,12 @@ def quantize_per_tensor(
61
61
raise ValueError (
62
62
f"Unsupported dtype to quantize to. Supported dtypes must be one of { supported_quant_types } "
63
63
)
64
- return torch .round (input / scale + zero_point ).to (dtype )
64
+
65
+ dequantized = torch .round (input_tensor * scale + zero_point ).to (dtype )
66
+ return torch .max (
67
+ torch .min (dequantized , torch .tensor (quant_max )),
68
+ torch .tensor (quant_min ),
69
+ )
65
70
66
71
67
72
@impl (m , "dequantize_per_tensor" )
@@ -173,9 +178,16 @@ def quantized_add(
173
178
dequant_X = X_scale * (X - X_zero_point )
174
179
dequant_Y = Y_scale * (Y - Y_zero_point )
175
180
181
+ out_scale_inv = 1 / out_scale
182
+
176
183
# q_min/q_max are unused args
177
184
return quantize_per_tensor (
178
- dequant_X + dequant_Y , out_scale , out_zero_point , - 128 , 127 , dtype
185
+ dequant_X + dequant_Y ,
186
+ out_scale_inv ,
187
+ out_zero_point ,
188
+ torch .iinfo (dtype ).min ,
189
+ torch .iinfo (dtype ).max ,
190
+ dtype ,
179
191
)
180
192
181
193
@@ -206,6 +218,7 @@ def quantized_linear(
206
218
- offset (Tensor): Unused
207
219
"""
208
220
out_scale = - out_multiplier * (1 / (1 << 31 )) * (2 ** out_shift [0 ])
221
+ out_scale_inv = 1 / out_scale
209
222
210
223
N , K = weight .shape
211
224
@@ -223,7 +236,12 @@ def quantized_linear(
223
236
src - in_zero_point , weight - weight_zero_point , bias
224
237
)
225
238
return quantize_per_tensor (
226
- out , out_scale , out_zero_point , - 128 , 127 , dtype
239
+ out ,
240
+ out_scale_inv ,
241
+ out_zero_point ,
242
+ torch .iinfo (dtype ).min ,
243
+ torch .iinfo (dtype ).max ,
244
+ dtype ,
227
245
).reshape (* leading_dims , N )
228
246
229
247
0 commit comments