Skip to content

Commit ea5cf49

Browse files
authored
Add backend-agnostic implementation for quantized_linear
Differential Revision: D81363750 Pull Request resolved: #13897
1 parent 32e82bc commit ea5cf49

File tree

2 files changed

+152
-1
lines changed

2 files changed

+152
-1
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# pyre-strict
88

9+
from typing import Optional
10+
911
import torch
1012
from executorch.exir.scalar_type import ScalarType
1113
from torch.library import impl, Library
@@ -177,6 +179,54 @@ def quantized_add(
177179
)
178180

179181

182+
@impl(m, "quantized_linear")
183+
def quantized_linear(
184+
src: torch.Tensor,
185+
weight: torch.Tensor,
186+
bias: torch.Tensor,
187+
in_zero_point: int,
188+
weight_zero_point: torch.Tensor,
189+
out_multiplier: torch.Tensor,
190+
out_shift: torch.Tensor,
191+
out_zero_point: int,
192+
offset: Optional[torch.Tensor],
193+
) -> torch.Tensor:
194+
"""
195+
Quantized linear (transposed matmul) operation.
196+
197+
Args:
198+
- src (Tensor): The activations tensor
199+
- weight (Tensor): The weight tensor
200+
- bias (Tensor): The bias tensor
201+
- in_zero_point (int): The quantized mapping of zero for the input
202+
- weight_zero_point (Tensor): The quantized mapping of zero for the weight
203+
- out_multiplier (Tensor): The multiplier used to scale the output
204+
- out_shift (Tensor): The shift used to scale the output
205+
- out_zero_point (int): The quantized mapping of zero for the output
206+
- offset (Tensor): Unused
207+
"""
208+
out_scale = -out_multiplier * (1 / (1 << 31)) * (2 ** out_shift[0])
209+
210+
N, K = weight.shape
211+
212+
leading_dims = src.shape[:-1]
213+
src = src.view(-1, K)
214+
215+
dtype = src.dtype
216+
supported_dtypes = [torch.int8, torch.uint8, torch.int32]
217+
if dtype not in supported_dtypes:
218+
raise ValueError(
219+
f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_dtypes}"
220+
)
221+
222+
out = torch.nn.functional.linear(
223+
src - in_zero_point, weight - weight_zero_point, bias
224+
)
225+
return quantize_per_tensor(
226+
out, out_scale, out_zero_point, -128, 127, dtype
227+
).reshape(*leading_dims, N)
228+
229+
180230
@impl(m, "requantize")
181231
def requantize(
182232
input: torch.Tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8-
8+
import typing
99
import unittest
1010

11+
import numpy as np
1112
import torch
1213

1314
from executorch.backends.cadence.aot.ref_implementations import (
1415
dequantize_per_tensor,
1516
quantize_per_tensor,
1617
quantized_add,
18+
quantized_linear,
1719
)
1820
from executorch.backends.cadence.aot.typing_stubs import expand
1921

@@ -138,3 +140,102 @@ def test_quantized_add(
138140
torch.equal(output, expected_output),
139141
f"Values don't match in {name}: got {output}, expected {expected_output}",
140142
)
143+
144+
@expand(
145+
[
146+
# Test case 1: 1x2 input, 1x2 weight (1 output feature)
147+
(
148+
torch.Size([1, 2]), # src_shape: 1 sample, 2 input features
149+
torch.Size([1, 2]), # weight_shape: 1 output feature, 2 input features
150+
0, # in_zero_point
151+
torch.tensor([0, 0], dtype=torch.int8), # weight_zero_point
152+
torch.tensor(
153+
[1073741824], dtype=torch.int32
154+
), # out_multiplier (0.5 * 2^31)
155+
torch.tensor([0], dtype=torch.int8), # out_shift
156+
0, # out_zero_point
157+
torch.tensor([[-2]], dtype=torch.int8), # expected_output
158+
),
159+
# Test case 2: 1x3 input, 2x3 weight (2 output features)
160+
(
161+
torch.Size([1, 3]), # src_shape: 1 sample, 3 input features
162+
torch.Size([2, 3]), # weight_shape: 2 output features, 3 input features
163+
0, # in_zero_point
164+
torch.tensor([0, 0, 0], dtype=torch.int8), # weight_zero_point
165+
torch.tensor(
166+
[1073741824], dtype=torch.int32
167+
), # out_multiplier (0.5 * 2^31)
168+
torch.tensor([0], dtype=torch.int8), # out_shift
169+
0, # out_zero_point
170+
torch.tensor([[-10, -30]], dtype=torch.int8), # expected_output
171+
),
172+
# Test case 3: Batch case with different dimensions
173+
(
174+
torch.Size([1, 2, 2]), # src_shape: batch=1, seq=2, features=2
175+
torch.Size([3, 2]), # weight_shape: 3 output features, 2 input features
176+
0, # in_zero_point
177+
torch.tensor([0, 0], dtype=torch.int8), # weight_zero_point
178+
torch.tensor(
179+
[1073741824], dtype=torch.int32
180+
), # out_multiplier (0.5 * 2^31)
181+
torch.tensor([0], dtype=torch.int8), # out_shift
182+
0, # out_zero_point
183+
torch.tensor(
184+
[[[-2, -8, -14], [-6, -28, -50]]], dtype=torch.int8
185+
), # expected_output
186+
),
187+
# Test case 4: Non-zero zero points
188+
(
189+
torch.Size([1, 2]), # src_shape: 1 sample, 2 input features
190+
torch.Size([2, 2]), # weight_shape: 2 output feature, 1 input feature
191+
2, # in_zero_point
192+
torch.tensor([1, 1], dtype=torch.int8), # weight_zero_point
193+
torch.tensor(
194+
[268435456], dtype=torch.int32
195+
), # out_multiplier (1.0 * 2^31)
196+
torch.tensor([0]), # out_shift
197+
1, # out_zero_point
198+
torch.tensor([[-15, 25]], dtype=torch.int8), # expected_output
199+
),
200+
]
201+
)
202+
def test_quantized_linear(
203+
self,
204+
src_shape: torch.Size,
205+
weight_shape: torch.Size,
206+
in_zero_point: int,
207+
weight_zero_point: torch.Tensor,
208+
out_multiplier: torch.Tensor,
209+
out_shift: torch.Tensor,
210+
out_zero_point: int,
211+
expected_output: torch.Tensor,
212+
) -> None:
213+
src = (
214+
torch.arange(np.product(src_shape))
215+
.reshape(src_shape)
216+
.to(expected_output.dtype)
217+
)
218+
weight = (
219+
torch.arange(np.product(weight_shape))
220+
.reshape(weight_shape)
221+
.to(expected_output.dtype)
222+
)
223+
bias = torch.arange(weight_shape[0]).to(expected_output.dtype)
224+
output = quantized_linear(
225+
src,
226+
weight,
227+
bias,
228+
in_zero_point,
229+
weight_zero_point,
230+
out_multiplier,
231+
out_shift,
232+
out_zero_point,
233+
typing.cast(torch.Tensor, None),
234+
)
235+
236+
self.assertTrue(output.dtype == expected_output.dtype, "Dtype mismatch")
237+
238+
self.assertTrue(
239+
torch.equal(output, expected_output),
240+
f"Values don't match: got {output}, expected {expected_output}",
241+
)

0 commit comments

Comments
 (0)