Skip to content

Commit 7cacfcd

Browse files
authored
Add backend-agnostic implementation for dequantize_per_tensor
Differential Revision: D81266532 Pull Request resolved: #13777
1 parent 389918b commit 7cacfcd

File tree

2 files changed

+104
-1
lines changed

2 files changed

+104
-1
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,62 @@ def quantize_per_tensor(
5656
return torch.round(input / scale + zero_point).to(dtype)
5757

5858

59+
@impl(m, "dequantize_per_tensor")
60+
def dequantize_per_tensor(
61+
input_tensor: torch.Tensor,
62+
scale: float,
63+
zero_point: int,
64+
quant_min: int,
65+
quant_max: int,
66+
dtype: torch.dtype,
67+
) -> torch.Tensor:
68+
"""
69+
Dequantizes an integral tensor to a floating-point tensor.
70+
71+
Args:
72+
- input (Tensor): input tensor
73+
- scale (float): Quantization scale. Derived from the ratio
74+
between the min/max of the floating-point tensor and the
75+
min/max of the quantized range.
76+
- zero_point (int): The point which represents 0 in the quantized
77+
range. For example, consider the floating point range [-1., 2.] and
78+
quantized integer range [-7, 7]. In this case, 0 is 1/3 of way from
79+
-1. to 2. So, the point that represents 0 in the quantized range should
80+
be 1/3 of the way from [-7, 7]. This ends up being -2 in the integer space.
81+
- quant_min (int): The smallest value in the quantized domain. Unused since scale
82+
is already provided.
83+
- quant_max (int): The largest value in the quantized domain. Unused since scale
84+
is already provided.
85+
- dtype (torch.dtype): The type of the output tensor. Must be a floating point type.
86+
"""
87+
supported_quant_types = [
88+
torch.int8,
89+
torch.int16,
90+
torch.int32,
91+
torch.uint8,
92+
torch.uint16,
93+
]
94+
if input_tensor.dtype not in supported_quant_types:
95+
raise ValueError(f"Input dtype must be one of {supported_quant_types}")
96+
supported_dequant_types = [
97+
torch.float,
98+
torch.float32,
99+
torch.float16,
100+
torch.bfloat16,
101+
]
102+
if dtype not in supported_dequant_types:
103+
raise ValueError(
104+
f"Unsupported dtype to dequantize to. Supported dtypes must be one of {supported_dequant_types}"
105+
)
106+
107+
# Needed to prevent underflow in cases where the zero_point is larger than
108+
# the quantized value.
109+
if not input_tensor.dtype.is_signed:
110+
input_tensor = input_tensor.to(torch.int32)
111+
112+
return (input_tensor - zero_point).to(dtype) * scale
113+
114+
59115
@impl(m, "requantize")
60116
def requantize(
61117
input: torch.Tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010

1111
import torch
1212

13-
from executorch.backends.cadence.aot.ref_implementations import quantize_per_tensor
13+
from executorch.backends.cadence.aot.ref_implementations import (
14+
dequantize_per_tensor,
15+
quantize_per_tensor,
16+
)
1417
from executorch.backends.cadence.aot.typing_stubs import expand
1518

1619

@@ -48,3 +51,47 @@ def test_quantize_per_tensor(
4851
torch.equal(output, expected_output),
4952
f"Values don't match in {name}: got {output}, expected {expected_output}",
5053
)
54+
55+
@expand(
56+
[
57+
# Signed quantization ranges
58+
("signed_range0_int8", 0, -1.0, 2.0, -7, 7, torch.int8, 0.428),
59+
("signed_range0_int16", 0, -1.0, 2.0, -7, 7, torch.int16, 0.428),
60+
("signed_range0_int32", 0, -1.0, 2.0, -7, 7, torch.int32, 0.428),
61+
("signed_range1_int8", -3, -1.0, 5.0, -6, 7, torch.int8, 0.461),
62+
("signed_range1_int16", -3, -1.0, 5.0, -6, 7, torch.int16, 0.461),
63+
("signed_range1_int32", -3, -1.0, 5.0, -6, 7, torch.int32, 0.461),
64+
# Unsigned quantization ranges
65+
("unsigned_range0_uint8", 3, -1.0, 2.0, 0, 7, torch.uint8, 0.428),
66+
("unsigned_range0_uint16", 3, -1.0, 2.0, 0, 7, torch.uint16, 0.428),
67+
("unsigned_range1_uint8", 4, -1.0, 5.0, 3, 7, torch.uint8, 0.0),
68+
("unsigned_range1_uint16", 4, -1.0, 5.0, 3, 7, torch.uint16, 0.0),
69+
]
70+
)
71+
def test_dequantize_per_tensor(
72+
self,
73+
name: str,
74+
input_value: int,
75+
f_min: float,
76+
f_max: float,
77+
q_min: int,
78+
q_max: int,
79+
input_dtype: torch.dtype,
80+
expected_value: int,
81+
) -> None:
82+
input_tensor = torch.tensor([input_value], dtype=input_dtype)
83+
scale = (f_max - f_min) / (q_max - q_min)
84+
zero_point = round(-f_min / scale) + q_min
85+
expected_output = torch.tensor([expected_value], dtype=torch.float32)
86+
87+
output = dequantize_per_tensor(
88+
input_tensor, scale, zero_point, q_min, q_max, torch.float32
89+
)
90+
91+
self.assertEqual(
92+
output.dtype, expected_output.dtype, f"Dtype mismatch in {name}"
93+
)
94+
self.assertTrue(
95+
torch.allclose(output, expected_output, rtol=0.001, atol=0.001),
96+
f"Values don't match in {name}: got {output}, expected {expected_output}",
97+
)

0 commit comments

Comments
 (0)