Skip to content

Commit de30390

Browse files
authored
Support for all quantized linear ops
Differential Revision: D81940978 Pull Request resolved: #14078
1 parent f294074 commit de30390

File tree

2 files changed

+251
-61
lines changed

2 files changed

+251
-61
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 98 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-strict
88

99

10-
from typing import Callable, Optional
10+
from typing import Callable
1111

1212
import torch
1313

@@ -193,17 +193,15 @@ def quantized_add(
193193
)
194194

195195

196-
@impl(m, "quantized_linear")
197-
def quantized_linear(
196+
def quantized_linear_common(
198197
src: torch.Tensor,
199198
weight: torch.Tensor,
200199
bias: torch.Tensor,
201200
in_zero_point: int,
202-
weight_zero_point: torch.Tensor,
203-
out_multiplier: torch.Tensor,
204-
out_shift: torch.Tensor,
201+
weight_zero_point: torch.Tensor | int,
202+
out_multiplier: torch.Tensor | int,
203+
out_shift: int,
205204
out_zero_point: int,
206-
offset: Optional[torch.Tensor],
207205
) -> torch.Tensor:
208206
"""
209207
Quantized linear (transposed matmul) operation.
@@ -219,7 +217,7 @@ def quantized_linear(
219217
- out_zero_point (int): The quantized mapping of zero for the output
220218
- offset (Tensor): Unused
221219
"""
222-
out_scale = -out_multiplier * (1 / (1 << 31)) * (2 ** out_shift[0])
220+
out_scale = -out_multiplier * (1 / (1 << 31)) * (2**out_shift)
223221
out_scale_inv = 1 / out_scale
224222

225223
N, K = weight.shape
@@ -235,7 +233,9 @@ def quantized_linear(
235233
)
236234

237235
out = torch.nn.functional.linear(
238-
src - in_zero_point, weight - weight_zero_point, bias
236+
(src - in_zero_point).float(),
237+
(weight - weight_zero_point).float(),
238+
bias.float(),
239239
)
240240
return quantize_per_tensor(
241241
out,
@@ -247,6 +247,95 @@ def quantized_linear(
247247
).reshape(*leading_dims, N)
248248

249249

250+
def quantized_linear_variant(
251+
per_tensor: bool,
252+
src_dtype: torch.dtype | None = None,
253+
weight_dtype: torch.dtype | None = None,
254+
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
255+
256+
def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
257+
def variant(
258+
src: torch.Tensor,
259+
weight: torch.Tensor,
260+
bias: torch.Tensor,
261+
in_zero_point: int,
262+
weight_zero_point: torch.Tensor | int,
263+
out_multiplier: torch.Tensor | int,
264+
out_shift: torch.Tensor | int,
265+
out_zero_point: int,
266+
offset: torch.Tensor | None = None,
267+
) -> torch.Tensor:
268+
if src_dtype and src.dtype != src_dtype:
269+
raise ValueError(
270+
f"src dtype must be {src_dtype}. Got {src.dtype} instead"
271+
)
272+
if weight_dtype and weight.dtype != weight_dtype:
273+
raise ValueError(
274+
f"weight dtype must be {weight_dtype}. Got {weight.dtype} instead"
275+
)
276+
if bias.dtype != torch.int32:
277+
raise ValueError(
278+
f"bias dtype must be torch.int32. Got {bias.dtype} instead"
279+
)
280+
281+
if per_tensor:
282+
assert isinstance(weight_zero_point, int)
283+
assert isinstance(out_multiplier, int)
284+
assert isinstance(out_shift, int)
285+
return quantized_linear_common(
286+
src,
287+
weight,
288+
bias,
289+
in_zero_point,
290+
weight_zero_point,
291+
out_multiplier,
292+
out_shift,
293+
out_zero_point,
294+
)
295+
else:
296+
assert isinstance(out_shift, torch.Tensor)
297+
if out_shift.numel() != 1:
298+
raise ValueError("out_shift must be a scalar")
299+
300+
if out_shift.dtype != torch.int64:
301+
raise ValueError("out_shift must be an int64")
302+
303+
return quantized_linear_common(
304+
src,
305+
weight,
306+
bias,
307+
in_zero_point,
308+
weight_zero_point,
309+
out_multiplier,
310+
int(out_shift.item()),
311+
out_zero_point,
312+
)
313+
314+
return variant
315+
316+
return decorator
317+
318+
319+
@impl(m, "quantized_linear")
320+
@quantized_linear_variant(False)
321+
def quantized_linear() -> torch.Tensor: ...
322+
323+
324+
@impl(m, "quantized_linear.per_tensor")
325+
@quantized_linear_variant(True)
326+
def quantized_linear_per_tensor() -> torch.Tensor: ...
327+
328+
329+
@impl(m, "quantized_linear_asym8sxasym8s_asym8s.per_tensor")
330+
@quantized_linear_variant(True, torch.int8, torch.int8)
331+
def quantized_linear_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: ...
332+
333+
334+
@impl(m, "quantized_linear_asym8uxasym8u_asym8u.per_tensor")
335+
@quantized_linear_variant(True, torch.uint8, torch.uint8)
336+
def quantized_linear_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ...
337+
338+
250339
@impl(m, "quantized_layer_norm.per_tensor")
251340
def quantized_layer_norm_per_tensor(
252341
input_tensor: torch.Tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 153 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -141,59 +141,139 @@ def test_quantized_add(
141141
@expand(
142142
[
143143
# Test case 1: 1x2 input, 1x2 weight (1 output feature)
144-
(
145-
torch.Size([1, 2]), # src_shape: 1 sample, 2 input features
146-
torch.Size([1, 2]), # weight_shape: 1 output feature, 2 input features
147-
0, # in_zero_point
148-
torch.tensor([0, 0], dtype=torch.int8), # weight_zero_point
149-
torch.tensor(
150-
[1073741824], dtype=torch.int32
151-
), # out_multiplier (0.5 * 2^31)
152-
torch.tensor([0], dtype=torch.int8), # out_shift
153-
0, # out_zero_point
154-
torch.tensor([[-2]], dtype=torch.int8), # expected_output
155-
),
144+
*[
145+
(
146+
torch.Size([1, 2]), # src_shape: 1 sample, 2 input features
147+
torch.Size(
148+
[1, 2]
149+
), # weight_shape: 1 output feature, 2 input features
150+
0, # in_zero_point
151+
torch.tensor([0, 0], dtype=dtype), # weight_zero_point
152+
torch.tensor(
153+
[1073741824], dtype=torch.int32
154+
), # out_multiplier (0.5 * 2^31)
155+
torch.tensor([0], dtype=torch.int64), # out_shift
156+
0, # out_zero_point
157+
torch.tensor([[-2]], dtype=dtype), # expected_output
158+
per_tensor,
159+
)
160+
for (per_tensor, dtype) in (
161+
(False, torch.int8),
162+
(True, torch.int8),
163+
(True, torch.uint8),
164+
)
165+
],
156166
# Test case 2: 1x3 input, 2x3 weight (2 output features)
157-
(
158-
torch.Size([1, 3]), # src_shape: 1 sample, 3 input features
159-
torch.Size([2, 3]), # weight_shape: 2 output features, 3 input features
160-
0, # in_zero_point
161-
torch.tensor([0, 0, 0], dtype=torch.int8), # weight_zero_point
162-
torch.tensor(
163-
[1073741824], dtype=torch.int32
164-
), # out_multiplier (0.5 * 2^31)
165-
torch.tensor([0], dtype=torch.int8), # out_shift
166-
0, # out_zero_point
167-
torch.tensor([[-10, -30]], dtype=torch.int8), # expected_output
168-
),
167+
*[
168+
(
169+
torch.Size([1, 3]), # src_shape: 1 sample, 3 input features
170+
torch.Size(
171+
[2, 3]
172+
), # weight_shape: 2 output features, 3 input features
173+
0, # in_zero_point
174+
torch.tensor([0, 0, 0], dtype=dtype), # weight_zero_point
175+
torch.tensor(
176+
[1073741824], dtype=torch.int32
177+
), # out_multiplier (0.5 * 2^31)
178+
torch.tensor([0], dtype=torch.int64), # out_shift
179+
0, # out_zero_point
180+
torch.tensor([[-10, -30]], dtype=dtype), # expected_output
181+
per_tensor,
182+
)
183+
for (per_tensor, dtype) in (
184+
(False, torch.int8),
185+
(True, torch.int8),
186+
(True, torch.uint8),
187+
)
188+
],
169189
# Test case 3: Batch case with different dimensions
170-
(
171-
torch.Size([1, 2, 2]), # src_shape: batch=1, seq=2, features=2
172-
torch.Size([3, 2]), # weight_shape: 3 output features, 2 input features
173-
0, # in_zero_point
174-
torch.tensor([0, 0], dtype=torch.int8), # weight_zero_point
175-
torch.tensor(
176-
[1073741824], dtype=torch.int32
177-
), # out_multiplier (0.5 * 2^31)
178-
torch.tensor([0], dtype=torch.int8), # out_shift
179-
0, # out_zero_point
180-
torch.tensor(
181-
[[[-2, -8, -14], [-6, -28, -50]]], dtype=torch.int8
182-
), # expected_output
183-
),
190+
*[
191+
(
192+
torch.Size([1, 2, 2]), # src_shape: batch=1, seq=2, features=2
193+
torch.Size(
194+
[3, 2]
195+
), # weight_shape: 3 output features, 2 input features
196+
0, # in_zero_point
197+
torch.tensor([0, 0], dtype=dtype), # weight_zero_point
198+
torch.tensor(
199+
[1073741824], dtype=torch.int32
200+
), # out_multiplier (0.5 * 2^31)
201+
torch.tensor([0], dtype=torch.int64), # out_shift
202+
0, # out_zero_point
203+
torch.tensor(
204+
[[[-2, -8, -14], [-6, -28, -50]]], dtype=dtype
205+
), # expected_output
206+
per_tensor,
207+
)
208+
for (per_tensor, dtype) in (
209+
(False, torch.int8),
210+
(True, torch.int8),
211+
(True, torch.uint8),
212+
)
213+
],
184214
# Test case 4: Non-zero zero points
185-
(
186-
torch.Size([1, 2]), # src_shape: 1 sample, 2 input features
187-
torch.Size([2, 2]), # weight_shape: 2 output feature, 1 input feature
188-
2, # in_zero_point
189-
torch.tensor([1, 1], dtype=torch.int8), # weight_zero_point
190-
torch.tensor(
191-
[268435456], dtype=torch.int32
192-
), # out_multiplier (1.0 * 2^31)
193-
torch.tensor([0]), # out_shift
194-
1, # out_zero_point
195-
torch.tensor([[-15, 25]], dtype=torch.int8), # expected_output
196-
),
215+
*[
216+
(
217+
torch.Size([1, 2]), # src_shape: 1 sample, 2 input features
218+
torch.Size(
219+
[2, 2]
220+
), # weight_shape: 2 output feature, 1 input feature
221+
2, # in_zero_point
222+
torch.tensor([1, 1], dtype=dtype), # weight_zero_point
223+
torch.tensor(
224+
[268435456], dtype=torch.int32
225+
), # out_multiplier (1.0 * 2^31)
226+
torch.tensor([0], dtype=torch.int64), # out_shift
227+
1, # out_zero_point
228+
torch.tensor([[-15, 25]], dtype=dtype), # expected_output
229+
per_tensor,
230+
)
231+
for (per_tensor, dtype) in (
232+
(False, torch.int8),
233+
(True, torch.int8),
234+
(True, torch.uint8),
235+
)
236+
],
237+
# Test case 5: Non-uniform weight zero points
238+
*[
239+
(
240+
torch.Size([1, 2]), # src_shape: 1 sample, 2 input features
241+
torch.Size(
242+
[2, 2]
243+
), # weight_shape: 2 output feature, 1 input feature
244+
2, # in_zero_point
245+
torch.tensor([1, 2], dtype=dtype), # weight_zero_point
246+
torch.tensor(
247+
[268435456], dtype=torch.int32
248+
), # out_multiplier (1.0 * 2^31)
249+
torch.tensor([0], dtype=torch.int64), # out_shift
250+
1, # out_zero_point
251+
torch.tensor([[-23, 17]], dtype=dtype), # expected_output
252+
False,
253+
)
254+
for dtype in (torch.int8, torch.uint8)
255+
],
256+
# Test case 6: Non-zero out_shift (shift=1)
257+
*[
258+
(
259+
torch.Size([1, 2]), # src_shape: 1 sample, 2 input features
260+
torch.Size(
261+
[2, 2]
262+
), # weight_shape: 2 output features, 2 input features
263+
2, # in_zero_point
264+
torch.tensor([1, 1], dtype=dtype), # weight_zero_point
265+
torch.tensor(
266+
[268435456], dtype=torch.int32
267+
), # out_multiplier (0.125 * 2^31)
268+
torch.tensor(
269+
[1], dtype=torch.int64
270+
), # out_shift (shift=1, doubles the scale)
271+
1, # out_zero_point
272+
torch.tensor([[-7, 13]], dtype=dtype), # expected_output
273+
per_tensor,
274+
)
275+
for (per_tensor, dtype) in ((False, torch.int8), (True, torch.int8))
276+
],
197277
]
198278
)
199279
def test_quantized_linear(
@@ -206,6 +286,7 @@ def test_quantized_linear(
206286
out_shift: torch.Tensor,
207287
out_zero_point: int,
208288
expected_output: torch.Tensor,
289+
per_tensor: bool,
209290
) -> None:
210291
src = (
211292
torch.arange(np.prod(src_shape))
@@ -217,8 +298,28 @@ def test_quantized_linear(
217298
.reshape(weight_shape)
218299
.to(expected_output.dtype)
219300
)
220-
bias = torch.arange(weight_shape[0]).to(expected_output.dtype)
221-
output = torch.ops.cadence.quantized_linear(
301+
bias = torch.arange(weight_shape[0]).to(torch.int32)
302+
if per_tensor:
303+
weight_zero_point = weight_zero_point[0]
304+
out_multiplier = out_multiplier[0]
305+
out_shift = out_shift[0]
306+
307+
if per_tensor:
308+
match expected_output.dtype:
309+
case torch.int8:
310+
linear_op = (
311+
torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor
312+
)
313+
case torch.uint8:
314+
linear_op = (
315+
torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor
316+
)
317+
case _:
318+
linear_op = torch.ops.cadence.quantized_linear.per_tensor
319+
else:
320+
linear_op = torch.ops.cadence.quantized_linear
321+
322+
output = linear_op(
222323
src,
223324
weight,
224325
bias,

0 commit comments

Comments
 (0)