From 73cf11112e059b148b3aa94fd94eacc23c2841f1 Mon Sep 17 00:00:00 2001 From: Rahul Chandra Date: Thu, 20 Nov 2025 13:55:48 -0800 Subject: [PATCH 1/2] 2/n Enable 16-bit activations and 8-bit weights in Cadence Quantizer for linear (#15901) Summary: # Context We continue from D84284794 to add support for 16-bit activations. Note that right now, all though they support 16-bit activations already, it's only if the weights are also 16-bits. To do this, we need to change the way we template some functions. # Current Behavior Right now, we're composing two macros together, the `ET_FORALL_JARVIS_QUANTIZED_TYPES_WITH_INT16` macro: https://www.internalfb.com/code/fbsource/[9e8c6d8466107f58aa3de1b9e4ec71c49d670a8f]/fbcode/on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/operators.h?lines=22-25 and the function macro(`quantized_linear` chosen for example): https://www.internalfb.com/code/fbsource/[9e8c6d8466107f58aa3de1b9e4ec71c49d670a8f]/fbcode/on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/quantized_linear_out.cpp?lines=30-41 so together, it just becomes a switch statement, calling the `quantized_linear` function with the correct template parameter. However, note that it assumes that both the input activations and weights are the same dtype, which is not the case. # This Diff We fix the generic implementation by allowing there to be two generics, one for the weight and one for the input activations. Reviewed By: hsharma35 Differential Revision: D86538176 --- .../operators/op_quantized_linear_out.cpp | 43 +++++- backends/cadence/hifi/operators/targets.bzl | 5 +- .../tests/test_op_quantized_linear_out.cpp | 132 ++++++++++++++++++ 3 files changed, 175 insertions(+), 5 deletions(-) create mode 100644 backends/cadence/hifi/operators/tests/test_op_quantized_linear_out.cpp diff --git a/backends/cadence/hifi/operators/op_quantized_linear_out.cpp b/backends/cadence/hifi/operators/op_quantized_linear_out.cpp index 84aff1c2f41..d9f4e41bc39 100644 --- a/backends/cadence/hifi/operators/op_quantized_linear_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_linear_out.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -207,7 +208,7 @@ void inline _quantized_linear_per_tensor_asym8s( } void quantized_linear_out( - __ET_UNUSED KernelRuntimeContext& ctx, + KernelRuntimeContext& ctx, const Tensor& in, const Tensor& weight, const Tensor& bias, @@ -216,9 +217,26 @@ void quantized_linear_out( const Tensor& out_multiplier, const Tensor& out_shift, int64_t out_zero_point, - __ET_UNUSED const optional& offset, + const optional& offset, Tensor& out) { - if (out.scalar_type() == executorch::aten::ScalarType::Byte) { + if (out.scalar_type() == ::executorch::aten::ScalarType::Short && + in.scalar_type() == ::executorch::aten::ScalarType::Short && + weight.scalar_type() == ::executorch::aten::ScalarType::Char) { + ::impl::generic::native::quantized_linear_out( + ctx, + in, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + offset, + out); + } + + else if (out.scalar_type() == executorch::aten::ScalarType::Byte) { _quantized_linear_asym8u( in, weight, @@ -260,7 +278,24 @@ void quantized_linear_per_tensor_out( int64_t out_zero_point, __ET_UNUSED const optional& offset, Tensor& out) { - if (out.scalar_type() == executorch::aten::ScalarType::Byte) { + if (out.scalar_type() == ::executorch::aten::ScalarType::Short && + in.scalar_type() == ::executorch::aten::ScalarType::Short && + weight.scalar_type() == ::executorch::aten::ScalarType::Char) { + ::impl::generic::native::quantized_linear_per_tensor_out( + ctx, + in, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + offset, + out); + } + + else if (out.scalar_type() == executorch::aten::ScalarType::Byte) { _quantized_linear_per_tensor_asym8u( in, weight, diff --git a/backends/cadence/hifi/operators/targets.bzl b/backends/cadence/hifi/operators/targets.bzl index a25dfd1bcbc..5d135e320bf 100644 --- a/backends/cadence/hifi/operators/targets.bzl +++ b/backends/cadence/hifi/operators/targets.bzl @@ -87,7 +87,6 @@ OPERATORS = [ "quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out", "quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out", "quantized_layer_norm", - "quantized_linear_out", "quantized_linear_asym8sxasym8s_asym8s_per_tensor_out", "quantized_linear_asym8uxasym8u_asym8u_per_tensor_out", "quantized_matmul_out", @@ -122,3 +121,7 @@ def define_common_targets(): # Define build targets for all operators registered in the tables above. for op in OPERATORS: define_operator(op) + + # quantized_linear_out and quantized_linear_per_tensor_out needs additional dependency for int16 support + define_operator("quantized_linear_out", deps=["fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators:quantize_linear_out", "fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators:headers",]) + define_operator("quantized_linear_per_tensor_out", deps=["fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators:quantize_linear_out", "fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators:headers",]) diff --git a/backends/cadence/hifi/operators/tests/test_op_quantized_linear_out.cpp b/backends/cadence/hifi/operators/tests/test_op_quantized_linear_out.cpp new file mode 100644 index 00000000000..fddf373290f --- /dev/null +++ b/backends/cadence/hifi/operators/tests/test_op_quantized_linear_out.cpp @@ -0,0 +1,132 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace impl { +namespace HiFi { +namespace native { +namespace { + +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::aten::TensorImpl; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; +using ::executorch::runtime::runtime_init; +using ::executorch::runtime::testing::TensorFactory; +using std::optional; +using std::string_view; + +class HiFiQuantizedLinearTest : public OperatorTest { + public: + protected: + void quantized_linear_out( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + int64_t out_zero_point, + const optional& offset, + Tensor& output) { + return ::impl::HiFi::native::quantized_linear_out( + context_, + input, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + offset, + output); + } + + void quantized_linear_per_tensor_out( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + const optional& offset, + Tensor& output) { + return ::impl::HiFi::native::quantized_linear_per_tensor_out( + context_, + input, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + offset, + output); + } +}; + +// Test quantized_linear_out with int16 activations (asym8s) +TEST_F(HiFiQuantizedLinearTest, QuantizedLinearInt16Test) { + TensorFactory tf_int16; + TensorFactory tf_int32; + TensorFactory tf_int8; + + // Simple 2D case: input [2, 3] x weight [4, 3] = output [2, 4] + // Values captured from e2e test with + // CadenceWith16BitLinearActivationsQuantizer + Tensor input = + tf_int16.make({2, 3}, {-28170, -26389, -32768, -31474, -32266, -29076}); + Tensor weight = tf_int8.make( + {4, 3}, {1, 87, -128, -114, -59, 44, -1, 127, -12, 44, -46, -29}); + Tensor bias = tf_int32.zeros({4}); + Tensor output = tf_int16.zeros({2, 4}); + + int64_t in_zero_point = -29822; + Tensor weight_zero_point = tf_int32.make({1}, {2}); + Tensor out_multiplier = tf_int32.make({1}, {2011373824}); + Tensor out_shift = tf_int32.make({1}, {-8}); + int64_t out_zero_point = -30847; + quantized_linear_out( + input, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + std::nullopt, + output); + // Expected output from e2e test + Tensor expected_output = tf_int16.make( + {2, 4}, {-28384, -32767, -29144, -30862, -31956, -29486, -31985, -30756}); + EXPECT_TENSOR_CLOSE(output, expected_output); +} + +} // namespace +} // namespace native +} // namespace HiFi +} // namespace impl From e064f8e94734c0f552f5ba1dab1856be1d9112c1 Mon Sep 17 00:00:00 2001 From: Rahul Chandra Date: Thu, 20 Nov 2025 13:55:48 -0800 Subject: [PATCH 2/2] Enable 16-bit activations and 8-bit weights in Cadence Quantizer for Conv Summary: # Context We continue from D84284794 to add support for 16-bit activations. Note that right now, all though they support 16-bit activations already, it's only if the weights are also 16-bits. To do this, we need to change the way we template some functions. # Current Behavior Right now, we're composing two macros together, the `ET_FORALL_JARVIS_QUANTIZED_TYPES_WITH_INT16` macro: https://www.internalfb.com/code/fbsource/[9e8c6d8466107f58aa3de1b9e4ec71c49d670a8f]/fbcode/on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/operators.h?lines=22-25 and the function macro(`quantized_linear` chosen for example): https://www.internalfb.com/code/fbsource/[9e8c6d8466107f58aa3de1b9e4ec71c49d670a8f]/fbcode/on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/quantized_linear_out.cpp?lines=30-41 so together, it just becomes a switch statement, calling the `quantized_linear` function with the correct template parameter. However, note that it assumes that both the input activations and weights are the same dtype, which is not the case. # This Diff We fix this checking for our datatypes, and calling the functions with the correct data types, as in D86538176. Reviewed By: hsharma35 Differential Revision: D86643471 --- backends/cadence/aot/quantizer/quantizer.py | 14 ++ .../op_quantized_conv2d_nchw_out.cpp | 49 ++++ .../op_quantized_conv2d_nhwc_out.cpp | 49 ++++ backends/cadence/hifi/operators/targets.bzl | 6 +- .../tests/test_op_quantized_conv2d_out.cpp | 222 ++++++++++++++++++ backends/cadence/vision/kernels/kernels.cpp | 16 +- 6 files changed, 352 insertions(+), 4 deletions(-) create mode 100644 backends/cadence/hifi/operators/tests/test_op_quantized_conv2d_out.cpp diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 70b16b86fda..7dac4049feb 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -372,3 +372,17 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: # Add 16-bit quantizers for LinearPattern quantizers.append(CadenceAtenQuantizer(LinearPattern(), qconfig_A16)) super().__init__(quantizers) + + +class CadenceWith16BitConvActivationsQuantizer(CadenceQuantizer): + """ + Quantizer including A16 conv + """ + + def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: + if quantizers is None: + quantizers = [] + # Add 16-bit quantizers for Conv patterns + quantizers.append(CadenceAtenQuantizer(Conv1dPattern(), qconfig_A16)) + quantizers.append(CadenceAtenQuantizer(Conv2dPattern(), qconfig_A16)) + super().__init__(quantizers) diff --git a/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_out.cpp index 984747d9316..79660beee4d 100644 --- a/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_out.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) @@ -532,6 +533,30 @@ void quantized_conv2d_nchw_out( __ET_UNUSED const Tensor& out_multiplier, __ET_UNUSED const Tensor& out_shift, Tensor& out) { + // Handle W8A16 heterogeneous type (int16_t activations, int8_t weights) + if (out.scalar_type() == ::executorch::aten::ScalarType::Short && + input.scalar_type() == ::executorch::aten::ScalarType::Short && + weight.scalar_type() == ::executorch::aten::ScalarType::Char) { + ::impl::generic::native::quantized_conv2d_nchw_out( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + out); + return; + } + const float bias_scale_float = bias_scale.const_data_ptr()[0]; const int32_t weight_zero_point_int = weight_zero_point.const_data_ptr()[0]; @@ -596,6 +621,30 @@ void quantized_conv2d_nchw_per_tensor_out( __ET_UNUSED int64_t out_multiplier, __ET_UNUSED int64_t out_shift, Tensor& out) { + // Handle W8A16 heterogeneous type (int16_t activations, int8_t weights) + if (out.scalar_type() == ::executorch::aten::ScalarType::Short && + input.scalar_type() == ::executorch::aten::ScalarType::Short && + weight.scalar_type() == ::executorch::aten::ScalarType::Char) { + ::impl::generic::native::quantized_conv2d_nchw_per_tensor_out( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + out); + return; + } + bool optimized = 0; if ((input.scalar_type() == ScalarType::Char) || diff --git a/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_out.cpp index a5d503853c4..b3e4cda036a 100644 --- a/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_out.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) @@ -438,6 +439,30 @@ void quantized_conv2d_nhwc_out( __ET_UNUSED const Tensor& out_multiplier, __ET_UNUSED const Tensor& out_shift, Tensor& out) { + // Handle W8A16 heterogeneous type (int16_t activations, int8_t weights) + if (out.scalar_type() == ::executorch::aten::ScalarType::Short && + input.scalar_type() == ::executorch::aten::ScalarType::Short && + weight.scalar_type() == ::executorch::aten::ScalarType::Char) { + ::impl::generic::native::quantized_conv2d_nhwc_out( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + out); + return; + } + const float bias_scale_float = bias_scale.const_data_ptr()[0]; const int32_t weight_zero_point_int = weight_zero_point.const_data_ptr()[0]; @@ -502,6 +527,30 @@ void quantized_conv2d_nhwc_per_tensor_out( __ET_UNUSED int64_t out_multiplier, __ET_UNUSED int64_t out_shift, Tensor& out) { + // Handle W8A16 heterogeneous type (int16_t activations, int8_t weights) + if (out.scalar_type() == ::executorch::aten::ScalarType::Short && + input.scalar_type() == ::executorch::aten::ScalarType::Short && + weight.scalar_type() == ::executorch::aten::ScalarType::Char) { + ::impl::generic::native::quantized_conv2d_nhwc_per_tensor_out( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + out); + return; + } + bool optimized = 0; if ((input.scalar_type() == ScalarType::Char) || diff --git a/backends/cadence/hifi/operators/targets.bzl b/backends/cadence/hifi/operators/targets.bzl index 5d135e320bf..b6478fe08d0 100644 --- a/backends/cadence/hifi/operators/targets.bzl +++ b/backends/cadence/hifi/operators/targets.bzl @@ -65,7 +65,6 @@ OPERATORS = [ "ne", "permute_copy", "pow", - "quantized_conv2d_nchw_out", "quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor_out", "quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor_out", "quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_out", @@ -74,7 +73,6 @@ OPERATORS = [ "quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out", "quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out", "quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out", - "quantized_conv2d_nhwc_out", "quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_out", "quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor_out", "quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor_out", @@ -125,3 +123,7 @@ def define_common_targets(): # quantized_linear_out and quantized_linear_per_tensor_out needs additional dependency for int16 support define_operator("quantized_linear_out", deps=["fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators:quantize_linear_out", "fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators:headers",]) define_operator("quantized_linear_per_tensor_out", deps=["fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators:quantize_linear_out", "fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators:headers",]) + + # quantized_conv2d_nchw_out and quantized_conv2d_nhwc_out need additional dependency for int16 support + define_operator("quantized_conv2d_nchw_out", deps=["fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators:quantize_conv2d_out", "fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators:headers",]) + define_operator("quantized_conv2d_nhwc_out", deps=["fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators:quantize_conv2d_out", "fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators:headers",]) diff --git a/backends/cadence/hifi/operators/tests/test_op_quantized_conv2d_out.cpp b/backends/cadence/hifi/operators/tests/test_op_quantized_conv2d_out.cpp new file mode 100644 index 00000000000..54e1432aa2a --- /dev/null +++ b/backends/cadence/hifi/operators/tests/test_op_quantized_conv2d_out.cpp @@ -0,0 +1,222 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace impl { +namespace HiFi { +namespace native { +namespace { + +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::aten::TensorImpl; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; +using ::executorch::runtime::runtime_init; +using ::executorch::runtime::testing::TensorFactory; +using std::optional; +using std::string_view; + +class HiFiQuantizedConv2dTest : public OperatorTest { + public: + protected: + void quantized_conv2d_nchw_out( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + ::executorch::aten::IntArrayRef stride, + ::executorch::aten::IntArrayRef padding, + ::executorch::aten::IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& output) { + return ::impl::HiFi::native::quantized_conv2d_nchw_out( + context_, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + output); + } + + void quantized_conv2d_nhwc_out( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + ::executorch::aten::IntArrayRef stride, + ::executorch::aten::IntArrayRef padding, + ::executorch::aten::IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& output) { + return ::impl::HiFi::native::quantized_conv2d_nhwc_out( + context_, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + output); + } +}; + +// Test quantized_conv2d_nchw_out with int16 activations and int8 weights +TEST_F(HiFiQuantizedConv2dTest, QuantizedConv2dNchwInt16Test) { + TensorFactory tf_int16; + TensorFactory tf_int32; + TensorFactory tf_int8; + TensorFactory tf_float; + + // Simple 2D case: input [1, 8, 20, 28] with kernel [16, 8, 3, 5] + // Using simple values for testing + Tensor input = tf_int16.ones({1, 8, 20, 28}); + Tensor weight = tf_int8.ones({16, 8, 3, 5}); + Tensor bias = tf_int32.zeros({16}); + + // Calculate output dimensions: (20-3)/1+1=18, (28-5)/1+1=24 + Tensor output = tf_int16.zeros({1, 16, 18, 24}); + + int64_t in_zero_point = 0; + Tensor weight_zero_point = tf_int32.make({1}, {0}); + Tensor bias_scale = tf_float.make({1}, {1.0f}); + double output_scale = 1.0; + int64_t output_zero_point = 0; + Tensor out_multiplier = tf_int32.make({1}, {1073741824}); // 0.5 * 2^31 + Tensor out_shift = tf_int32.make({1}, {0}); + + int64_t stride_arr[] = {1, 1}; + int64_t padding_arr[] = {0, 0}; + int64_t dilation_arr[] = {1, 1}; + + ::executorch::aten::ArrayRef stride(stride_arr, 2); + ::executorch::aten::ArrayRef padding(padding_arr, 2); + ::executorch::aten::ArrayRef dilation(dilation_arr, 2); + + quantized_conv2d_nchw_out( + input, + weight, + bias, + stride, + padding, + dilation, + 1, // groups + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + output); + + // Basic sanity check - output should be non-zero + // With all ones input and weights, and kernel size 3x5=15 * 8 channels = 120 + // Expected value per output element would be around 120 + EXPECT_NE(output.const_data_ptr()[0], 0); +} + +// Test quantized_conv2d_nhwc_out with int16 activations and int8 weights +TEST_F(HiFiQuantizedConv2dTest, QuantizedConv2dNhwcInt16Test) { + TensorFactory tf_int16; + TensorFactory tf_int32; + TensorFactory tf_int8; + TensorFactory tf_float; + + // Simple 2D case in NHWC format: input [1, 20, 28, 8] with kernel [16, 3, 5, + // 8] + Tensor input = tf_int16.ones({1, 20, 28, 8}); + Tensor weight = tf_int8.ones({16, 3, 5, 8}); + Tensor bias = tf_int32.zeros({16}); + + // Calculate output dimensions: (20-3)/1+1=18, (28-5)/1+1=24 + Tensor output = tf_int16.zeros({1, 18, 24, 16}); + + int64_t in_zero_point = 0; + Tensor weight_zero_point = tf_int32.make({1}, {0}); + Tensor bias_scale = tf_float.make({1}, {1.0f}); + double output_scale = 1.0; + int64_t output_zero_point = 0; + Tensor out_multiplier = tf_int32.make({1}, {1073741824}); // 0.5 * 2^31 + Tensor out_shift = tf_int32.make({1}, {0}); + + int64_t stride_arr[] = {1, 1}; + int64_t padding_arr[] = {0, 0}; + int64_t dilation_arr[] = {1, 1}; + + ::executorch::aten::ArrayRef stride(stride_arr, 2); + ::executorch::aten::ArrayRef padding(padding_arr, 2); + ::executorch::aten::ArrayRef dilation(dilation_arr, 2); + + quantized_conv2d_nhwc_out( + input, + weight, + bias, + stride, + padding, + dilation, + 1, // groups + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + output); + + // Basic sanity check - output should be non-zero + EXPECT_NE(output.const_data_ptr()[0], 0); +} + +} // namespace +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/vision/kernels/kernels.cpp b/backends/cadence/vision/kernels/kernels.cpp index 70c811df741..04c239df390 100644 --- a/backends/cadence/vision/kernels/kernels.cpp +++ b/backends/cadence/vision/kernels/kernels.cpp @@ -18,8 +18,20 @@ namespace vision { namespace kernels { void* allocate_temp_memory(KernelRuntimeContext& ctx, size_t size) { - Result temp_mem_res = ctx.allocate_temp(size); - return temp_mem_res.ok() ? temp_mem_res.get() : nullptr; + ET_LOG(Info, "yo"); + constexpr size_t kAlignment = + 8; // 16-byte alignment for vectorized operations + Result temp_mem_res = ctx.allocate_temp(size, kAlignment); + if (temp_mem_res.ok()) { + void* ptr = temp_mem_res.get(); + return ptr; + } else { + ET_LOG( + Error, + "Failed to allocate temp memory, error: 0x%x", + static_cast(temp_mem_res.error())); + return nullptr; + } } // Quantize a fp32 value to an int8_t/uint8_t value