@@ -48,7 +48,13 @@ def quantize_per_tensor(
4848 is already provided.
4949 - dtype (torch.dtype): The type of the output tensor
5050 """
51- supported_quant_types = [torch .int8 , torch .int16 , torch .int32 ]
51+ supported_quant_types = [
52+ torch .int8 ,
53+ torch .int16 ,
54+ torch .int32 ,
55+ torch .uint8 ,
56+ torch .uint16 ,
57+ ]
5258 if dtype not in supported_quant_types :
5359 raise ValueError (
5460 f"Unsupported dtype to quantize to. Supported dtypes must be one of { supported_quant_types } "
@@ -112,6 +118,65 @@ def dequantize_per_tensor(
112118 return (input_tensor - zero_point ).to (dtype ) * scale
113119
114120
121+ @impl (m , "quantized_add" )
122+ def quantized_add (
123+ X : torch .Tensor ,
124+ X_scale : torch .Tensor ,
125+ X_zero_point : torch .Tensor ,
126+ Y : torch .Tensor ,
127+ Y_scale : torch .Tensor ,
128+ Y_zero_point : torch .Tensor ,
129+ out_scale : float ,
130+ out_zero_point : int ,
131+ ) -> torch .Tensor :
132+ """
133+ Sums up two quantized tensors and returns another quantized tensor. The intuition
134+ is that we want dequant(out) ~= dequant(X) + dequant(Y)
135+
136+ If we do that math, we get
137+ out_scale(out - out_zero_point) = X_scale(X - X_zero_point) + Y_scale(Y - Y_zero_point)
138+
139+ Rearranging, we get
140+ out = (X_scale(X - X_zero_point) + Y_scale(Y - Y_zero_point)) / out_scale + out_zero_point
141+
142+ Args:
143+ - X (Tensor): The first operand
144+ - X_scale (Tensor): The ratio between the sizes of X's floating point and quantized
145+ ranges
146+ - X_zero_point (Tensor): The quantized mapping of zero for X
147+ - Y (Tensor): The second operand
148+ - Y_scale (Tensor): The ratio between the sizes of Y's floating point and quantized
149+ ranges
150+ - Y_zero_point (Tensor): The quantized mapping of zero for Y
151+ - out_scale (float): The ratio between the sizes of the output's floating point and
152+ quantized ranges
153+ - out_zero_point (int): The quantized mapping of zero for the output
154+ """
155+ supported_dtypes = [torch .int8 , torch .uint8 ]
156+ if X .dtype != Y .dtype :
157+ raise ValueError ("X and Y dtypes need to match" )
158+
159+ dtype = X .dtype
160+ if dtype not in supported_dtypes :
161+ raise ValueError (
162+ f"X and Y dtypes need to be in { supported_dtypes } . Got { dtype } "
163+ )
164+
165+ if dtype == torch .uint8 :
166+ X = X .to (torch .int8 )
167+ Y = Y .to (torch .int8 )
168+
169+ # TODO(agrebenisan): This should be done in fixed point arithmetic, but to match the quantized_add_out.cpp
170+ # reference implementation, we'll do it in floating point.
171+ dequant_X = X_scale * (X - X_zero_point )
172+ dequant_Y = Y_scale * (Y - Y_zero_point )
173+
174+ # q_min/q_max are unused args
175+ return quantize_per_tensor (
176+ dequant_X + dequant_Y , out_scale , out_zero_point , - 128 , 127 , dtype
177+ )
178+
179+
115180@impl (m , "requantize" )
116181def requantize (
117182 input : torch .Tensor ,
0 commit comments