Skip to content

Commit 0b0e2dc

Browse files
authored
D81187339: [Cadence] Add backend-agnostic implementation for quantized_add
Differential Revision: D81341103 Pull Request resolved: #13820
1 parent 94284d7 commit 0b0e2dc

File tree

2 files changed

+109
-1
lines changed

2 files changed

+109
-1
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,13 @@ def quantize_per_tensor(
4848
is already provided.
4949
- dtype (torch.dtype): The type of the output tensor
5050
"""
51-
supported_quant_types = [torch.int8, torch.int16, torch.int32]
51+
supported_quant_types = [
52+
torch.int8,
53+
torch.int16,
54+
torch.int32,
55+
torch.uint8,
56+
torch.uint16,
57+
]
5258
if dtype not in supported_quant_types:
5359
raise ValueError(
5460
f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}"
@@ -112,6 +118,65 @@ def dequantize_per_tensor(
112118
return (input_tensor - zero_point).to(dtype) * scale
113119

114120

121+
@impl(m, "quantized_add")
122+
def quantized_add(
123+
X: torch.Tensor,
124+
X_scale: torch.Tensor,
125+
X_zero_point: torch.Tensor,
126+
Y: torch.Tensor,
127+
Y_scale: torch.Tensor,
128+
Y_zero_point: torch.Tensor,
129+
out_scale: float,
130+
out_zero_point: int,
131+
) -> torch.Tensor:
132+
"""
133+
Sums up two quantized tensors and returns another quantized tensor. The intuition
134+
is that we want dequant(out) ~= dequant(X) + dequant(Y)
135+
136+
If we do that math, we get
137+
out_scale(out - out_zero_point) = X_scale(X - X_zero_point) + Y_scale(Y - Y_zero_point)
138+
139+
Rearranging, we get
140+
out = (X_scale(X - X_zero_point) + Y_scale(Y - Y_zero_point)) / out_scale + out_zero_point
141+
142+
Args:
143+
- X (Tensor): The first operand
144+
- X_scale (Tensor): The ratio between the sizes of X's floating point and quantized
145+
ranges
146+
- X_zero_point (Tensor): The quantized mapping of zero for X
147+
- Y (Tensor): The second operand
148+
- Y_scale (Tensor): The ratio between the sizes of Y's floating point and quantized
149+
ranges
150+
- Y_zero_point (Tensor): The quantized mapping of zero for Y
151+
- out_scale (float): The ratio between the sizes of the output's floating point and
152+
quantized ranges
153+
- out_zero_point (int): The quantized mapping of zero for the output
154+
"""
155+
supported_dtypes = [torch.int8, torch.uint8]
156+
if X.dtype != Y.dtype:
157+
raise ValueError("X and Y dtypes need to match")
158+
159+
dtype = X.dtype
160+
if dtype not in supported_dtypes:
161+
raise ValueError(
162+
f"X and Y dtypes need to be in {supported_dtypes}. Got {dtype}"
163+
)
164+
165+
if dtype == torch.uint8:
166+
X = X.to(torch.int8)
167+
Y = Y.to(torch.int8)
168+
169+
# TODO(agrebenisan): This should be done in fixed point arithmetic, but to match the quantized_add_out.cpp
170+
# reference implementation, we'll do it in floating point.
171+
dequant_X = X_scale * (X - X_zero_point)
172+
dequant_Y = Y_scale * (Y - Y_zero_point)
173+
174+
# q_min/q_max are unused args
175+
return quantize_per_tensor(
176+
dequant_X + dequant_Y, out_scale, out_zero_point, -128, 127, dtype
177+
)
178+
179+
115180
@impl(m, "requantize")
116181
def requantize(
117182
input: torch.Tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from executorch.backends.cadence.aot.ref_implementations import (
1414
dequantize_per_tensor,
1515
quantize_per_tensor,
16+
quantized_add,
1617
)
1718
from executorch.backends.cadence.aot.typing_stubs import expand
1819

@@ -95,3 +96,45 @@ def test_dequantize_per_tensor(
9596
torch.allclose(output, expected_output, rtol=0.001, atol=0.001),
9697
f"Values don't match in {name}: got {output}, expected {expected_output}",
9798
)
99+
100+
@expand(
101+
[
102+
# Only these types need to be tested as per ET_FORALL_JARVIS_QUANTIZED_TYPES in
103+
# on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/operators.h
104+
("int16", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 6, torch.int8),
105+
("uint8", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 6, torch.uint8),
106+
]
107+
)
108+
def test_quantized_add(
109+
self,
110+
name: str,
111+
X: int,
112+
X_scale: float,
113+
X_zero_point: int,
114+
Y: int,
115+
Y_scale: float,
116+
Y_zero_point: int,
117+
out_scale: float,
118+
out_zero_point: int,
119+
expected_value: int,
120+
dtype: torch.dtype,
121+
) -> None:
122+
X_tensor = torch.tensor([X], dtype=dtype)
123+
Y_tensor = torch.tensor([Y], dtype=dtype)
124+
expected_output = torch.tensor([expected_value], dtype=dtype)
125+
126+
output = quantized_add(
127+
X_tensor,
128+
torch.tensor(X_scale),
129+
torch.tensor(X_zero_point, dtype=dtype),
130+
Y_tensor,
131+
torch.tensor(Y_scale),
132+
torch.tensor(Y_zero_point, dtype=dtype),
133+
out_scale,
134+
out_zero_point,
135+
)
136+
137+
self.assertTrue(
138+
torch.equal(output, expected_output),
139+
f"Values don't match in {name}: got {output}, expected {expected_output}",
140+
)

0 commit comments

Comments
 (0)