Skip to content

Commit 98e8b2d

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Support for batched matmul (#14956)
Summary: Matmul was relying on linear infra which didn't support batched second argument. This adds support. Differential Revision: D84279595
1 parent 3591604 commit 98e8b2d

File tree

2 files changed

+41
-16
lines changed

2 files changed

+41
-16
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def quantize_per_tensor(
6262
]
6363
if dtype not in supported_quant_types:
6464
raise ValueError(
65-
f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}"
65+
f"Unsupported dtype to quantize to {dtype}. Supported dtypes must be one of {supported_quant_types}"
6666
)
6767

6868
return torch.ops.quantized_decomposed.quantize_per_tensor(
@@ -264,7 +264,7 @@ def quantized_linear_common(
264264
supported_dtypes = [torch.int8, torch.uint8, torch.int32]
265265
if dtype not in supported_dtypes:
266266
raise ValueError(
267-
f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_dtypes}"
267+
f"Unsupported dtype to quantize to {dtype}. Supported dtypes must be one of {supported_dtypes}"
268268
)
269269

270270
out = torch.nn.functional.linear(
@@ -427,25 +427,27 @@ def quantized_matmul(
427427
- out_multiplier (int): The multiplier used to scale the output
428428
- out_shift (int): The shift used to scale the output
429429
- out_zero_point (int): The quantized mapping of zero for the output
430-
- transposed (bool): Whether to transpose the weight tensor
430+
- transposed (bool): Whether Y is transposed.
431431
"""
432432
if bias is not None and not torch.all(bias == 0):
433433
raise ValueError("bias must be None or all zeros since unused in out variant")
434434

435-
# Looks weird, but quantized linear assumes weights are pre-transposed,
436-
# hence we transpose only if `transposed` is False.
437-
if not transposed:
438-
Y = Y.T
435+
if transposed:
436+
Y = Y.transpose(-1, -2)
439437

440-
return quantized_linear_common(
441-
X,
442-
Y,
443-
bias or torch.zeros(1, dtype=torch.int32),
444-
X_zero_point,
445-
Y_zero_point,
446-
out_multiplier,
447-
out_shift,
438+
out_scale = 1.0 / (-out_multiplier * (1 / (1 << 31)) * (2**out_shift))
439+
440+
out = torch.matmul(
441+
(X - X_zero_point).float(),
442+
(Y - Y_zero_point).float(),
443+
)
444+
return quantize_per_tensor(
445+
out,
446+
out_scale,
448447
out_zero_point,
448+
torch.iinfo(X.dtype).min,
449+
torch.iinfo(X.dtype).max,
450+
X.dtype,
449451
)
450452

451453

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,29 @@ def test_quantized_add(
350350
for (matmul, transposed_matmul) in ((True, False), (True, True))
351351
for (per_tensor, dtype) in ((True, torch.int8),)
352352
],
353+
*[
354+
(
355+
torch.Size([2, 1, 2]), # src_shape: 1 sample, 2 input features
356+
torch.Size(
357+
[2, 2, 2]
358+
), # weight_shape: 2 output features, 2 input features
359+
2, # in_zero_point
360+
torch.tensor([1, 1], dtype=dtype), # weight_zero_point
361+
torch.tensor(
362+
[268435456], dtype=torch.int32
363+
), # out_multiplier (0.125 * 2^31)
364+
torch.tensor(
365+
[1], dtype=torch.int32
366+
), # out_shift (shift=1, doubles the scale)
367+
1, # out_zero_point
368+
torch.tensor([[[1, 2]], [[0, -1]]], dtype=dtype), # expected_output
369+
per_tensor,
370+
matmul,
371+
transposed_matmul,
372+
)
373+
for (matmul, transposed_matmul) in ((True, False), (True, True))
374+
for (per_tensor, dtype) in ((True, torch.int8),)
375+
],
353376
]
354377
)
355378
def test_quantized_linear(
@@ -380,7 +403,7 @@ def test_quantized_linear(
380403
.to(expected_output.dtype)
381404
)
382405
if matmul and not transposed_matmul:
383-
weight = weight.T
406+
weight = weight.transpose(-1, -2)
384407

385408
if per_tensor:
386409
weight_zero_point = weight_zero_point[0]

0 commit comments

Comments
 (0)