Skip to content

Commit e9903b8

Browse files
authored
Add int8/uint8 specialized variants of quantized_add_per_tensor
Differential Revision: D81951110 Pull Request resolved: #14094
1 parent 79c8e49 commit e9903b8

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,48 @@ def quantized_add_per_tensor(
193193
)
194194

195195

196+
@impl(m, "quantized_add_asym8sxasym8s_asym8s.per_tensor")
197+
def quantized_add_asym8sxasym8s_asym8s_per_tensor(
198+
X: torch.Tensor,
199+
X_scale: float,
200+
X_zero_point: int,
201+
Y: torch.Tensor,
202+
Y_scale: float,
203+
Y_zero_point: int,
204+
out_scale: float,
205+
out_zero_point: int,
206+
) -> torch.Tensor:
207+
if X.dtype != torch.int8:
208+
raise ValueError("X dtype must be torch.int8")
209+
if Y.dtype != torch.int8:
210+
raise ValueError("Y dtype must be torch.int8")
211+
212+
return quantized_add_per_tensor(
213+
X, X_scale, X_zero_point, Y, Y_scale, Y_zero_point, out_scale, out_zero_point
214+
)
215+
216+
217+
@impl(m, "quantized_add_asym8uxasym8u_asym8u.per_tensor")
218+
def quantized_add_asym8uxasym8u_asym8u_per_tensor(
219+
X: torch.Tensor,
220+
X_scale: float,
221+
X_zero_point: int,
222+
Y: torch.Tensor,
223+
Y_scale: float,
224+
Y_zero_point: int,
225+
out_scale: float,
226+
out_zero_point: int,
227+
) -> torch.Tensor:
228+
if X.dtype != torch.uint8:
229+
raise ValueError("X dtype must be torch.int8")
230+
if Y.dtype != torch.uint8:
231+
raise ValueError("Y dtype must be torch.int8")
232+
233+
return quantized_add_per_tensor(
234+
X, X_scale, X_zero_point, Y, Y_scale, Y_zero_point, out_scale, out_zero_point
235+
)
236+
237+
196238
def quantized_linear_common(
197239
src: torch.Tensor,
198240
weight: torch.Tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def test_dequantize_per_tensor(
100100
[
101101
# Only these types need to be tested as per ET_FORALL_JARVIS_QUANTIZED_TYPES in
102102
# on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/operators.h
103-
("int16", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 6, torch.int8),
103+
("int8", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 6, torch.int8),
104104
("uint8", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 6, torch.uint8),
105105
]
106106
)
@@ -122,6 +122,27 @@ def test_quantized_add(
122122
Y_tensor = torch.tensor([Y], dtype=dtype)
123123
expected_output = torch.tensor([expected_value], dtype=dtype)
124124

125+
quantized_add = (
126+
torch.ops.cadence.quantized_add_asym8sxasym8s_asym8s.per_tensor
127+
if dtype == torch.int8
128+
else torch.ops.cadence.quantized_add_asym8uxasym8u_asym8u.per_tensor
129+
)
130+
output = quantized_add(
131+
X_tensor,
132+
X_scale,
133+
X_zero_point,
134+
Y_tensor,
135+
Y_scale,
136+
Y_zero_point,
137+
out_scale,
138+
out_zero_point,
139+
)
140+
141+
self.assertTrue(
142+
torch.equal(output, expected_output),
143+
f"Values don't match in {name}: got {output}, expected {expected_output}",
144+
)
145+
125146
output = torch.ops.cadence.quantized_add(
126147
X_tensor,
127148
X_scale,

0 commit comments

Comments
 (0)