Skip to content

Commit 02a203e

Browse files
authored
Add support for strongly typed op_quantized_matmul, generalize dispatch strategy
Differential Revision: D80132832 Pull Request resolved: #13375
1 parent ec67249 commit 02a203e

9 files changed

+539
-40
lines changed

backends/cadence/aot/functions.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,16 @@
234234
- arg_meta: null
235235
kernel_name: impl::reference::quantized_matmul_out
236236

237+
- func: cadence::quantized_matmul_asym8sxasym8s_asym8s.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
238+
kernels:
239+
- arg_meta: null
240+
kernel_name: impl::reference::quantized_matmul_asym8sxasym8s_asym8s_out
241+
242+
- func: cadence::quantized_matmul_asym8uxasym8u_asym8u.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
243+
kernels:
244+
- arg_meta: null
245+
kernel_name: impl::reference::quantized_matmul_asym8uxasym8u_asym8u_out
246+
237247
- func: cadence::quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
238248
kernels:
239249
- arg_meta: null

backends/cadence/aot/functions_hifi.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,16 @@
354354
- arg_meta: null
355355
kernel_name: cadence::impl::HiFi::quantized_matmul_out
356356

357+
- func: cadence::quantized_matmul_asym8sxasym8s_asym8s.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
358+
kernels:
359+
- arg_meta: null
360+
kernel_name: cadence::impl::HiFi::quantized_matmul_asym8sxasym8s_asym8s_out
361+
362+
- func: cadence::quantized_matmul_asym8uxasym8u_asym8u.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
363+
kernels:
364+
- arg_meta: null
365+
kernel_name: cadence::impl::HiFi::quantized_matmul_asym8uxasym8u_asym8u_out
366+
357367
- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
358368
kernels:
359369
- arg_meta: null

backends/cadence/aot/ops_registrations.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,18 @@
103103
lib.define(
104104
"quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False, *, Tensor(a!) out) -> Tensor(a!)"
105105
)
106+
lib.define(
107+
"quantized_matmul_asym8sxasym8s_asym8s(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)"
108+
)
109+
lib.define(
110+
"quantized_matmul_asym8sxasym8s_asym8s.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False, *, Tensor(a!) out) -> Tensor(a!)"
111+
)
112+
lib.define(
113+
"quantized_matmul_asym8uxasym8u_asym8u(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)"
114+
)
115+
lib.define(
116+
"quantized_matmul_asym8uxasym8u_asym8u.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False, *, Tensor(a!) out) -> Tensor(a!)"
117+
)
106118

107119
lib.define(
108120
"convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, "
@@ -700,6 +712,92 @@ def quantized_matmul_meta(
700712
return X.new_empty(out_size, dtype=X.dtype)
701713

702714

715+
@register_fake("cadence::quantized_matmul_asym8sxasym8s_asym8s")
716+
def quantized_matmul_asym8sxasym8s_asym8s_meta(
717+
X: torch.Tensor,
718+
X_zero_point: int,
719+
Y: torch.Tensor,
720+
Y_zero_point: int,
721+
bias: Optional[torch.Tensor],
722+
out_multiplier: int,
723+
out_shift: int,
724+
out_zero_point: int,
725+
transposed: bool = False,
726+
) -> torch.Tensor:
727+
X_size = list(X.size())
728+
Y_size = list(Y.size())
729+
730+
# Get the batch dimensions for both tensors
731+
X_batch_dims = X_size[:-2]
732+
Y_batch_dims = Y_size[:-2]
733+
734+
# If they don't match, check that they're compatible
735+
if X_batch_dims != Y_batch_dims:
736+
assert prod(X_batch_dims) == prod(
737+
Y_batch_dims
738+
), f"Batch dimensions of X and Y do not match: {X_batch_dims} vs {Y_batch_dims}"
739+
740+
# Get the matmul output size
741+
if transposed:
742+
assert X_size[-1] == Y_size[-1], "matrices cannot be multiplied"
743+
mat_size = [X_size[-2], Y_size[-2]]
744+
else:
745+
assert X_size[-1] == Y_size[-2], "matrices cannot be multiplied"
746+
mat_size = [X_size[-2], Y_size[-1]]
747+
748+
# Combine the larger batch dimensions with the matmul output size
749+
out_size = (
750+
X_batch_dims + mat_size
751+
if len(X_batch_dims) > len(Y_batch_dims)
752+
else Y_batch_dims + mat_size
753+
)
754+
755+
return X.new_empty(out_size, dtype=X.dtype)
756+
757+
758+
@register_fake("cadence::quantized_matmul_asym8uxasym8u_asym8u")
759+
def quantized_matmul_asym8uxasym8u_asym8u_meta(
760+
X: torch.Tensor,
761+
X_zero_point: int,
762+
Y: torch.Tensor,
763+
Y_zero_point: int,
764+
bias: Optional[torch.Tensor],
765+
out_multiplier: int,
766+
out_shift: int,
767+
out_zero_point: int,
768+
transposed: bool = False,
769+
) -> torch.Tensor:
770+
X_size = list(X.size())
771+
Y_size = list(Y.size())
772+
773+
# Get the batch dimensions for both tensors
774+
X_batch_dims = X_size[:-2]
775+
Y_batch_dims = Y_size[:-2]
776+
777+
# If they don't match, check that they're compatible
778+
if X_batch_dims != Y_batch_dims:
779+
assert prod(X_batch_dims) == prod(
780+
Y_batch_dims
781+
), f"Batch dimensions of X and Y do not match: {X_batch_dims} vs {Y_batch_dims}"
782+
783+
# Get the matmul output size
784+
if transposed:
785+
assert X_size[-1] == Y_size[-1], "matrices cannot be multiplied"
786+
mat_size = [X_size[-2], Y_size[-2]]
787+
else:
788+
assert X_size[-1] == Y_size[-2], "matrices cannot be multiplied"
789+
mat_size = [X_size[-2], Y_size[-1]]
790+
791+
# Combine the larger batch dimensions with the matmul output size
792+
out_size = (
793+
X_batch_dims + mat_size
794+
if len(X_batch_dims) > len(Y_batch_dims)
795+
else Y_batch_dims + mat_size
796+
)
797+
798+
return X.new_empty(out_size, dtype=X.dtype)
799+
800+
703801
@register_fake("cadence::im2row")
704802
def im2row_meta(
705803
input: torch.Tensor,

backends/cadence/aot/tests/test_type_dispatch_passes.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,55 @@ def test_uint8_dispatch_quantized_relu(self) -> None:
185185
),
186186
1,
187187
)
188+
189+
def test_int8_dispatch_quantized_matmul(self) -> None:
190+
"""Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant for quantized_matmul"""
191+
x = torch.randint(-128, 127, (2, 3), dtype=torch.int8)
192+
y = torch.randint(-128, 127, (3, 4), dtype=torch.int8)
193+
bias = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32)
194+
gm = single_op_builder(
195+
placeholders=(x, y, bias),
196+
op=exir_ops.edge.cadence.quantized_matmul.default,
197+
args=(x, 0, y, 0, bias, 1, 0, 0, False),
198+
)
199+
p = CompileTimeTypeDispatchPass()
200+
gm = cast(PassResult, p(gm)).graph_module
201+
# Original op should be replaced
202+
self.assertEqual(
203+
count_node(gm, exir_ops.edge.cadence.quantized_matmul.default),
204+
0,
205+
)
206+
# Should be replaced with int8 specific variant
207+
self.assertEqual(
208+
count_node(
209+
gm,
210+
exir_ops.edge.cadence.quantized_matmul_asym8sxasym8s_asym8s.default,
211+
),
212+
1,
213+
)
214+
215+
def test_uint8_dispatch_quantized_matmul(self) -> None:
216+
"""Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant for quantized_matmul"""
217+
x = torch.randint(0, 255, (2, 3), dtype=torch.uint8)
218+
y = torch.randint(0, 255, (3, 4), dtype=torch.uint8)
219+
bias = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32)
220+
gm = single_op_builder(
221+
placeholders=(x, y, bias),
222+
op=exir_ops.edge.cadence.quantized_matmul.default,
223+
args=(x, 0, y, 0, bias, 1, 0, 0, False),
224+
)
225+
p = CompileTimeTypeDispatchPass()
226+
gm = cast(PassResult, p(gm)).graph_module
227+
# Original op should be replaced
228+
self.assertEqual(
229+
count_node(gm, exir_ops.edge.cadence.quantized_matmul.default),
230+
0,
231+
)
232+
# Should be replaced with uint8 specific variant
233+
self.assertEqual(
234+
count_node(
235+
gm,
236+
exir_ops.edge.cadence.quantized_matmul_asym8uxasym8u_asym8u.default,
237+
),
238+
1,
239+
)

backends/cadence/aot/type_dispatch.py

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
# pyre-strict
88

9+
from dataclasses import dataclass
10+
from typing import Optional
11+
912
import torch
1013
from executorch.backends.cadence.aot.pass_utils import (
1114
CadencePassAttribute,
@@ -17,29 +20,42 @@
1720
from torch.fx.node import Argument
1821

1922

23+
@dataclass
24+
class OpConfig:
25+
"""Configuration for type dispatch operations."""
26+
27+
base_name: str
28+
input_arg_idx: int = 0
29+
weight_arg_idx: Optional[int] = None
30+
variant: str = "per_tensor"
31+
32+
2033
@register_cadence_pass(CadencePassAttribute(opt_level=4))
2134
class CompileTimeTypeDispatchPass(ExportPass):
2235
"""
2336
Replaces generic ops with ops that have explicit types.
2437
"""
2538

26-
_BINARY_TYPE_DISPATCH_MAP: dict[tuple[torch.dtype, torch.dtype], str] = {
39+
_TYPE_DISPATCH_MAP: dict[tuple[torch.dtype, ...], str] = {
40+
(torch.int8,): "asym8s_asym8s",
41+
(torch.uint8,): "asym8u_asym8u",
2742
(torch.int8, torch.int8): "asym8sxasym8s_asym8s",
2843
(torch.uint8, torch.uint8): "asym8uxasym8u_asym8u",
2944
}
3045

31-
_UNARY_TYPE_DISPATCH_MAP: dict[torch.dtype, str] = {
32-
torch.int8: "asym8s_asym8s",
33-
torch.uint8: "asym8u_asym8u",
34-
}
35-
36-
_BINARY_SUPPORTED_OPS: dict[OpOverload, str] = {
37-
exir_ops.edge.cadence.quantized_fully_connected.per_tensor: "quantized_fully_connected",
38-
exir_ops.edge.cadence.quantized_linear.per_tensor: "quantized_linear",
39-
}
40-
41-
_SUPPORTED_UNARY_OPS: dict[OpOverload, str] = {
42-
exir_ops.edge.cadence.quantized_relu.per_tensor: "quantized_relu",
46+
_SUPPORTED_OPS: dict[OpOverload, OpConfig] = {
47+
exir_ops.edge.cadence.quantized_fully_connected.per_tensor: OpConfig(
48+
"quantized_fully_connected", input_arg_idx=0, weight_arg_idx=1
49+
),
50+
exir_ops.edge.cadence.quantized_linear.per_tensor: OpConfig(
51+
"quantized_linear", input_arg_idx=0, weight_arg_idx=1
52+
),
53+
exir_ops.edge.cadence.quantized_matmul.default: OpConfig(
54+
"quantized_matmul", input_arg_idx=0, weight_arg_idx=2, variant="default"
55+
),
56+
exir_ops.edge.cadence.quantized_relu.per_tensor: OpConfig(
57+
"quantized_relu", input_arg_idx=0
58+
),
4359
}
4460

4561
def call_operator(
@@ -49,37 +65,28 @@ def call_operator(
4965
kwargs: dict[str, Argument],
5066
meta: NodeMetadata,
5167
) -> ProxyValue:
52-
if op in self._BINARY_SUPPORTED_OPS:
53-
# pyre-ignore[16]: None has no attribute `to_tensor`.
54-
input_dtype = args[0].to_tensor().dtype
55-
weight_dtype = args[1].to_tensor().dtype
56-
dtype_pair = (input_dtype, weight_dtype)
57-
58-
if dtype_pair not in self._BINARY_TYPE_DISPATCH_MAP:
59-
raise RuntimeError(
60-
f"Unsupported input types for {op}: {input_dtype} and {weight_dtype}"
61-
)
62-
63-
base_op_name = self._BINARY_SUPPORTED_OPS[op]
64-
type_suffix = self._BINARY_TYPE_DISPATCH_MAP[dtype_pair]
65-
66-
typed_op_name = f"{base_op_name}_{type_suffix}"
67-
typed_op = getattr(exir_ops.edge.cadence, typed_op_name).per_tensor
68+
if op not in self._SUPPORTED_OPS:
69+
return super().call_operator(op, args, kwargs, meta)
6870

69-
return super().call_operator(typed_op, args, kwargs, meta)
71+
config = self._SUPPORTED_OPS[op]
7072

71-
elif op in self._SUPPORTED_UNARY_OPS:
72-
input_dtype = args[0].to_tensor().dtype
73+
# pyre-ignore[16]: None has no attribute `to_tensor`.
74+
input_dtype = args[config.input_arg_idx].to_tensor().dtype
7375

74-
if input_dtype not in self._UNARY_TYPE_DISPATCH_MAP:
75-
raise RuntimeError(f"Unsupported input type for {op}: {input_dtype}")
76+
if config.weight_arg_idx is not None:
77+
weight_dtype = args[config.weight_arg_idx].to_tensor().dtype
78+
dtype_key = (input_dtype, weight_dtype)
79+
else:
80+
dtype_key = (input_dtype,)
7681

77-
base_op_name = self._SUPPORTED_UNARY_OPS[op]
78-
type_suffix = self._UNARY_TYPE_DISPATCH_MAP[input_dtype]
82+
if dtype_key not in self._TYPE_DISPATCH_MAP:
83+
raise RuntimeError(f"Unsupported input types for {op}: {dtype_key}")
7984

80-
typed_op_name = f"{base_op_name}_{type_suffix}"
81-
typed_op = getattr(exir_ops.edge.cadence, typed_op_name).per_tensor
85+
type_suffix = self._TYPE_DISPATCH_MAP[dtype_key]
86+
typed_op_name = f"{config.base_name}_{type_suffix}"
8287

83-
return super().call_operator(typed_op, args, kwargs, meta)
88+
typed_op = getattr(
89+
getattr(exir_ops.edge.cadence, typed_op_name), config.variant
90+
)
8491

85-
return super().call_operator(op, args, kwargs, meta)
92+
return super().call_operator(typed_op, args, kwargs, meta)

0 commit comments

Comments
 (0)