Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ python_library(
":reorder_ops",
":replace_ops",
":simplify_ops",
":type_dispatch",
":utils",
"//caffe2:torch",
"//executorch/exir:pass_base",
Expand Down Expand Up @@ -322,6 +323,37 @@ python_library(
],
)

python_library(
name = "type_dispatch",
srcs = [
"type_dispatch.py",
],
typing = True,
deps = [
"//caffe2:torch",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/exir:pass_base",
],
)

python_unittest(
name = "test_type_dispatch_passes",
srcs = [
"tests/test_type_dispatch_passes.py",
],
supports_static_listing = False,
typing = True,
deps = [
":ops_registrations",
":type_dispatch",
"//caffe2:torch",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
],
)

python_library(
name = "typing_stubs",
srcs = [
Expand Down
10 changes: 10 additions & 0 deletions backends/cadence/aot/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,16 @@
- arg_meta: null
kernel_name: impl::reference::quantized_fully_connected_per_tensor_out

- func: cadence::quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out

- func: cadence::quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out

- func: cadence::requantize.out(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
Expand Down
18 changes: 14 additions & 4 deletions backends/cadence/aot/functions_hifi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -329,17 +329,27 @@
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_relu_per_tensor_out

- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
- func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_fully_connected_out
kernel_name: cadence::impl::HiFi::quantized_matmul_out

- func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_matmul_out
kernel_name: cadence::impl::HiFi::quantized_fully_connected_out

- func: cadence::quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_fully_connected_per_tensor_out

- func: cadence::quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out

- func: cadence::quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out
60 changes: 60 additions & 0 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,14 @@
"quantized_fully_connected.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
)
lib.define(
"quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
)
lib.define(
"quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
)
lib.define("where_Scalar(Tensor condition, float self, float other) -> (Tensor Z)")
lib.define(
"where_Scalar.out(Tensor condition, float self, float other, *, Tensor(a!) out) -> Tensor(a!)"
Expand Down Expand Up @@ -240,6 +248,14 @@
"quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, "
"Tensor indices, bool pruned_weights=False, *, Tensor(a!) out) -> Tensor(a!)"
Expand Down Expand Up @@ -754,6 +770,50 @@ def quantized_fully_connected_per_tensor_meta(
return src.new_empty(out_size, dtype=src.dtype)


@register_fake("cadence::quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor")
def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_meta(
src: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
in_zero_point: int,
weight_zero_point: int,
out_multiplier: int,
out_shift: int,
out_zero_point: int,
offset: Optional[torch.Tensor],
) -> torch.Tensor:
# src comes in shape [leading_dims, in_dim]
# weight comes in shape [out_dim, in_dim]
# output comes in empty with shape [leading_dims, out_dim]
out_size = list(src.size())
weight_size = list(weight.size())
assert len(weight_size) == 2
out_size[-1] = weight_size[0]
return src.new_empty(out_size, dtype=src.dtype)


@register_fake("cadence::quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor")
def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_meta(
src: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
in_zero_point: int,
weight_zero_point: int,
out_multiplier: int,
out_shift: int,
out_zero_point: int,
offset: Optional[torch.Tensor],
) -> torch.Tensor:
# src comes in shape [leading_dims, in_dim]
# weight comes in shape [out_dim, in_dim]
# output comes in empty with shape [leading_dims, out_dim]
out_size = list(src.size())
weight_size = list(weight.size())
assert len(weight_size) == 2
out_size[-1] = weight_size[0]
return src.new_empty(out_size, dtype=src.dtype)


@register_fake("cadence::convolution")
def convolution_meta(
input: torch.Tensor,
Expand Down
2 changes: 2 additions & 0 deletions backends/cadence/aot/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ReplaceMulTensorWithMulAndFullOpsPass,
)
from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph
from executorch.backends.cadence.aot.type_dispatch import CompileTimeTypeDispatchPass
from executorch.exir import EdgeProgramManager
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.pass_manager import PassManager, PassType
Expand Down Expand Up @@ -90,6 +91,7 @@ def get_passes_in_default_order() -> list[Type[ExportPass]]:
FuseFullThenReshapePass,
FuseTransposeOrPermuteOpPairsPass,
RemoveNopSliceOrViewOpPass,
CompileTimeTypeDispatchPass,
]
return pytree.tree_flatten(passes)[0]

Expand Down
87 changes: 87 additions & 0 deletions backends/cadence/aot/tests/test_type_dispatch_passes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict

import unittest
from typing import cast

import executorch.backends.cadence.aot.ops_registrations # noqa
import torch
from executorch.backends.cadence.aot.graph_builder import single_op_builder
from executorch.backends.cadence.aot.pass_utils import count_node
from executorch.backends.cadence.aot.type_dispatch import CompileTimeTypeDispatchPass
from executorch.exir.dialects._ops import ops as exir_ops
from torch.fx.passes.infra.pass_base import PassResult


class TestTypeDispatchPasses(unittest.TestCase):
def test_int8_dispatch(self) -> None:
"""Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant"""
x = torch.randint(-128, 127, (2, 3), dtype=torch.int8)
w = torch.randint(-128, 127, (4, 3), dtype=torch.int8)
b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32)
gm = single_op_builder(
placeholders=(x, w, b),
op=exir_ops.edge.cadence.quantized_fully_connected.per_tensor,
args=(x, w, b, 0, 0, 1, 0, 0, None),
)
p = CompileTimeTypeDispatchPass()
gm = cast(PassResult, p(gm)).graph_module
# Original op should be replaced
self.assertEqual(
count_node(gm, exir_ops.edge.cadence.quantized_fully_connected.per_tensor),
0,
)
# Should be replaced with int8 specific variant
self.assertEqual(
count_node(
gm,
exir_ops.edge.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor,
),
1,
)

def test_uint8_dispatch(self) -> None:
"""Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant"""
x = torch.randint(0, 255, (2, 3), dtype=torch.uint8)
w = torch.randint(0, 255, (4, 3), dtype=torch.uint8)
b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32)
gm = single_op_builder(
placeholders=(x, w, b),
op=exir_ops.edge.cadence.quantized_fully_connected.per_tensor,
args=(x, w, b, 0, 0, 1, 0, 0, None),
)
p = CompileTimeTypeDispatchPass()
gm = cast(PassResult, p(gm)).graph_module
# Original op should be replaced
self.assertEqual(
count_node(gm, exir_ops.edge.cadence.quantized_fully_connected.per_tensor),
0,
)
# Should be replaced with uint8 specific variant
self.assertEqual(
count_node(
gm,
exir_ops.edge.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor,
),
1,
)

def test_mixed_types_error(self) -> None:
"""Test mixed int8/uint8 inputs should raise RuntimeError"""
x = torch.randint(-128, 127, (2, 3), dtype=torch.int8)
w = torch.randint(0, 255, (4, 3), dtype=torch.uint8)
b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32)
gm = single_op_builder(
placeholders=(x, w, b),
op=exir_ops.edge.cadence.quantized_fully_connected.per_tensor,
args=(x, w, b, 0, 0, 1, 0, 0, None),
)
p = CompileTimeTypeDispatchPass()
# Mixed types should raise RuntimeError
with self.assertRaises(RuntimeError) as context:
cast(PassResult, p(gm)).graph_module
self.assertIn("Unsupported input types", str(context.exception))
62 changes: 62 additions & 0 deletions backends/cadence/aot/type_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import torch
from executorch.backends.cadence.aot.pass_utils import (
CadencePassAttribute,
register_cadence_pass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
from torch._ops import OpOverload
from torch.fx.node import Argument


@register_cadence_pass(CadencePassAttribute(opt_level=4))
class CompileTimeTypeDispatchPass(ExportPass):
"""
Replaces generic ops with ops that have explicit types.
"""

def call_operator(
self,
op: OpOverload,
args: tuple[Argument, ...],
kwargs: dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op not in {
exir_ops.edge.cadence.quantized_fully_connected.per_tensor,
}:
return super().call_operator(op, args, kwargs, meta)

if (
# pyre-ignore[16]: None has no attribute `to_tensor`.
args[0].to_tensor().dtype == torch.int8
and args[1].to_tensor().dtype == torch.int8
):
return super().call_operator(
exir_ops.edge.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor,
args,
kwargs,
meta,
)
elif (
args[0].to_tensor().dtype == torch.uint8
and args[1].to_tensor().dtype == torch.uint8
):
return super().call_operator(
exir_ops.edge.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor,
args,
kwargs,
meta,
)
else:
raise RuntimeError(
f"Unsupported input types for {op}: {args[0].to_tensor().dtype} and {args[1].to_tensor().dtype}"
)
Loading
Loading