Skip to content

Commit eec95d0

Browse files
authored
Support custom quantized_matmul + variants
Differential Revision: D81973532 Pull Request resolved: #14095
1 parent e9903b8 commit eec95d0

File tree

2 files changed

+225
-44
lines changed

2 files changed

+225
-44
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 123 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def quantized_linear_common(
241241
bias: torch.Tensor,
242242
in_zero_point: int,
243243
weight_zero_point: torch.Tensor | int,
244-
out_multiplier: torch.Tensor | int,
244+
out_multiplier: int,
245245
out_shift: int,
246246
out_zero_point: int,
247247
) -> torch.Tensor:
@@ -329,34 +329,30 @@ def variant(
329329
assert isinstance(weight_zero_point, int)
330330
assert isinstance(out_multiplier, int)
331331
assert isinstance(out_shift, int)
332-
return quantized_linear_common(
333-
src,
334-
weight,
335-
bias,
336-
in_zero_point,
337-
weight_zero_point,
338-
out_multiplier,
339-
out_shift,
340-
out_zero_point,
341-
)
332+
_out_shift = out_shift
333+
_out_multiplier = out_multiplier
342334
else:
343335
assert isinstance(out_shift, torch.Tensor)
336+
assert isinstance(out_multiplier, torch.Tensor)
344337
if out_shift.numel() != 1:
345338
raise ValueError("out_shift must be a scalar")
346339

347340
if out_shift.dtype != torch.int64:
348341
raise ValueError("out_shift must be an int64")
349342

350-
return quantized_linear_common(
351-
src,
352-
weight,
353-
bias,
354-
in_zero_point,
355-
weight_zero_point,
356-
out_multiplier,
357-
int(out_shift.item()),
358-
out_zero_point,
359-
)
343+
_out_shift = int(out_shift.item())
344+
_out_multiplier = int(out_multiplier[0].item())
345+
346+
return quantized_linear_common(
347+
src,
348+
weight,
349+
bias,
350+
in_zero_point,
351+
weight_zero_point,
352+
_out_multiplier,
353+
_out_shift,
354+
out_zero_point,
355+
)
360356

361357
return variant
362358

@@ -403,6 +399,112 @@ def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor:
403399
def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ...
404400

405401

402+
@impl(m, "quantized_matmul")
403+
def quantized_matmul(
404+
X: torch.Tensor,
405+
X_zero_point: int,
406+
Y: torch.Tensor,
407+
Y_zero_point: int,
408+
bias: torch.Tensor | None,
409+
out_multiplier: int,
410+
out_shift: int,
411+
out_zero_point: int,
412+
transposed: bool = False,
413+
) -> torch.Tensor:
414+
"""
415+
Quantized matmul operation.
416+
417+
Args:
418+
- X (Tensor): The activations tensor
419+
- X_zero_point (int): The quantized mapping of zero for the input
420+
- Y (Tensor): The weight tensor
421+
- Y_zero_point (int): The quantized mapping of zero for the weight
422+
- bias (Tensor): The bias tensor
423+
- out_multiplier (int): The multiplier used to scale the output
424+
- out_shift (int): The shift used to scale the output
425+
- out_zero_point (int): The quantized mapping of zero for the output
426+
- transposed (bool): Whether to transpose the weight tensor
427+
"""
428+
if bias is not None and not torch.all(bias == 0):
429+
raise ValueError("bias must be None or all zeros since unused in out variant")
430+
431+
# Looks weird, but quantized linear assumes weights are pre-transposed,
432+
# hence we transpose only if `transposed` is False.
433+
if not transposed:
434+
Y = Y.T
435+
436+
return quantized_linear_common(
437+
X,
438+
Y,
439+
bias or torch.zeros(1, dtype=torch.int32),
440+
X_zero_point,
441+
Y_zero_point,
442+
out_multiplier,
443+
out_shift,
444+
out_zero_point,
445+
)
446+
447+
448+
@impl(m, "quantized_matmul_asym8sxasym8s_asym8s")
449+
def quantized_matmul_asym8sxasym8s_asym8s(
450+
X: torch.Tensor,
451+
X_zero_point: int,
452+
Y: torch.Tensor,
453+
Y_zero_point: int,
454+
bias: torch.Tensor | None,
455+
out_multiplier: int,
456+
out_shift: int,
457+
out_zero_point: int,
458+
transposed: bool = False,
459+
) -> torch.Tensor:
460+
if X.dtype != torch.int8:
461+
raise ValueError("X dtype must be torch.int8")
462+
if Y.dtype != torch.int8:
463+
raise ValueError("Y dtype must be torch.int8")
464+
465+
return quantized_matmul(
466+
X,
467+
X_zero_point,
468+
Y,
469+
Y_zero_point,
470+
bias,
471+
out_multiplier,
472+
out_shift,
473+
out_zero_point,
474+
transposed,
475+
)
476+
477+
478+
@impl(m, "quantized_matmul_asym8uxasym8u_asym8u")
479+
def quantized_matmul_asym8uxasym8u_asym8u(
480+
X: torch.Tensor,
481+
X_zero_point: int,
482+
Y: torch.Tensor,
483+
Y_zero_point: int,
484+
bias: torch.Tensor | None,
485+
out_multiplier: int,
486+
out_shift: int,
487+
out_zero_point: int,
488+
transposed: bool = False,
489+
) -> torch.Tensor:
490+
if X.dtype != torch.uint8:
491+
raise ValueError("X dtype must be torch.uint8")
492+
if Y.dtype != torch.uint8:
493+
raise ValueError("Y dtype must be torch.uint8")
494+
495+
return quantized_matmul(
496+
X,
497+
X_zero_point,
498+
Y,
499+
Y_zero_point,
500+
bias,
501+
out_multiplier,
502+
out_shift,
503+
out_zero_point,
504+
transposed,
505+
)
506+
507+
406508
@impl(m, "quantized_layer_norm.per_tensor")
407509
def quantized_layer_norm_per_tensor(
408510
input_tensor: torch.Tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 102 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ def test_quantized_add(
177177
0, # out_zero_point
178178
torch.tensor([[-2]], dtype=dtype), # expected_output
179179
per_tensor,
180+
False,
181+
False,
180182
)
181183
for (per_tensor, dtype) in (
182184
(False, torch.int8),
@@ -200,6 +202,8 @@ def test_quantized_add(
200202
0, # out_zero_point
201203
torch.tensor([[-10, -30]], dtype=dtype), # expected_output
202204
per_tensor,
205+
False,
206+
False,
203207
)
204208
for (per_tensor, dtype) in (
205209
(False, torch.int8),
@@ -225,6 +229,8 @@ def test_quantized_add(
225229
[[[-2, -8, -14], [-6, -28, -50]]], dtype=dtype
226230
), # expected_output
227231
per_tensor,
232+
False,
233+
False,
228234
)
229235
for (per_tensor, dtype) in (
230236
(False, torch.int8),
@@ -248,6 +254,8 @@ def test_quantized_add(
248254
1, # out_zero_point
249255
torch.tensor([[-15, 25]], dtype=dtype), # expected_output
250256
per_tensor,
257+
False,
258+
False,
251259
)
252260
for (per_tensor, dtype) in (
253261
(False, torch.int8),
@@ -271,6 +279,8 @@ def test_quantized_add(
271279
1, # out_zero_point
272280
torch.tensor([[-23, 17]], dtype=dtype), # expected_output
273281
False,
282+
False,
283+
False,
274284
)
275285
for dtype in (torch.int8, torch.uint8)
276286
],
@@ -292,9 +302,34 @@ def test_quantized_add(
292302
1, # out_zero_point
293303
torch.tensor([[-7, 13]], dtype=dtype), # expected_output
294304
per_tensor,
305+
False,
306+
False,
295307
)
296308
for (per_tensor, dtype) in ((False, torch.int8), (True, torch.int8))
297309
],
310+
*[
311+
(
312+
torch.Size([1, 2]), # src_shape: 1 sample, 2 input features
313+
torch.Size(
314+
[2, 2]
315+
), # weight_shape: 2 output features, 2 input features
316+
2, # in_zero_point
317+
torch.tensor([1, 1], dtype=dtype), # weight_zero_point
318+
torch.tensor(
319+
[268435456], dtype=torch.int32
320+
), # out_multiplier (0.125 * 2^31)
321+
torch.tensor(
322+
[1], dtype=torch.int64
323+
), # out_shift (shift=1, doubles the scale)
324+
1, # out_zero_point
325+
torch.tensor([[-7, 17]], dtype=dtype), # expected_output
326+
per_tensor,
327+
matmul,
328+
transposed_matmul,
329+
)
330+
for (matmul, transposed_matmul) in ((True, False), (True, True))
331+
for (per_tensor, dtype) in ((True, torch.int8), (True, torch.uint8))
332+
],
298333
]
299334
)
300335
def test_quantized_linear(
@@ -308,7 +343,12 @@ def test_quantized_linear(
308343
out_zero_point: int,
309344
expected_output: torch.Tensor,
310345
per_tensor: bool,
346+
matmul: bool,
347+
transposed_matmul: bool,
311348
) -> None:
349+
if not per_tensor and matmul:
350+
self.skipTest("Only per_tensor supported for matmul")
351+
312352
src = (
313353
torch.arange(np.prod(src_shape))
314354
.reshape(src_shape)
@@ -319,7 +359,9 @@ def test_quantized_linear(
319359
.reshape(weight_shape)
320360
.to(expected_output.dtype)
321361
)
322-
bias = torch.arange(weight_shape[0]).to(torch.int32)
362+
if matmul and not transposed_matmul:
363+
weight = weight.T
364+
323365
if per_tensor:
324366
weight_zero_point = weight_zero_point[0]
325367
out_multiplier = out_multiplier[0]
@@ -328,38 +370,75 @@ def test_quantized_linear(
328370
if per_tensor:
329371
match expected_output.dtype:
330372
case torch.int8:
331-
linear_ops = (
332-
torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor,
333-
torch.ops.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor,
334-
)
373+
if matmul:
374+
linear_ops = (
375+
# Doesn't have per tensor name, but it is per tensor
376+
torch.ops.cadence.quantized_matmul_asym8sxasym8s_asym8s,
377+
)
378+
else:
379+
linear_ops = (
380+
torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor,
381+
torch.ops.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor,
382+
)
335383
case torch.uint8:
336-
linear_ops = (
337-
torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor,
338-
torch.ops.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor,
339-
)
384+
if matmul:
385+
linear_ops = (
386+
torch.ops.cadence.quantized_matmul_asym8uxasym8u_asym8u,
387+
)
388+
else:
389+
linear_ops = (
390+
torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor,
391+
torch.ops.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor,
392+
)
340393
case _:
341-
linear_ops = (
342-
torch.ops.cadence.quantized_linear.per_tensor,
343-
torch.ops.cadence.quantized_fully_connected.per_tensor,
344-
)
394+
if matmul:
395+
linear_ops = (torch.ops.cadence.quantized_matmul,)
396+
else:
397+
linear_ops = (
398+
torch.ops.cadence.quantized_linear.per_tensor,
399+
torch.ops.cadence.quantized_fully_connected.per_tensor,
400+
)
345401
else:
346402
linear_ops = (
347403
torch.ops.cadence.quantized_linear,
348404
torch.ops.cadence.quantized_fully_connected,
349405
)
350406

351407
for linear_op in linear_ops:
352-
output = linear_op(
353-
src,
354-
weight,
355-
bias,
356-
in_zero_point,
357-
weight_zero_point,
358-
out_multiplier,
359-
out_shift,
360-
out_zero_point,
361-
typing.cast(torch.Tensor, None),
408+
# Get the function name for linear_op for debugging
409+
op_name = (
410+
linear_op.__name__ if hasattr(linear_op, "__name__") else str(linear_op)
362411
)
412+
if matmul:
413+
assert "quantized_matmul" in op_name
414+
output = linear_op(
415+
src,
416+
in_zero_point,
417+
weight,
418+
weight_zero_point,
419+
None,
420+
out_multiplier,
421+
out_shift,
422+
out_zero_point,
423+
transposed_matmul,
424+
)
425+
else:
426+
assert (
427+
"quantized_linear" in op_name
428+
or "quantized_fully_connected" in op_name
429+
)
430+
bias = torch.arange(weight_shape[0]).to(torch.int32)
431+
output = linear_op(
432+
src,
433+
weight,
434+
bias,
435+
in_zero_point,
436+
weight_zero_point,
437+
out_multiplier,
438+
out_shift,
439+
out_zero_point,
440+
typing.cast(torch.Tensor, None),
441+
)
363442

364443
self.assertTrue(output.dtype == expected_output.dtype, "Dtype mismatch")
365444

0 commit comments

Comments
 (0)