@@ -127,14 +127,14 @@ def dequantize_per_tensor(
127
127
return (input_tensor - zero_point ).to (dtype ) * scale
128
128
129
129
130
- @impl (m , "quantized_add" )
131
- def quantized_add (
130
+ @impl (m , "quantized_add.per_tensor " )
131
+ def quantized_add_per_tensor (
132
132
X : torch .Tensor ,
133
- X_scale : torch . Tensor ,
134
- X_zero_point : torch . Tensor ,
133
+ X_scale : float ,
134
+ X_zero_point : int ,
135
135
Y : torch .Tensor ,
136
- Y_scale : torch . Tensor ,
137
- Y_zero_point : torch . Tensor ,
136
+ Y_scale : float ,
137
+ Y_zero_point : int ,
138
138
out_scale : float ,
139
139
out_zero_point : int ,
140
140
) -> torch .Tensor :
@@ -149,17 +149,17 @@ def quantized_add(
149
149
out = (X_scale(X - X_zero_point) + Y_scale(Y - Y_zero_point)) / out_scale + out_zero_point
150
150
151
151
Args:
152
- - X (Tensor) : The first operand
153
- - X_scale (Tensor) : The ratio between the sizes of X's floating point and quantized
152
+ - X: The first operand
153
+ - X_scale: The ratio between the sizes of X's floating point and quantized
154
154
ranges
155
- - X_zero_point (Tensor) : The quantized mapping of zero for X
156
- - Y (Tensor) : The second operand
157
- - Y_scale (Tensor) : The ratio between the sizes of Y's floating point and quantized
155
+ - X_zero_point: The quantized mapping of zero for X
156
+ - Y: The second operand
157
+ - Y_scale: The ratio between the sizes of Y's floating point and quantized
158
158
ranges
159
- - Y_zero_point (Tensor) : The quantized mapping of zero for Y
160
- - out_scale (float) : The ratio between the sizes of the output's floating point and
159
+ - Y_zero_point: The quantized mapping of zero for Y
160
+ - out_scale: The ratio between the sizes of the output's floating point and
161
161
quantized ranges
162
- - out_zero_point (int) : The quantized mapping of zero for the output
162
+ - out_zero_point: The quantized mapping of zero for the output
163
163
"""
164
164
supported_dtypes = [torch .int8 , torch .uint8 ]
165
165
if X .dtype != Y .dtype :
0 commit comments