|
10 | 10 |
|
11 | 11 | import torch
|
12 | 12 |
|
13 |
| -from executorch.backends.cadence.aot.ref_implementations import quantize_per_tensor |
| 13 | +from executorch.backends.cadence.aot.ref_implementations import ( |
| 14 | + dequantize_per_tensor, |
| 15 | + quantize_per_tensor, |
| 16 | +) |
14 | 17 | from executorch.backends.cadence.aot.typing_stubs import expand
|
15 | 18 |
|
16 | 19 |
|
@@ -48,3 +51,47 @@ def test_quantize_per_tensor(
|
48 | 51 | torch.equal(output, expected_output),
|
49 | 52 | f"Values don't match in {name}: got {output}, expected {expected_output}",
|
50 | 53 | )
|
| 54 | + |
| 55 | + @expand( |
| 56 | + [ |
| 57 | + # Signed quantization ranges |
| 58 | + ("signed_range0_int8", 0, -1.0, 2.0, -7, 7, torch.int8, 0.428), |
| 59 | + ("signed_range0_int16", 0, -1.0, 2.0, -7, 7, torch.int16, 0.428), |
| 60 | + ("signed_range0_int32", 0, -1.0, 2.0, -7, 7, torch.int32, 0.428), |
| 61 | + ("signed_range1_int8", -3, -1.0, 5.0, -6, 7, torch.int8, 0.461), |
| 62 | + ("signed_range1_int16", -3, -1.0, 5.0, -6, 7, torch.int16, 0.461), |
| 63 | + ("signed_range1_int32", -3, -1.0, 5.0, -6, 7, torch.int32, 0.461), |
| 64 | + # Unsigned quantization ranges |
| 65 | + ("unsigned_range0_uint8", 3, -1.0, 2.0, 0, 7, torch.uint8, 0.428), |
| 66 | + ("unsigned_range0_uint16", 3, -1.0, 2.0, 0, 7, torch.uint16, 0.428), |
| 67 | + ("unsigned_range1_uint8", 4, -1.0, 5.0, 3, 7, torch.uint8, 0.0), |
| 68 | + ("unsigned_range1_uint16", 4, -1.0, 5.0, 3, 7, torch.uint16, 0.0), |
| 69 | + ] |
| 70 | + ) |
| 71 | + def test_dequantize_per_tensor( |
| 72 | + self, |
| 73 | + name: str, |
| 74 | + input_value: int, |
| 75 | + f_min: float, |
| 76 | + f_max: float, |
| 77 | + q_min: int, |
| 78 | + q_max: int, |
| 79 | + input_dtype: torch.dtype, |
| 80 | + expected_value: int, |
| 81 | + ) -> None: |
| 82 | + input_tensor = torch.tensor([input_value], dtype=input_dtype) |
| 83 | + scale = (f_max - f_min) / (q_max - q_min) |
| 84 | + zero_point = round(-f_min / scale) + q_min |
| 85 | + expected_output = torch.tensor([expected_value], dtype=torch.float32) |
| 86 | + |
| 87 | + output = dequantize_per_tensor( |
| 88 | + input_tensor, scale, zero_point, q_min, q_max, torch.float32 |
| 89 | + ) |
| 90 | + |
| 91 | + self.assertEqual( |
| 92 | + output.dtype, expected_output.dtype, f"Dtype mismatch in {name}" |
| 93 | + ) |
| 94 | + self.assertTrue( |
| 95 | + torch.allclose(output, expected_output, rtol=0.001, atol=0.001), |
| 96 | + f"Values don't match in {name}: got {output}, expected {expected_output}", |
| 97 | + ) |
0 commit comments