diff --git a/kernels/quantized/cpu/op_dequantize.cpp b/kernels/quantized/cpu/op_dequantize.cpp index 8d73d06694e..847f764b0e4 100644 --- a/kernels/quantized/cpu/op_dequantize.cpp +++ b/kernels/quantized/cpu/op_dequantize.cpp @@ -39,6 +39,7 @@ void check_dequantize_per_tensor_args( input.scalar_type() == ScalarType::Byte || input.scalar_type() == ScalarType::Char || input.scalar_type() == ScalarType::Bits16 || + input.scalar_type() == ScalarType::UInt16 || input.scalar_type() == ScalarType::Short || input.scalar_type() == ScalarType::Int, "input.scalar_type() %" PRId8 " is not supported:", @@ -120,6 +121,7 @@ Tensor& dequantize_per_tensor_out( switch (input.scalar_type()) { ET_FORALL_INT_TYPES(CALCULATE_INT_TYPE); CALCULATE_INT_TYPE(uint16_t, Bits16); + CALCULATE_INT_TYPE(uint16_t, UInt16); default: ET_CHECK_MSG( false, @@ -315,6 +317,7 @@ Tensor& dequantize_per_channel_out( switch (input.scalar_type()) { ET_FORALL_INT_TYPES(CALCULATE_FLOAT_TYPE); CALCULATE_INT_TYPE(uint16_t, Bits16); + CALCULATE_INT_TYPE(uint16_t, UInt16); default: ET_CHECK_MSG( false, diff --git a/kernels/quantized/cpu/op_quantize.cpp b/kernels/quantized/cpu/op_quantize.cpp index 0b1b5c5529c..74c21e9f464 100644 --- a/kernels/quantized/cpu/op_quantize.cpp +++ b/kernels/quantized/cpu/op_quantize.cpp @@ -57,7 +57,7 @@ void check_quantize_per_tensor_args( static_cast(std::numeric_limits::min()); quant_max_upper_bound = static_cast(std::numeric_limits::max()); - } else if (dtype == ScalarType::Bits16) { + } else if (dtype == ScalarType::Bits16 || dtype == ScalarType::UInt16) { quant_min_lower_bound = std::numeric_limits::min(); quant_max_upper_bound = std::numeric_limits::max(); } else if (dtype == ScalarType::Short) { @@ -139,6 +139,7 @@ Tensor& quantize_per_tensor_out( switch (out.scalar_type()) { \ ET_FORALL_INT_TYPES_WITH(IN_CTYPE, QUANTIZE_IMPL); \ QUANTIZE_IMPL(IN_CTYPE, uint16_t, Bits16) \ + QUANTIZE_IMPL(IN_CTYPE, uint16_t, UInt16) \ default: \ ET_CHECK_MSG( \ false, \ @@ -334,6 +335,7 @@ Tensor& quantize_per_channel_out( switch (out.scalar_type()) { \ ET_FORALL_INT_TYPES_WITH(CTYPE_IN, QUANTIZE_IMPL); \ QUANTIZE_IMPL(CTYPE_IN, uint16_t, Bits16) \ + QUANTIZE_IMPL(CTYPE_IN, uint16_t, UInt16) \ default: \ ET_CHECK_MSG( \ false, \ diff --git a/kernels/quantized/test/op_dequantize_test.cpp b/kernels/quantized/test/op_dequantize_test.cpp index 10126264450..8d23e74e41b 100644 --- a/kernels/quantized/test/op_dequantize_test.cpp +++ b/kernels/quantized/test/op_dequantize_test.cpp @@ -63,6 +63,7 @@ TEST(OpDequantizeOutTest, AllDtypesSupported) { test_dtype(); test_dtype(); test_dtype(); + test_dtype(); test_dtype(); } diff --git a/kernels/quantized/test/op_quantize_test.cpp b/kernels/quantized/test/op_quantize_test.cpp index ce81186099b..384ba630c54 100644 --- a/kernels/quantized/test/op_quantize_test.cpp +++ b/kernels/quantized/test/op_quantize_test.cpp @@ -54,6 +54,7 @@ TEST(OpQuantizeOutTest, AllDtypesSupported) { test_dtype(); test_dtype(); test_dtype(); + test_dtype(); test_dtype(); } diff --git a/runtime/core/exec_aten/testing_util/tensor_factory.h b/runtime/core/exec_aten/testing_util/tensor_factory.h index 9f8f7e9db75..9ccda151283 100644 --- a/runtime/core/exec_aten/testing_util/tensor_factory.h +++ b/runtime/core/exec_aten/testing_util/tensor_factory.h @@ -650,6 +650,13 @@ struct ScalarTypeToCppTypeWrapper { using ctype = uint16_t; }; +// Use a C type of `uint16_t` instead of `UInt16` to simplify code reuse when +// testing multiple integer types. +template <> +struct ScalarTypeToCppTypeWrapper { + using ctype = uint16_t; +}; + // To allow implicit conversion between simple types to `ctype` #define SPECIALIZE_ScalarTypeToCppTypeWrapper(CTYPE, DTYPE) \ template <> \