@@ -127,14 +127,14 @@ def dequantize_per_tensor(
127127 return (input_tensor - zero_point ).to (dtype ) * scale
128128
129129
130- @impl (m , "quantized_add" )
131- def quantized_add (
130+ @impl (m , "quantized_add.per_tensor " )
131+ def quantized_add_per_tensor (
132132 X : torch .Tensor ,
133- X_scale : torch . Tensor ,
134- X_zero_point : torch . Tensor ,
133+ X_scale : float ,
134+ X_zero_point : int ,
135135 Y : torch .Tensor ,
136- Y_scale : torch . Tensor ,
137- Y_zero_point : torch . Tensor ,
136+ Y_scale : float ,
137+ Y_zero_point : int ,
138138 out_scale : float ,
139139 out_zero_point : int ,
140140) -> torch .Tensor :
@@ -149,17 +149,17 @@ def quantized_add(
149149 out = (X_scale(X - X_zero_point) + Y_scale(Y - Y_zero_point)) / out_scale + out_zero_point
150150
151151 Args:
152- - X (Tensor) : The first operand
153- - X_scale (Tensor) : The ratio between the sizes of X's floating point and quantized
152+ - X: The first operand
153+ - X_scale: The ratio between the sizes of X's floating point and quantized
154154 ranges
155- - X_zero_point (Tensor) : The quantized mapping of zero for X
156- - Y (Tensor) : The second operand
157- - Y_scale (Tensor) : The ratio between the sizes of Y's floating point and quantized
155+ - X_zero_point: The quantized mapping of zero for X
156+ - Y: The second operand
157+ - Y_scale: The ratio between the sizes of Y's floating point and quantized
158158 ranges
159- - Y_zero_point (Tensor) : The quantized mapping of zero for Y
160- - out_scale (float) : The ratio between the sizes of the output's floating point and
159+ - Y_zero_point: The quantized mapping of zero for Y
160+ - out_scale: The ratio between the sizes of the output's floating point and
161161 quantized ranges
162- - out_zero_point (int) : The quantized mapping of zero for the output
162+ - out_zero_point: The quantized mapping of zero for the output
163163 """
164164 supported_dtypes = [torch .int8 , torch .uint8 ]
165165 if X .dtype != Y .dtype :
0 commit comments