From e2050e7b85b417804d810326e43d8b76656cd4dd Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Thu, 17 Apr 2025 10:28:16 -0700 Subject: [PATCH 1/2] [cortex-m] Add scalar c++ op for quantize_per_tensor Only buck build for now, CMake is next. No MVE, scalar only. Strictly the dtypes we care about update arg_meta to reflect that. Differential Revision: [D73141767](https://our.internmc.facebook.com/intern/diff/D73141767/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D73141767/)! [ghstack-poisoned] --- backends/cortex_m/ops/TARGETS | 9 +- .../cortex_m/ops/op_quantize_per_tensor.cpp | 151 ++++++++++++++++++ backends/cortex_m/ops/operators.yaml | 11 ++ backends/cortex_m/ops/targets.bzl | 68 ++++++++ backends/cortex_m/test/TARGETS | 7 +- .../test/op_quantize_per_tensor_test.cpp | 48 ++++++ backends/cortex_m/test/targets.bzl | 36 +++++ 7 files changed, 325 insertions(+), 5 deletions(-) create mode 100644 backends/cortex_m/ops/op_quantize_per_tensor.cpp create mode 100644 backends/cortex_m/ops/operators.yaml create mode 100644 backends/cortex_m/ops/targets.bzl create mode 100644 backends/cortex_m/test/op_quantize_per_tensor_test.cpp create mode 100644 backends/cortex_m/test/targets.bzl diff --git a/backends/cortex_m/ops/TARGETS b/backends/cortex_m/ops/TARGETS index 81d3bc6eb6e..e02f096fd83 100644 --- a/backends/cortex_m/ops/TARGETS +++ b/backends/cortex_m/ops/TARGETS @@ -5,8 +5,7 @@ # LICENSE file in the root directory of this source tree. load("@fbcode_macros//build_defs:python_library.bzl", "python_library") -load("@fbcode_macros//build_defs:export_files.bzl", "export_file") -load("@fbsource//xplat/executorch/codegen:codegen.bzl", "executorch_generated_lib") +load("targets.bzl", "define_common_targets") oncall("executorch") @@ -17,5 +16,7 @@ python_library( ], deps = [ "fbcode//caffe2:torch", - ] -) + ], +) + +define_common_targets() diff --git a/backends/cortex_m/ops/op_quantize_per_tensor.cpp b/backends/cortex_m/ops/op_quantize_per_tensor.cpp new file mode 100644 index 00000000000..2303953ea91 --- /dev/null +++ b/backends/cortex_m/ops/op_quantize_per_tensor.cpp @@ -0,0 +1,151 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +// Check for Helium/MVE support +#if defined(__ARM_FEATURE_MVE) && (__ARM_FEATURE_MVE & 1) +#include +#define HAS_HELIUM_SIMD 1 +#endif + +namespace cortex_m { +namespace native { + +using Tensor = executorch::aten::Tensor; +using ScalarType = executorch::aten::ScalarType; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; + +namespace { + +/** + * Asserts that the parameters are valid for float to int8 quantization. + */ +void check_quantize_args( + const Tensor& input, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + // Ensure input is float type + ET_CHECK_MSG( + input.scalar_type() == ScalarType::Float, + "input.scalar_type() %" PRId8 " is not float type", + static_cast(input.scalar_type())); + + // Check output dtype is int8 (Char) + ET_CHECK_MSG( + out.scalar_type() == ScalarType::Char, + "out.scalar_type() %" PRId8 " is not int8 (Char)", + static_cast(out.scalar_type())); + + // Check dtype is int8 (Char) + ET_CHECK_MSG( + dtype == ScalarType::Char, + "dtype %" PRId8 " is not int8 (Char)", + static_cast(dtype)); + + // Validate quant_min and quant_max for int8 + int32_t quant_min_lower_bound = std::numeric_limits::min(); + int32_t quant_max_upper_bound = std::numeric_limits::max(); + + ET_CHECK_MSG( + quant_min >= quant_min_lower_bound, + "quant_min out of bound for int8, expected quant_min_lower_bound: %" PRId32 + " actual quant_min: %" PRId64, + quant_min_lower_bound, + quant_min); + + ET_CHECK_MSG( + quant_max <= quant_max_upper_bound, + "quant_max out of bound for int8, expected quant_max_upper_bound: %" PRId32 + " actual quant_max: %" PRId64, + quant_max_upper_bound, + quant_max); +} + +/** + * Scalar implementation of quantization for a single value. + */ +template +T quantize_val( + float inv_scale, + int32_t zero_point, + K value, + int64_t quant_min, + int64_t quant_max) { + int32_t qvalue = zero_point + static_cast(std::nearbyint(inv_scale * value)); + qvalue = std::max(qvalue, static_cast(quant_min)); + qvalue = std::min(qvalue, static_cast(quant_max)); + return static_cast(qvalue); +} + +} // namespace + +Tensor& quantize_per_tensor_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + // Ignore context for now + (void)context; + + // Resize output tensor to match input dimensions + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in quantize_per_tensor_out"); + + // Validate input parameters + check_quantize_args(input, quant_min, quant_max, dtype, out); + + // Pre-compute inverse scale for better performance + float inv_scale = 1.0f / static_cast(scale); + int32_t zp = static_cast(zero_point); + int32_t qmin = static_cast(quant_min); + int32_t qmax = static_cast(quant_max); + + // Get pointers to input and output data + const float* input_data = input.const_data_ptr(); + int8_t* out_data = out.mutable_data_ptr(); + const size_t numel = input.numel(); + +#if defined(HAS_HELIUM_SIMD) + // Helium MVE implementation for float32 to int8 quantization + #Error "Implement MVE version!" +#else + // Scalar implementation for float32 to int8 quantization + for (size_t i = 0; i < numel; i++) { + out_data[i] = quantize_val(inv_scale, zp, input_data[i], qmin, qmax); + } +#endif + + return out; +} + +Tensor& quantize_per_tensor_out( + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + KernelRuntimeContext context; + return quantize_per_tensor_out(context, input, scale, zero_point, quant_min, quant_max, dtype, out); +} + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml new file mode 100644 index 00000000000..e4c28fc678a --- /dev/null +++ b/backends/cortex_m/ops/operators.yaml @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +- func: cortex_m::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: cortex_m::quantize_per_tensor_out diff --git a/backends/cortex_m/ops/targets.bzl b/backends/cortex_m/ops/targets.bzl new file mode 100644 index 00000000000..70c81b227c5 --- /dev/null +++ b/backends/cortex_m/ops/targets.bzl @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +load("@fbsource//tools/build_defs:platform_defs.bzl", "CXX") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load("@fbsource//xplat/executorch/codegen:codegen.bzl", "et_operator_library", "executorch_generated_lib") +load("@fbcode_macros//build_defs:export_files.bzl", "export_file") + +def define_operator_target(name: str): + runtime.cxx_library( + name = "op_{}".format(name), + srcs = [ + "op_{}.cpp".format(name), + ], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes" + ], + link_whole = True, + ) + +OPERATORS = [ + "quantize_per_tensor", +] + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + for op in OPERATORS: + define_operator_target(op) + + all_op_targets = [":op_{}".format(op) for op in OPERATORS] + + runtime.cxx_library( + name = "cortex_m_operators", + srcs = [], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + exported_deps = all_op_targets, + ) + + export_file(name = "operators.yaml") + + et_operator_library( + name = "ops_lib", + _is_external_target = True, + ops_schema_yaml_target = ":operators.yaml", + ) + + executorch_generated_lib( + name = "cortex_m_generated_lib", + deps = [ + ":ops_lib", + ":cortex_m_operators", + ], + functions_yaml_target = ":operators.yaml", + platforms = CXX, + visibility = ["PUBLIC"], + define_static_targets = True, + ) diff --git a/backends/cortex_m/test/TARGETS b/backends/cortex_m/test/TARGETS index 64c0358b80a..d381011b648 100644 --- a/backends/cortex_m/test/TARGETS +++ b/backends/cortex_m/test/TARGETS @@ -5,6 +5,9 @@ # LICENSE file in the root directory of this source tree. load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") +load("targets.bzl", "define_common_targets") + +oncall("executorch") python_unittest( name = "test_replace_quant_nodes", @@ -15,4 +18,6 @@ python_unittest( "//executorch/backends/cortex_m/passes:replace_quant_nodes_pass", "//executorch/backends/cortex_m/ops:ops", ], -) +) + +define_common_targets() diff --git a/backends/cortex_m/test/op_quantize_per_tensor_test.cpp b/backends/cortex_m/test/op_quantize_per_tensor_test.cpp new file mode 100644 index 00000000000..119b10337ec --- /dev/null +++ b/backends/cortex_m/test/op_quantize_per_tensor_test.cpp @@ -0,0 +1,48 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include +#include + +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::runtime::KernelRuntimeContext; +using torch::executor::testing::TensorFactory; + +// Test op +using cortex_m::native::quantize_per_tensor_out; + +void test_dtype() { + TensorFactory tf; + + Tensor input = tf.full({3, 5}, 4); + double scale = 0.5; + + int64_t zero_point = 108; + int64_t quant_min = 0; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 5}); + // 4 / 0.5 + 108 = 116 + Tensor expected = tfo.full({3, 5}, 116); + + KernelRuntimeContext ctx; + quantize_per_tensor_out( + ctx, input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, AllDtypesSupported) { + test_dtype(); +} diff --git a/backends/cortex_m/test/targets.bzl b/backends/cortex_m/test/targets.bzl new file mode 100644 index 00000000000..2b8cc604043 --- /dev/null +++ b/backends/cortex_m/test/targets.bzl @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +OPERATORS = [ + "quantize_per_tensor", +] + +def define_operator_test_target(op): + runtime.cxx_test( + name = "op_{}_test".format(op), + srcs = [ + "op_{}_test.cpp".format(op), + ], + deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/kernels/test:test_util", + "//executorch/backends/cortex_m/ops:op_{}".format(op), + "//executorch/backends/cortex_m/ops:cortex_m_generated_lib", + ] + ) + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + for op in OPERATORS: + define_operator_test_target(op) + + From f308b356f37d485002a0dc6adbb2128f32084ae7 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Thu, 17 Apr 2025 12:47:33 -0700 Subject: [PATCH 2/2] Update on "[cortex-m] Add scalar c++ op for quantize_per_tensor" Only buck build for now, CMake is next. No MVE, scalar only. Strictly the dtypes we care about update arg_meta to reflect that. Differential Revision: [D73141767](https://our.internmc.facebook.com/intern/diff/D73141767/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D73141767/)! [ghstack-poisoned] --- .../cortex_m/ops/op_quantize_per_tensor.cpp | 17 ++++++++++------- .../test/op_quantize_per_tensor_test.cpp | 15 +++++++++++---- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/backends/cortex_m/ops/op_quantize_per_tensor.cpp b/backends/cortex_m/ops/op_quantize_per_tensor.cpp index 2303953ea91..25385602e58 100644 --- a/backends/cortex_m/ops/op_quantize_per_tensor.cpp +++ b/backends/cortex_m/ops/op_quantize_per_tensor.cpp @@ -82,7 +82,8 @@ T quantize_val( K value, int64_t quant_min, int64_t quant_max) { - int32_t qvalue = zero_point + static_cast(std::nearbyint(inv_scale * value)); + int32_t qvalue = + zero_point + static_cast(std::nearbyint(inv_scale * value)); qvalue = std::max(qvalue, static_cast(quant_min)); qvalue = std::min(qvalue, static_cast(quant_max)); return static_cast(qvalue); @@ -123,12 +124,13 @@ Tensor& quantize_per_tensor_out( const size_t numel = input.numel(); #if defined(HAS_HELIUM_SIMD) - // Helium MVE implementation for float32 to int8 quantization - #Error "Implement MVE version!" +// Helium MVE implementation for float32 to int8 quantization +#Error "Implement MVE version!" #else // Scalar implementation for float32 to int8 quantization for (size_t i = 0; i < numel; i++) { - out_data[i] = quantize_val(inv_scale, zp, input_data[i], qmin, qmax); + out_data[i] = + quantize_val(inv_scale, zp, input_data[i], qmin, qmax); } #endif @@ -143,9 +145,10 @@ Tensor& quantize_per_tensor_out( int64_t quant_max, ScalarType dtype, Tensor& out) { - KernelRuntimeContext context; - return quantize_per_tensor_out(context, input, scale, zero_point, quant_min, quant_max, dtype, out); + KernelRuntimeContext context; + return quantize_per_tensor_out( + context, input, scale, zero_point, quant_min, quant_max, dtype, out); } - + } // namespace native } // namespace cortex_m diff --git a/backends/cortex_m/test/op_quantize_per_tensor_test.cpp b/backends/cortex_m/test/op_quantize_per_tensor_test.cpp index 119b10337ec..3b6e2e582a5 100644 --- a/backends/cortex_m/test/op_quantize_per_tensor_test.cpp +++ b/backends/cortex_m/test/op_quantize_per_tensor_test.cpp @@ -7,10 +7,10 @@ */ #include // Declares the operator -#include #include -#include #include +#include +#include #include using executorch::aten::ScalarType; @@ -38,11 +38,18 @@ void test_dtype() { KernelRuntimeContext ctx; quantize_per_tensor_out( - ctx, input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out); + ctx, + input, + scale, + zero_point, + quant_min, + quant_max, + ScalarType::Char, + out); EXPECT_TENSOR_EQ(out, expected); } TEST(OpQuantizeOutTest, AllDtypesSupported) { - test_dtype(); + test_dtype(); }