Skip to content

Commit 368ca27

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Enable strongly typed ops for deployment (#13230)
Summary: Pull Request resolved: #13230 As titled. This allows a severe reduction in codesize by only using the bare minimum cpp code. Right now, this diff only implements that option for quantized fully connected per tensor. This diff is introducing opt level 4 to use for deployment purposes. The idea is similar to -Os in most compilers, we just use 4 for simplicity and because for now, everything should be inclusive by construction. This decision can be re-visited later! Differential Revision: D79867630
1 parent 6fd97ab commit 368ca27

11 files changed

+470
-4
lines changed

backends/cadence/aot/TARGETS

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ python_library(
101101
":reorder_ops",
102102
":replace_ops",
103103
":simplify_ops",
104+
":type_dispatch",
104105
":utils",
105106
"//caffe2:torch",
106107
"//executorch/exir:pass_base",
@@ -322,6 +323,37 @@ python_library(
322323
],
323324
)
324325

326+
python_library(
327+
name = "type_dispatch",
328+
srcs = [
329+
"type_dispatch.py",
330+
],
331+
typing = True,
332+
deps = [
333+
"//caffe2:torch",
334+
"//executorch/backends/cadence/aot:pass_utils",
335+
"//executorch/exir:pass_base",
336+
],
337+
)
338+
339+
python_unittest(
340+
name = "test_type_dispatch_passes",
341+
srcs = [
342+
"tests/test_type_dispatch_passes.py",
343+
],
344+
supports_static_listing = False,
345+
typing = True,
346+
deps = [
347+
":ops_registrations",
348+
":type_dispatch",
349+
"//caffe2:torch",
350+
"//executorch/backends/cadence/aot:graph_builder",
351+
"//executorch/backends/cadence/aot:pass_utils",
352+
"//executorch/exir:pass_base",
353+
"//executorch/exir/dialects:lib",
354+
],
355+
)
356+
325357
python_library(
326358
name = "typing_stubs",
327359
srcs = [

backends/cadence/aot/functions.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,16 @@
254254
- arg_meta: null
255255
kernel_name: impl::reference::quantized_fully_connected_per_tensor_out
256256

257+
- func: cadence::quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
258+
kernels:
259+
- arg_meta: null
260+
kernel_name: impl::reference::quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out
261+
262+
- func: cadence::quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
263+
kernels:
264+
- arg_meta: null
265+
kernel_name: impl::reference::quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out
266+
257267
- func: cadence::requantize.out(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!)
258268
kernels:
259269
- arg_meta: null

backends/cadence/aot/functions_hifi.yaml

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -329,17 +329,27 @@
329329
- arg_meta: null
330330
kernel_name: cadence::impl::HiFi::quantized_relu_per_tensor_out
331331

332-
- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
332+
- func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
333333
kernels:
334334
- arg_meta: null
335-
kernel_name: cadence::impl::HiFi::quantized_fully_connected_out
335+
kernel_name: cadence::impl::HiFi::quantized_matmul_out
336336

337-
- func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
337+
- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
338338
kernels:
339339
- arg_meta: null
340-
kernel_name: cadence::impl::HiFi::quantized_matmul_out
340+
kernel_name: cadence::impl::HiFi::quantized_fully_connected_out
341341

342342
- func: cadence::quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
343343
kernels:
344344
- arg_meta: null
345345
kernel_name: cadence::impl::HiFi::quantized_fully_connected_per_tensor_out
346+
347+
- func: cadence::quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
348+
kernels:
349+
- arg_meta: null
350+
kernel_name: cadence::impl::HiFi::quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out
351+
352+
- func: cadence::quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
353+
kernels:
354+
- arg_meta: null
355+
kernel_name: cadence::impl::HiFi::quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out

backends/cadence/aot/ops_registrations.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,14 @@
162162
"quantized_fully_connected.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
163163
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
164164
)
165+
lib.define(
166+
"quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
167+
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
168+
)
169+
lib.define(
170+
"quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
171+
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
172+
)
165173
lib.define("where_Scalar(Tensor condition, float self, float other) -> (Tensor Z)")
166174
lib.define(
167175
"where_Scalar.out(Tensor condition, float self, float other, *, Tensor(a!) out) -> Tensor(a!)"
@@ -240,6 +248,14 @@
240248
"quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
241249
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
242250
)
251+
lib.define(
252+
"quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
253+
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
254+
)
255+
lib.define(
256+
"quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
257+
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
258+
)
243259
lib.define(
244260
"quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, "
245261
"Tensor indices, bool pruned_weights=False, *, Tensor(a!) out) -> Tensor(a!)"
@@ -754,6 +770,50 @@ def quantized_fully_connected_per_tensor_meta(
754770
return src.new_empty(out_size, dtype=src.dtype)
755771

756772

773+
@register_fake("cadence::quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor")
774+
def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_meta(
775+
src: torch.Tensor,
776+
weight: torch.Tensor,
777+
bias: torch.Tensor,
778+
in_zero_point: int,
779+
weight_zero_point: int,
780+
out_multiplier: int,
781+
out_shift: int,
782+
out_zero_point: int,
783+
offset: Optional[torch.Tensor],
784+
) -> torch.Tensor:
785+
# src comes in shape [leading_dims, in_dim]
786+
# weight comes in shape [out_dim, in_dim]
787+
# output comes in empty with shape [leading_dims, out_dim]
788+
out_size = list(src.size())
789+
weight_size = list(weight.size())
790+
assert len(weight_size) == 2
791+
out_size[-1] = weight_size[0]
792+
return src.new_empty(out_size, dtype=src.dtype)
793+
794+
795+
@register_fake("cadence::quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor")
796+
def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_meta(
797+
src: torch.Tensor,
798+
weight: torch.Tensor,
799+
bias: torch.Tensor,
800+
in_zero_point: int,
801+
weight_zero_point: int,
802+
out_multiplier: int,
803+
out_shift: int,
804+
out_zero_point: int,
805+
offset: Optional[torch.Tensor],
806+
) -> torch.Tensor:
807+
# src comes in shape [leading_dims, in_dim]
808+
# weight comes in shape [out_dim, in_dim]
809+
# output comes in empty with shape [leading_dims, out_dim]
810+
out_size = list(src.size())
811+
weight_size = list(weight.size())
812+
assert len(weight_size) == 2
813+
out_size[-1] = weight_size[0]
814+
return src.new_empty(out_size, dtype=src.dtype)
815+
816+
757817
@register_fake("cadence::convolution")
758818
def convolution_meta(
759819
input: torch.Tensor,

backends/cadence/aot/passes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ReplaceMulTensorWithMulAndFullOpsPass,
3434
)
3535
from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph
36+
from executorch.backends.cadence.aot.type_dispatch import CompileTimeTypeDispatchPass
3637
from executorch.exir import EdgeProgramManager
3738
from executorch.exir.pass_base import ExportPass, PassResult
3839
from executorch.exir.pass_manager import PassManager, PassType
@@ -90,6 +91,7 @@ def get_passes_in_default_order() -> list[Type[ExportPass]]:
9091
FuseFullThenReshapePass,
9192
FuseTransposeOrPermuteOpPairsPass,
9293
RemoveNopSliceOrViewOpPass,
94+
CompileTimeTypeDispatchPass,
9395
]
9496
return pytree.tree_flatten(passes)[0]
9597

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# pyre-strict
7+
8+
import unittest
9+
from typing import cast
10+
11+
import executorch.backends.cadence.aot.ops_registrations # noqa
12+
import torch
13+
from executorch.backends.cadence.aot.graph_builder import single_op_builder
14+
from executorch.backends.cadence.aot.pass_utils import count_node
15+
from executorch.backends.cadence.aot.type_dispatch import CompileTimeTypeDispatchPass
16+
from executorch.exir.dialects._ops import ops as exir_ops
17+
from torch.fx.passes.infra.pass_base import PassResult
18+
19+
20+
class TestTypeDispatchPasses(unittest.TestCase):
21+
def test_int8_dispatch(self) -> None:
22+
"""Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant"""
23+
x = torch.randint(-128, 127, (2, 3), dtype=torch.int8)
24+
w = torch.randint(-128, 127, (4, 3), dtype=torch.int8)
25+
b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32)
26+
gm = single_op_builder(
27+
placeholders=(x, w, b),
28+
op=exir_ops.edge.cadence.quantized_fully_connected.per_tensor,
29+
args=(x, w, b, 0, 0, 1, 0, 0, None),
30+
)
31+
p = CompileTimeTypeDispatchPass()
32+
gm = cast(PassResult, p(gm)).graph_module
33+
# Original op should be replaced
34+
self.assertEqual(
35+
count_node(gm, exir_ops.edge.cadence.quantized_fully_connected.per_tensor),
36+
0,
37+
)
38+
# Should be replaced with int8 specific variant
39+
self.assertEqual(
40+
count_node(
41+
gm,
42+
exir_ops.edge.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor,
43+
),
44+
1,
45+
)
46+
47+
def test_uint8_dispatch(self) -> None:
48+
"""Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant"""
49+
x = torch.randint(0, 255, (2, 3), dtype=torch.uint8)
50+
w = torch.randint(0, 255, (4, 3), dtype=torch.uint8)
51+
b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32)
52+
gm = single_op_builder(
53+
placeholders=(x, w, b),
54+
op=exir_ops.edge.cadence.quantized_fully_connected.per_tensor,
55+
args=(x, w, b, 0, 0, 1, 0, 0, None),
56+
)
57+
p = CompileTimeTypeDispatchPass()
58+
gm = cast(PassResult, p(gm)).graph_module
59+
# Original op should be replaced
60+
self.assertEqual(
61+
count_node(gm, exir_ops.edge.cadence.quantized_fully_connected.per_tensor),
62+
0,
63+
)
64+
# Should be replaced with uint8 specific variant
65+
self.assertEqual(
66+
count_node(
67+
gm,
68+
exir_ops.edge.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor,
69+
),
70+
1,
71+
)
72+
73+
def test_mixed_types_error(self) -> None:
74+
"""Test mixed int8/uint8 inputs should raise RuntimeError"""
75+
x = torch.randint(-128, 127, (2, 3), dtype=torch.int8)
76+
w = torch.randint(0, 255, (4, 3), dtype=torch.uint8)
77+
b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32)
78+
gm = single_op_builder(
79+
placeholders=(x, w, b),
80+
op=exir_ops.edge.cadence.quantized_fully_connected.per_tensor,
81+
args=(x, w, b, 0, 0, 1, 0, 0, None),
82+
)
83+
p = CompileTimeTypeDispatchPass()
84+
# Mixed types should raise RuntimeError
85+
with self.assertRaises(RuntimeError) as context:
86+
cast(PassResult, p(gm)).graph_module
87+
self.assertIn("Unsupported input types", str(context.exception))
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import torch
10+
from executorch.backends.cadence.aot.pass_utils import (
11+
CadencePassAttribute,
12+
register_cadence_pass,
13+
)
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
16+
from torch._ops import OpOverload
17+
from torch.fx.node import Argument
18+
19+
@register_cadence_pass(CadencePassAttribute(opt_level=4))
20+
class CompileTimeTypeDispatchPass(ExportPass):
21+
"""
22+
Replaces generic ops with ops that have explicit types.
23+
"""
24+
25+
def call_operator(
26+
self,
27+
op: OpOverload,
28+
args: tuple[Argument, ...],
29+
kwargs: dict[str, Argument],
30+
meta: NodeMetadata,
31+
) -> ProxyValue:
32+
if op not in {
33+
exir_ops.edge.cadence.quantized_fully_connected.per_tensor,
34+
}:
35+
return super().call_operator(op, args, kwargs, meta)
36+
37+
if (
38+
# pyre-ignore[16]: None has no attribute `to_tensor`.
39+
args[0].to_tensor().dtype == torch.int8
40+
and args[1].to_tensor().dtype == torch.int8
41+
):
42+
return super().call_operator(
43+
exir_ops.edge.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor,
44+
args,
45+
kwargs,
46+
meta,
47+
)
48+
elif (
49+
args[0].to_tensor().dtype == torch.uint8
50+
and args[1].to_tensor().dtype == torch.uint8
51+
):
52+
return super().call_operator(
53+
exir_ops.edge.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor,
54+
args,
55+
kwargs,
56+
meta,
57+
)
58+
else:
59+
raise RuntimeError(
60+
f"Unsupported input types for {op}: {args[0].to_tensor().dtype} and {args[1].to_tensor().dtype}"
61+
)

0 commit comments

Comments
 (0)