@@ -48,7 +48,13 @@ def quantize_per_tensor(
4848 is already provided.
4949 - dtype (torch.dtype): The type of the output tensor
5050 """
51- supported_quant_types = [torch .int8 , torch .int16 , torch .int32 ]
51+ supported_quant_types = [
52+ torch .int8 ,
53+ torch .int16 ,
54+ torch .int32 ,
55+ torch .uint8 ,
56+ torch .uint16 ,
57+ ]
5258 if dtype not in supported_quant_types :
5359 raise ValueError (
5460 f"Unsupported dtype to quantize to. Supported dtypes must be one of { supported_quant_types } "
@@ -116,6 +122,66 @@ def dequantize_per_tensor(
116122 return (input - zero_point ).to (dtype ) * scale
117123
118124
125+ @impl (m , "quantized_add" )
126+ def quantized_add (
127+ X : torch .Tensor ,
128+ X_scale : torch .Tensor ,
129+ X_zero_point : torch .Tensor ,
130+ Y : torch .Tensor ,
131+ Y_scale : torch .Tensor ,
132+ Y_zero_point : torch .Tensor ,
133+ out_scale : float ,
134+ out_zero_point : int ,
135+ ) -> torch .Tensor :
136+ """
137+ Sums up two quantized tensors and returns another quantized tensor. The intuition
138+ is that we want dequant(out) ~= dequant(X) + dequant(Y)
139+
140+ If we do that math, we get
141+ out_scale(out - out_zero_point) = X_scale(X - X_zero_point) + Y_scale(Y - Y_zero_point)
142+
143+ Rearranging, we get
144+ out = (X_scale(X - X_zero_point) + Y_scale(Y - Y_zero_point)) / out_scale + out_zero_point
145+
146+ Args:
147+ - X (Tensor): The first operand
148+ - X_scale (Tensor): The ratio between the sizes of X's floating point and quantized
149+ ranges
150+ - X_zero_point (Tensor): The quantized mapping of zero for X
151+ - Y (Tensor): The second operand
152+ - Y_scale (Tensor): The ratio between the sizes of Y's floating point and quantized
153+ ranges
154+ - Y_zero_point (Tensor): The quantized mapping of zero for Y
155+ - out_scale (float): The ratio between the sizes of the output's floating point and
156+ quantized ranges
157+ - out_zero_point (int): The quantized mapping of zero for the output
158+ """
159+ supported_dtypes = [torch .int8 , torch .uint8 ]
160+ if X .dtype != Y .dtype :
161+ raise ValueError ("X and Y dtypes need to match" )
162+
163+ dtype = X .dtype
164+ if dtype not in supported_dtypes :
165+ raise ValueError (
166+ f"X and Y dtypes need to be in { supported_dtypes } . Got { dtype } "
167+ )
168+
169+ if dtype == torch .uint8 :
170+ X = X .to (torch .int8 )
171+ Y = Y .to (torch .int8 )
172+
173+ # TODO(agrebenisan): This should be done in fixed point arithmetic, but to match the quantized_add_out.cpp
174+ # reference implementation, we'll do it in floating point.
175+ dequant_X = X_scale * (X - X_zero_point )
176+ dequant_Y = Y_scale * (Y - Y_zero_point )
177+ inv_out_scale = 1 / out_scale
178+
179+ # q_min/q_max are unused args
180+ return quantize_per_tensor (
181+ dequant_X + dequant_Y , inv_out_scale , out_zero_point , - 128 , 127 , dtype
182+ )
183+
184+
119185@impl (m , "requantize" )
120186def requantize (
121187 input : torch .Tensor ,
0 commit comments