@@ -48,7 +48,13 @@ def quantize_per_tensor(
48
48
is already provided.
49
49
- dtype (torch.dtype): The type of the output tensor
50
50
"""
51
- supported_quant_types = [torch .int8 , torch .int16 , torch .int32 ]
51
+ supported_quant_types = [
52
+ torch .int8 ,
53
+ torch .int16 ,
54
+ torch .int32 ,
55
+ torch .uint8 ,
56
+ torch .uint16 ,
57
+ ]
52
58
if dtype not in supported_quant_types :
53
59
raise ValueError (
54
60
f"Unsupported dtype to quantize to. Supported dtypes must be one of { supported_quant_types } "
@@ -112,6 +118,65 @@ def dequantize_per_tensor(
112
118
return (input_tensor - zero_point ).to (dtype ) * scale
113
119
114
120
121
+ @impl (m , "quantized_add" )
122
+ def quantized_add (
123
+ X : torch .Tensor ,
124
+ X_scale : torch .Tensor ,
125
+ X_zero_point : torch .Tensor ,
126
+ Y : torch .Tensor ,
127
+ Y_scale : torch .Tensor ,
128
+ Y_zero_point : torch .Tensor ,
129
+ out_scale : float ,
130
+ out_zero_point : int ,
131
+ ) -> torch .Tensor :
132
+ """
133
+ Sums up two quantized tensors and returns another quantized tensor. The intuition
134
+ is that we want dequant(out) ~= dequant(X) + dequant(Y)
135
+
136
+ If we do that math, we get
137
+ out_scale(out - out_zero_point) = X_scale(X - X_zero_point) + Y_scale(Y - Y_zero_point)
138
+
139
+ Rearranging, we get
140
+ out = (X_scale(X - X_zero_point) + Y_scale(Y - Y_zero_point)) / out_scale + out_zero_point
141
+
142
+ Args:
143
+ - X (Tensor): The first operand
144
+ - X_scale (Tensor): The ratio between the sizes of X's floating point and quantized
145
+ ranges
146
+ - X_zero_point (Tensor): The quantized mapping of zero for X
147
+ - Y (Tensor): The second operand
148
+ - Y_scale (Tensor): The ratio between the sizes of Y's floating point and quantized
149
+ ranges
150
+ - Y_zero_point (Tensor): The quantized mapping of zero for Y
151
+ - out_scale (float): The ratio between the sizes of the output's floating point and
152
+ quantized ranges
153
+ - out_zero_point (int): The quantized mapping of zero for the output
154
+ """
155
+ supported_dtypes = [torch .int8 , torch .uint8 ]
156
+ if X .dtype != Y .dtype :
157
+ raise ValueError ("X and Y dtypes need to match" )
158
+
159
+ dtype = X .dtype
160
+ if dtype not in supported_dtypes :
161
+ raise ValueError (
162
+ f"X and Y dtypes need to be in { supported_dtypes } . Got { dtype } "
163
+ )
164
+
165
+ if dtype == torch .uint8 :
166
+ X = X .to (torch .int8 )
167
+ Y = Y .to (torch .int8 )
168
+
169
+ # TODO(agrebenisan): This should be done in fixed point arithmetic, but to match the quantized_add_out.cpp
170
+ # reference implementation, we'll do it in floating point.
171
+ dequant_X = X_scale * (X - X_zero_point )
172
+ dequant_Y = Y_scale * (Y - Y_zero_point )
173
+
174
+ # q_min/q_max are unused args
175
+ return quantize_per_tensor (
176
+ dequant_X + dequant_Y , out_scale , out_zero_point , - 128 , 127 , dtype
177
+ )
178
+
179
+
115
180
@impl (m , "requantize" )
116
181
def requantize (
117
182
input : torch .Tensor ,
0 commit comments