Skip to content

Commit 79c8e49

Browse files
authored
Remove non-per-tensor quantized add and replace with per-tensor variant
Differential Revision: D81950579 Pull Request resolved: #14093
1 parent 0b4fe31 commit 79c8e49

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,14 @@ def dequantize_per_tensor(
127127
return (input_tensor - zero_point).to(dtype) * scale
128128

129129

130-
@impl(m, "quantized_add")
131-
def quantized_add(
130+
@impl(m, "quantized_add.per_tensor")
131+
def quantized_add_per_tensor(
132132
X: torch.Tensor,
133-
X_scale: torch.Tensor,
134-
X_zero_point: torch.Tensor,
133+
X_scale: float,
134+
X_zero_point: int,
135135
Y: torch.Tensor,
136-
Y_scale: torch.Tensor,
137-
Y_zero_point: torch.Tensor,
136+
Y_scale: float,
137+
Y_zero_point: int,
138138
out_scale: float,
139139
out_zero_point: int,
140140
) -> torch.Tensor:
@@ -149,17 +149,17 @@ def quantized_add(
149149
out = (X_scale(X - X_zero_point) + Y_scale(Y - Y_zero_point)) / out_scale + out_zero_point
150150
151151
Args:
152-
- X (Tensor): The first operand
153-
- X_scale (Tensor): The ratio between the sizes of X's floating point and quantized
152+
- X: The first operand
153+
- X_scale: The ratio between the sizes of X's floating point and quantized
154154
ranges
155-
- X_zero_point (Tensor): The quantized mapping of zero for X
156-
- Y (Tensor): The second operand
157-
- Y_scale (Tensor): The ratio between the sizes of Y's floating point and quantized
155+
- X_zero_point: The quantized mapping of zero for X
156+
- Y: The second operand
157+
- Y_scale: The ratio between the sizes of Y's floating point and quantized
158158
ranges
159-
- Y_zero_point (Tensor): The quantized mapping of zero for Y
160-
- out_scale (float): The ratio between the sizes of the output's floating point and
159+
- Y_zero_point: The quantized mapping of zero for Y
160+
- out_scale: The ratio between the sizes of the output's floating point and
161161
quantized ranges
162-
- out_zero_point (int): The quantized mapping of zero for the output
162+
- out_zero_point: The quantized mapping of zero for the output
163163
"""
164164
supported_dtypes = [torch.int8, torch.uint8]
165165
if X.dtype != Y.dtype:

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,11 @@ def test_quantized_add(
124124

125125
output = torch.ops.cadence.quantized_add(
126126
X_tensor,
127-
torch.tensor(X_scale),
128-
torch.tensor(X_zero_point, dtype=dtype),
127+
X_scale,
128+
X_zero_point,
129129
Y_tensor,
130-
torch.tensor(Y_scale),
131-
torch.tensor(Y_zero_point, dtype=dtype),
130+
Y_scale,
131+
Y_zero_point,
132132
out_scale,
133133
out_zero_point,
134134
)

0 commit comments

Comments
 (0)