diff --git a/backends/cadence/hifi/operators/op_quantize_per_tensor.cpp b/backends/cadence/hifi/operators/op_quantize_per_tensor.cpp index cd596417916..df09c7a12dd 100644 --- a/backends/cadence/hifi/operators/op_quantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/op_quantize_per_tensor.cpp @@ -20,11 +20,26 @@ namespace cadence { namespace impl { namespace HiFi { namespace native { - +namespace { using ::executorch::aten::ScalarType; using ::executorch::aten::Tensor; using ::executorch::runtime::KernelRuntimeContext; +// Add checks for dtype quant min/max bounds. +template +void check_quant_min_and_max( + KernelRuntimeContext& ctx, + const int64_t quant_min, + const int64_t quant_max) { + ET_KERNEL_CHECK( + ctx, + std::numeric_limits::min() == quant_min && + std::numeric_limits::max() == quant_max, + InvalidArgument, ); +} + +} // namespace + // Quantize the input tensor (PT2 version). Note that quant_ are not // used in any computation. void quantize_per_tensor_out( @@ -36,15 +51,43 @@ void quantize_per_tensor_out( __ET_UNUSED int64_t quant_max, const ScalarType dtype, Tensor& out) { - // Add checks for dtype quant min/max bounds. - ET_SWITCH_REALB_TYPES( - out.scalar_type(), ctx, "quantize_per_tensor", OUT_DTYPE, [&]() { - ET_KERNEL_CHECK( - ctx, - std::numeric_limits::min() == quant_min && - std::numeric_limits::max() == quant_max, - InvalidArgument, ); - }); + // Check for input scalar type. + ET_KERNEL_CHECK_MSG( + ctx, + input.scalar_type() == ScalarType::Float, + InvalidType, + , + "Input tensor for quantize_per_tensor.out should be type %s, but got %s", + ::torch::executor::toString(ScalarType::Float), + ::torch::executor::toString(input.scalar_type())); + + // Check quant min/max for output types. + switch (out.scalar_type()) { + case ScalarType::Byte: + check_quant_min_and_max(ctx, quant_min, quant_max); + break; + case ScalarType::Char: + check_quant_min_and_max(ctx, quant_min, quant_max); + break; + case ScalarType::Short: + check_quant_min_and_max(ctx, quant_min, quant_max); + break; + case ScalarType::Bits16: + case ScalarType::UInt16: + check_quant_min_and_max(ctx, quant_min, quant_max); + break; + case ScalarType::Int: + check_quant_min_and_max(ctx, quant_min, quant_max); + break; + default: + ET_KERNEL_CHECK_MSG( + ctx, + false, + InvalidType, + , + "Unhandled output dtype %s", + ::torch::executor::toString(out.scalar_type())); + } const float* input_data = input.const_data_ptr(); const size_t numel = out.numel(); diff --git a/backends/cadence/hifi/operators/tests/test_op_quantize_per_tensor.cpp b/backends/cadence/hifi/operators/tests/test_op_quantize_per_tensor.cpp index b84c81d1d2d..c8d5b03ce75 100644 --- a/backends/cadence/hifi/operators/tests/test_op_quantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/tests/test_op_quantize_per_tensor.cpp @@ -106,7 +106,7 @@ TEST_F(HiFiQuantizePerTensorTest, ThrowKernelFailureForQuantMaxLessThanLimit) { out)); } -TEST_F(HiFiQuantizePerTensorTest, CheckSingleElementQuantize) { +TEST_F(HiFiQuantizePerTensorTest, CheckSingleElementIntQuantize) { TensorFactory tf; const std::vector sizes{1}; constexpr ScalarType kOutDtype = ScalarType::Int; @@ -132,6 +132,32 @@ TEST_F(HiFiQuantizePerTensorTest, CheckSingleElementQuantize) { EXPECT_TENSOR_EQ(out, tf_out.make(sizes, {kExpectedOutputValue})); } +TEST_F(HiFiQuantizePerTensorTest, CheckSingleElementUInt16Quantize) { + TensorFactory tf; + const std::vector sizes{1}; + constexpr ScalarType kOutDtype = ScalarType::UInt16; + TensorFactory tf_out; + Tensor out = tf_out.zeros(sizes); + // Some arbitrary values for scalar args. + constexpr double kScale = 0.01; + constexpr int64_t kZeroPoint = 32768; + constexpr int64_t kQuantMin = std::numeric_limits::min(); + constexpr int64_t kQuantMax = std::numeric_limits::max(); + constexpr float kInputValue = 100.0f; + constexpr uint16_t kExpectedOutputValue = + static_cast(kInputValue / kScale + kZeroPoint); + + quantize_per_tensor_out( + tf.make(sizes, {kInputValue}), + kScale, + kZeroPoint, + kQuantMin, + kQuantMax, + kOutDtype, + out); + EXPECT_TENSOR_EQ(out, tf_out.make(sizes, {kExpectedOutputValue})); +} + } // namespace } // namespace native } // namespace HiFi