|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
7 | 7 | # pyre-strict
|
8 |
| - |
| 8 | +import typing |
9 | 9 | import unittest
|
10 | 10 |
|
| 11 | +import numpy as np |
11 | 12 | import torch
|
12 | 13 |
|
13 | 14 | from executorch.backends.cadence.aot.ref_implementations import (
|
14 | 15 | dequantize_per_tensor,
|
15 | 16 | quantize_per_tensor,
|
16 | 17 | quantized_add,
|
| 18 | + quantized_linear, |
17 | 19 | )
|
18 | 20 | from executorch.backends.cadence.aot.typing_stubs import expand
|
19 | 21 |
|
@@ -138,3 +140,102 @@ def test_quantized_add(
|
138 | 140 | torch.equal(output, expected_output),
|
139 | 141 | f"Values don't match in {name}: got {output}, expected {expected_output}",
|
140 | 142 | )
|
| 143 | + |
| 144 | + @expand( |
| 145 | + [ |
| 146 | + # Test case 1: 1x2 input, 1x2 weight (1 output feature) |
| 147 | + ( |
| 148 | + torch.Size([1, 2]), # src_shape: 1 sample, 2 input features |
| 149 | + torch.Size([1, 2]), # weight_shape: 1 output feature, 2 input features |
| 150 | + 0, # in_zero_point |
| 151 | + torch.tensor([0, 0], dtype=torch.int8), # weight_zero_point |
| 152 | + torch.tensor( |
| 153 | + [1073741824], dtype=torch.int32 |
| 154 | + ), # out_multiplier (0.5 * 2^31) |
| 155 | + torch.tensor([0], dtype=torch.int8), # out_shift |
| 156 | + 0, # out_zero_point |
| 157 | + torch.tensor([[-2]], dtype=torch.int8), # expected_output |
| 158 | + ), |
| 159 | + # Test case 2: 1x3 input, 2x3 weight (2 output features) |
| 160 | + ( |
| 161 | + torch.Size([1, 3]), # src_shape: 1 sample, 3 input features |
| 162 | + torch.Size([2, 3]), # weight_shape: 2 output features, 3 input features |
| 163 | + 0, # in_zero_point |
| 164 | + torch.tensor([0, 0, 0], dtype=torch.int8), # weight_zero_point |
| 165 | + torch.tensor( |
| 166 | + [1073741824], dtype=torch.int32 |
| 167 | + ), # out_multiplier (0.5 * 2^31) |
| 168 | + torch.tensor([0], dtype=torch.int8), # out_shift |
| 169 | + 0, # out_zero_point |
| 170 | + torch.tensor([[-10, -30]], dtype=torch.int8), # expected_output |
| 171 | + ), |
| 172 | + # Test case 3: Batch case with different dimensions |
| 173 | + ( |
| 174 | + torch.Size([1, 2, 2]), # src_shape: batch=1, seq=2, features=2 |
| 175 | + torch.Size([3, 2]), # weight_shape: 3 output features, 2 input features |
| 176 | + 0, # in_zero_point |
| 177 | + torch.tensor([0, 0], dtype=torch.int8), # weight_zero_point |
| 178 | + torch.tensor( |
| 179 | + [1073741824], dtype=torch.int32 |
| 180 | + ), # out_multiplier (0.5 * 2^31) |
| 181 | + torch.tensor([0], dtype=torch.int8), # out_shift |
| 182 | + 0, # out_zero_point |
| 183 | + torch.tensor( |
| 184 | + [[[-2, -8, -14], [-6, -28, -50]]], dtype=torch.int8 |
| 185 | + ), # expected_output |
| 186 | + ), |
| 187 | + # Test case 4: Non-zero zero points |
| 188 | + ( |
| 189 | + torch.Size([1, 2]), # src_shape: 1 sample, 2 input features |
| 190 | + torch.Size([2, 2]), # weight_shape: 2 output feature, 1 input feature |
| 191 | + 2, # in_zero_point |
| 192 | + torch.tensor([1, 1], dtype=torch.int8), # weight_zero_point |
| 193 | + torch.tensor( |
| 194 | + [268435456], dtype=torch.int32 |
| 195 | + ), # out_multiplier (1.0 * 2^31) |
| 196 | + torch.tensor([0]), # out_shift |
| 197 | + 1, # out_zero_point |
| 198 | + torch.tensor([[-15, 25]], dtype=torch.int8), # expected_output |
| 199 | + ), |
| 200 | + ] |
| 201 | + ) |
| 202 | + def test_quantized_linear( |
| 203 | + self, |
| 204 | + src_shape: torch.Size, |
| 205 | + weight_shape: torch.Size, |
| 206 | + in_zero_point: int, |
| 207 | + weight_zero_point: torch.Tensor, |
| 208 | + out_multiplier: torch.Tensor, |
| 209 | + out_shift: torch.Tensor, |
| 210 | + out_zero_point: int, |
| 211 | + expected_output: torch.Tensor, |
| 212 | + ) -> None: |
| 213 | + src = ( |
| 214 | + torch.arange(np.product(src_shape)) |
| 215 | + .reshape(src_shape) |
| 216 | + .to(expected_output.dtype) |
| 217 | + ) |
| 218 | + weight = ( |
| 219 | + torch.arange(np.product(weight_shape)) |
| 220 | + .reshape(weight_shape) |
| 221 | + .to(expected_output.dtype) |
| 222 | + ) |
| 223 | + bias = torch.arange(weight_shape[0]).to(expected_output.dtype) |
| 224 | + output = quantized_linear( |
| 225 | + src, |
| 226 | + weight, |
| 227 | + bias, |
| 228 | + in_zero_point, |
| 229 | + weight_zero_point, |
| 230 | + out_multiplier, |
| 231 | + out_shift, |
| 232 | + out_zero_point, |
| 233 | + typing.cast(torch.Tensor, None), |
| 234 | + ) |
| 235 | + |
| 236 | + self.assertTrue(output.dtype == expected_output.dtype, "Dtype mismatch") |
| 237 | + |
| 238 | + self.assertTrue( |
| 239 | + torch.equal(output, expected_output), |
| 240 | + f"Values don't match: got {output}, expected {expected_output}", |
| 241 | + ) |
0 commit comments