diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 8492bb55877..e257df37c8a 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -101,6 +101,7 @@ python_library( ":reorder_ops", ":replace_ops", ":simplify_ops", + ":type_dispatch", ":utils", "//caffe2:torch", "//executorch/exir:pass_base", @@ -322,6 +323,37 @@ python_library( ], ) +python_library( + name = "type_dispatch", + srcs = [ + "type_dispatch.py", + ], + typing = True, + deps = [ + "//caffe2:torch", + "//executorch/backends/cadence/aot:pass_utils", + "//executorch/exir:pass_base", + ], +) + +python_unittest( + name = "test_type_dispatch_passes", + srcs = [ + "tests/test_type_dispatch_passes.py", + ], + supports_static_listing = False, + typing = True, + deps = [ + ":ops_registrations", + ":type_dispatch", + "//caffe2:torch", + "//executorch/backends/cadence/aot:graph_builder", + "//executorch/backends/cadence/aot:pass_utils", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], +) + python_library( name = "typing_stubs", srcs = [ diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index 9dbf28f3114..68146760d9b 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -254,6 +254,16 @@ - arg_meta: null kernel_name: impl::reference::quantized_fully_connected_per_tensor_out +- func: cadence::quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out + +- func: cadence::quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out + - func: cadence::requantize.out(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null diff --git a/backends/cadence/aot/functions_hifi.yaml b/backends/cadence/aot/functions_hifi.yaml index 04228f40be7..7a9000b530b 100644 --- a/backends/cadence/aot/functions_hifi.yaml +++ b/backends/cadence/aot/functions_hifi.yaml @@ -329,17 +329,27 @@ - arg_meta: null kernel_name: cadence::impl::HiFi::quantized_relu_per_tensor_out -- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_fully_connected_out + kernel_name: cadence::impl::HiFi::quantized_matmul_out -- func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_matmul_out + kernel_name: cadence::impl::HiFi::quantized_fully_connected_out - func: cadence::quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null kernel_name: cadence::impl::HiFi::quantized_fully_connected_per_tensor_out + +- func: cadence::quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: cadence::impl::HiFi::quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out + +- func: cadence::quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: cadence::impl::HiFi::quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 5713861103c..91ed3560a04 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -162,6 +162,14 @@ "quantized_fully_connected.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " "int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)" ) +lib.define( + "quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " + "int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)" +) +lib.define( + "quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " + "int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)" +) lib.define("where_Scalar(Tensor condition, float self, float other) -> (Tensor Z)") lib.define( "where_Scalar.out(Tensor condition, float self, float other, *, Tensor(a!) out) -> Tensor(a!)" @@ -240,6 +248,14 @@ "quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " "int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " + "int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " + "int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" +) lib.define( "quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, " "Tensor indices, bool pruned_weights=False, *, Tensor(a!) out) -> Tensor(a!)" @@ -754,6 +770,50 @@ def quantized_fully_connected_per_tensor_meta( return src.new_empty(out_size, dtype=src.dtype) +@register_fake("cadence::quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor") +def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_meta( + src: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + in_zero_point: int, + weight_zero_point: int, + out_multiplier: int, + out_shift: int, + out_zero_point: int, + offset: Optional[torch.Tensor], +) -> torch.Tensor: + # src comes in shape [leading_dims, in_dim] + # weight comes in shape [out_dim, in_dim] + # output comes in empty with shape [leading_dims, out_dim] + out_size = list(src.size()) + weight_size = list(weight.size()) + assert len(weight_size) == 2 + out_size[-1] = weight_size[0] + return src.new_empty(out_size, dtype=src.dtype) + + +@register_fake("cadence::quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor") +def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_meta( + src: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + in_zero_point: int, + weight_zero_point: int, + out_multiplier: int, + out_shift: int, + out_zero_point: int, + offset: Optional[torch.Tensor], +) -> torch.Tensor: + # src comes in shape [leading_dims, in_dim] + # weight comes in shape [out_dim, in_dim] + # output comes in empty with shape [leading_dims, out_dim] + out_size = list(src.size()) + weight_size = list(weight.size()) + assert len(weight_size) == 2 + out_size[-1] = weight_size[0] + return src.new_empty(out_size, dtype=src.dtype) + + @register_fake("cadence::convolution") def convolution_meta( input: torch.Tensor, diff --git a/backends/cadence/aot/passes.py b/backends/cadence/aot/passes.py index d7c692f12e9..bb4a8f065d5 100644 --- a/backends/cadence/aot/passes.py +++ b/backends/cadence/aot/passes.py @@ -33,6 +33,7 @@ ReplaceMulTensorWithMulAndFullOpsPass, ) from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph +from executorch.backends.cadence.aot.type_dispatch import CompileTimeTypeDispatchPass from executorch.exir import EdgeProgramManager from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.pass_manager import PassManager, PassType @@ -90,6 +91,7 @@ def get_passes_in_default_order() -> list[Type[ExportPass]]: FuseFullThenReshapePass, FuseTransposeOrPermuteOpPairsPass, RemoveNopSliceOrViewOpPass, + CompileTimeTypeDispatchPass, ] return pytree.tree_flatten(passes)[0] diff --git a/backends/cadence/aot/tests/test_type_dispatch_passes.py b/backends/cadence/aot/tests/test_type_dispatch_passes.py new file mode 100644 index 00000000000..f29a13a5bf8 --- /dev/null +++ b/backends/cadence/aot/tests/test_type_dispatch_passes.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# pyre-strict + +import unittest +from typing import cast + +import executorch.backends.cadence.aot.ops_registrations # noqa +import torch +from executorch.backends.cadence.aot.graph_builder import single_op_builder +from executorch.backends.cadence.aot.pass_utils import count_node +from executorch.backends.cadence.aot.type_dispatch import CompileTimeTypeDispatchPass +from executorch.exir.dialects._ops import ops as exir_ops +from torch.fx.passes.infra.pass_base import PassResult + + +class TestTypeDispatchPasses(unittest.TestCase): + def test_int8_dispatch(self) -> None: + """Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant""" + x = torch.randint(-128, 127, (2, 3), dtype=torch.int8) + w = torch.randint(-128, 127, (4, 3), dtype=torch.int8) + b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32) + gm = single_op_builder( + placeholders=(x, w, b), + op=exir_ops.edge.cadence.quantized_fully_connected.per_tensor, + args=(x, w, b, 0, 0, 1, 0, 0, None), + ) + p = CompileTimeTypeDispatchPass() + gm = cast(PassResult, p(gm)).graph_module + # Original op should be replaced + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_fully_connected.per_tensor), + 0, + ) + # Should be replaced with int8 specific variant + self.assertEqual( + count_node( + gm, + exir_ops.edge.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor, + ), + 1, + ) + + def test_uint8_dispatch(self) -> None: + """Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant""" + x = torch.randint(0, 255, (2, 3), dtype=torch.uint8) + w = torch.randint(0, 255, (4, 3), dtype=torch.uint8) + b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32) + gm = single_op_builder( + placeholders=(x, w, b), + op=exir_ops.edge.cadence.quantized_fully_connected.per_tensor, + args=(x, w, b, 0, 0, 1, 0, 0, None), + ) + p = CompileTimeTypeDispatchPass() + gm = cast(PassResult, p(gm)).graph_module + # Original op should be replaced + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_fully_connected.per_tensor), + 0, + ) + # Should be replaced with uint8 specific variant + self.assertEqual( + count_node( + gm, + exir_ops.edge.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor, + ), + 1, + ) + + def test_mixed_types_error(self) -> None: + """Test mixed int8/uint8 inputs should raise RuntimeError""" + x = torch.randint(-128, 127, (2, 3), dtype=torch.int8) + w = torch.randint(0, 255, (4, 3), dtype=torch.uint8) + b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32) + gm = single_op_builder( + placeholders=(x, w, b), + op=exir_ops.edge.cadence.quantized_fully_connected.per_tensor, + args=(x, w, b, 0, 0, 1, 0, 0, None), + ) + p = CompileTimeTypeDispatchPass() + # Mixed types should raise RuntimeError + with self.assertRaises(RuntimeError) as context: + cast(PassResult, p(gm)).graph_module + self.assertIn("Unsupported input types", str(context.exception)) diff --git a/backends/cadence/aot/type_dispatch.py b/backends/cadence/aot/type_dispatch.py new file mode 100644 index 00000000000..431fcd4a0f2 --- /dev/null +++ b/backends/cadence/aot/type_dispatch.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import torch +from executorch.backends.cadence.aot.pass_utils import ( + CadencePassAttribute, + register_cadence_pass, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue +from torch._ops import OpOverload +from torch.fx.node import Argument + + +@register_cadence_pass(CadencePassAttribute(opt_level=4)) +class CompileTimeTypeDispatchPass(ExportPass): + """ + Replaces generic ops with ops that have explicit types. + """ + + def call_operator( + self, + op: OpOverload, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op not in { + exir_ops.edge.cadence.quantized_fully_connected.per_tensor, + }: + return super().call_operator(op, args, kwargs, meta) + + if ( + # pyre-ignore[16]: None has no attribute `to_tensor`. + args[0].to_tensor().dtype == torch.int8 + and args[1].to_tensor().dtype == torch.int8 + ): + return super().call_operator( + exir_ops.edge.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor, + args, + kwargs, + meta, + ) + elif ( + args[0].to_tensor().dtype == torch.uint8 + and args[1].to_tensor().dtype == torch.uint8 + ): + return super().call_operator( + exir_ops.edge.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor, + args, + kwargs, + meta, + ) + else: + raise RuntimeError( + f"Unsupported input types for {op}: {args[0].to_tensor().dtype} and {args[1].to_tensor().dtype}" + ) diff --git a/backends/cadence/hifi/operators/op_quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out.cpp new file mode 100644 index 00000000000..5e3a5173f32 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace cadence { +namespace impl { +namespace HiFi { +namespace native { + +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; +using std::optional; + +void quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + __ET_UNUSED const optional& offset, + Tensor& out) { + // input comes in shape [leading_dims, in_dim] + // weight comes in shape [out_dim, in_dim] + // output comes in empty with shape [leading_dims, out_dim] + // Perform matrix multiply (M x N) x (N x P)' => M x P + int64_t leading_dims = 1; + int64_t out_dim = weight.size(0); // = out_dim + int64_t in_dim = weight.size(1); // = in_dim + + const int8_t* __restrict__ in_data = in.const_data_ptr(); + const int8_t* __restrict__ weight_data = weight.const_data_ptr(); + const int32_t* __restrict__ bias_data = bias.const_data_ptr(); + int8_t* __restrict__ out_data = out.mutable_data_ptr(); + + int32_t ret = xa_nn_fully_connected_asym8sxasym8s_asym8s( + out_data, + weight_data, + in_data, + bias_data, + in_dim, // weight_depth, number of columns in weight + out_dim, // out_depth, number of rows in weight + -in_zero_point, + -static_cast(weight_zero_point), + static_cast(out_multiplier), + static_cast(out_shift), + out_zero_point); + ET_DCHECK_MSG(ret == 0, "HiFi quantized::fully_connected failed"); +} + +} // namespace native +} // namespace HiFi +} // namespace impl +} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out.cpp new file mode 100644 index 00000000000..80509fdd5db --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace cadence { +namespace impl { +namespace HiFi { +namespace native { + +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; +using std::optional; + +void quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + __ET_UNUSED const optional& offset, + Tensor& out) { + // input comes in shape [leading_dims, in_dim] + // weight comes in shape [out_dim, in_dim] + // output comes in empty with shape [leading_dims, out_dim] + // Perform matrix multiply (M x N) x (N x P)' => M x P + int64_t leading_dims = 1; + int64_t out_dim = weight.size(0); // = out_dim + int64_t in_dim = weight.size(1); // = in_dim + + const uint8_t* __restrict__ in_data = in.const_data_ptr(); + const uint8_t* __restrict__ weight_data = weight.const_data_ptr(); + const int32_t* __restrict__ bias_data = bias.const_data_ptr(); + uint8_t* __restrict__ out_data = out.mutable_data_ptr(); + + int32_t ret = xa_nn_fully_connected_asym8uxasym8u_asym8u( + out_data, + weight_data, + in_data, + bias_data, + in_dim, // weight_depth, number of columns in weight + out_dim, // out_depth, number of rows in weight + -in_zero_point, + -static_cast(weight_zero_point), + static_cast(out_multiplier), + static_cast(out_shift), + out_zero_point); + ET_DCHECK_MSG(ret == 0, "HiFi quantized::fully_connected failed"); +} + +} // namespace native +} // namespace HiFi +} // namespace impl +} // namespace cadence diff --git a/backends/cadence/hifi/operators/targets.bzl b/backends/cadence/hifi/operators/targets.bzl index bd9658cc2f9..9a797874cef 100644 --- a/backends/cadence/hifi/operators/targets.bzl +++ b/backends/cadence/hifi/operators/targets.bzl @@ -65,6 +65,8 @@ OPERATORS = [ "pow", "quantized_conv_out", "quantized_fully_connected_out", + "quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out", + "quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out", "quantized_layer_norm", "quantized_linear_out", "quantized_matmul_out", diff --git a/backends/cadence/reference/operators/quantized_fully_connected_out.cpp b/backends/cadence/reference/operators/quantized_fully_connected_out.cpp index fe41c2d7e77..136055de70a 100644 --- a/backends/cadence/reference/operators/quantized_fully_connected_out.cpp +++ b/backends/cadence/reference/operators/quantized_fully_connected_out.cpp @@ -92,6 +92,80 @@ void quantized_fully_connected_per_tensor_out( #undef typed_quantized_linear } +void quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + __ET_UNUSED const optional& offset, + Tensor& out) { +#define typed_quantized_linear(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_linear_per_tensor_( \ + in, \ + weight, \ + bias, \ + in_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_linear +} + +void quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + __ET_UNUSED const optional& offset, + Tensor& out) { +#define typed_quantized_linear(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_linear_per_tensor_( \ + in, \ + weight, \ + bias, \ + in_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_linear +} + }; // namespace native }; // namespace reference }; // namespace impl