Skip to content

Commit e42c881

Browse files
authored
Add backend-agnostic implementation for quantize_per_tensor
Differential Revision: D81187339 Pull Request resolved: #13769
1 parent 3e69a64 commit e42c881

File tree

3 files changed

+100
-0
lines changed

3 files changed

+100
-0
lines changed

backends/cadence/aot/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,3 +604,17 @@ python_unittest(
604604
"//later:lib",
605605
],
606606
)
607+
608+
python_unittest(
609+
name = "test_ref_implementations",
610+
srcs = [
611+
"tests/test_ref_implementations.py",
612+
],
613+
supports_static_listing = False,
614+
typing = True,
615+
deps = [
616+
":typing_stubs",
617+
"//executorch/backends/cadence/aot:ref_implementations",
618+
"//caffe2:torch",
619+
]
620+
)

backends/cadence/aot/ref_implementations.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,42 @@
2020
}
2121

2222

23+
@impl(m, "quantize_per_tensor")
24+
def quantize_per_tensor(
25+
input: torch.Tensor,
26+
scale: float,
27+
zero_point: int,
28+
quant_min: int,
29+
quant_max: int,
30+
dtype: torch.dtype,
31+
) -> torch.Tensor:
32+
"""
33+
Quantizes a floating-point tensor to an integral tensor.
34+
35+
Args:
36+
- input (Tensor): input tensor
37+
- scale (float): Quantization scale. Derived from the ratio
38+
between the min/max of the floating-point tensor and the
39+
min/max of the quantized range.
40+
- zero_point (int): The point which represents 0 in the quantized
41+
range. For example, consider the floating point range [-1., 2.] and
42+
quantized integer range [-7, 7]. In this case, 0 is 1/3 of way from
43+
-1. to 2. So, the point that represents 0 in the quantized range should
44+
be 1/3 of the way from [-7, 7]. This ends up being -2 in the integer space.
45+
- quant_min (int): The smallest value in the quantized domain. Unused since scale
46+
is already provided.
47+
- quant_max (int): The largest value in the quantized domain. Unused since scale
48+
is already provided.
49+
- dtype (torch.dtype): The type of the output tensor
50+
"""
51+
supported_quant_types = [torch.int8, torch.int16, torch.int32]
52+
if dtype not in supported_quant_types:
53+
raise ValueError(
54+
f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}"
55+
)
56+
return torch.round(input / scale + zero_point).to(dtype)
57+
58+
2359
@impl(m, "requantize")
2460
def requantize(
2561
input: torch.Tensor,
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import unittest
10+
11+
import torch
12+
13+
from executorch.backends.cadence.aot.ref_implementations import quantize_per_tensor
14+
from executorch.backends.cadence.aot.typing_stubs import expand
15+
16+
17+
class TestRefImplementations(unittest.TestCase):
18+
@expand(
19+
[
20+
("basic_int8", 0.42, -1.0, 2.0, -7, 7, torch.int8, 0),
21+
("basic_int16", 0.42, -1.0, 5.0, -6, 7, torch.int16, -3),
22+
]
23+
)
24+
def test_quantize_per_tensor(
25+
self,
26+
name: str,
27+
input_value: float,
28+
f_min: float,
29+
f_max: float,
30+
q_min: int,
31+
q_max: int,
32+
target_dtype: torch.dtype,
33+
expected_value: int,
34+
) -> None:
35+
input_tensor = torch.tensor([input_value])
36+
scale = (f_max - f_min) / (q_max - q_min)
37+
zero_point = round(-f_min / scale) + q_min
38+
expected_output = torch.tensor([expected_value], dtype=target_dtype)
39+
40+
output = quantize_per_tensor(
41+
input_tensor, scale, zero_point, q_min, q_max, target_dtype
42+
)
43+
44+
self.assertEqual(
45+
output.dtype, expected_output.dtype, f"Dtype mismatch in {name}"
46+
)
47+
self.assertTrue(
48+
torch.equal(output, expected_output),
49+
f"Values don't match in {name}: got {output}, expected {expected_output}",
50+
)

0 commit comments

Comments
 (0)