From c13cfaec0b7aa72877f4faa12822534e6e57364e Mon Sep 17 00:00:00 2001 From: Riley Dulin Date: Thu, 7 Nov 2024 15:57:34 -0800 Subject: [PATCH] Add support for uint16 in quant and dequant kernels Summary: In preparation for using uint16 on Cadence, add support to the quant kernels. We use the "Bits16" type for this purpose. This is an uninterpreted dtype that only works for a subset of kernels. We plan to only support it as a quant and dequant target. Reviewed By: hsharma35, mcremon-meta Differential Revision: D65370235 --- backends/cadence/hifi/kernels/kernels.cpp | 4 ++++ backends/cadence/hifi/operators/dequantize_per_tensor.cpp | 3 +++ backends/cadence/hifi/operators/quantize_per_tensor.cpp | 4 ++++ backends/cadence/reference/kernels/kernels.cpp | 4 ++++ .../cadence/reference/operators/dequantize_per_tensor.cpp | 8 ++++++++ .../cadence/reference/operators/quantize_per_tensor.cpp | 8 ++++++++ 6 files changed, 31 insertions(+) diff --git a/backends/cadence/hifi/kernels/kernels.cpp b/backends/cadence/hifi/kernels/kernels.cpp index 10e5fb176e0..1b335c846bb 100644 --- a/backends/cadence/hifi/kernels/kernels.cpp +++ b/backends/cadence/hifi/kernels/kernels.cpp @@ -165,6 +165,7 @@ void requantize( typed_quantize_val(int8_t); typed_quantize_val(uint8_t); typed_quantize_val(int16_t); +typed_quantize_val(uint16_t); #undef typed_quantize_val #define typed_quantize_vec(dtype) \ @@ -177,6 +178,7 @@ typed_quantize_val(int16_t); typed_quantize_vec(int8_t); typed_quantize_vec(uint8_t); typed_quantize_vec(int16_t); +typed_quantize_vec(uint16_t); typed_quantize_vec(int32_t); #undef typed_quantize_vec @@ -186,6 +188,7 @@ typed_quantize_vec(int32_t); typed_dequantize_val(int8_t); typed_dequantize_val(uint8_t); typed_dequantize_val(int16_t); +typed_dequantize_val(uint16_t); #undef typed_dequantize_val #define typed_dequantize_vec(dtype) \ @@ -198,6 +201,7 @@ typed_dequantize_val(int16_t); typed_dequantize_vec(int8_t); typed_dequantize_vec(uint8_t); typed_dequantize_vec(int16_t); +typed_dequantize_vec(uint16_t); typed_dequantize_vec(int32_t); #undef typed_dequantize_vec diff --git a/backends/cadence/hifi/operators/dequantize_per_tensor.cpp b/backends/cadence/hifi/operators/dequantize_per_tensor.cpp index 18381a26e0a..996d753c59d 100644 --- a/backends/cadence/hifi/operators/dequantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/dequantize_per_tensor.cpp @@ -41,6 +41,9 @@ void dequantize_per_tensor_out( } else if (input.scalar_type() == ScalarType::Short) { const int16_t* input_data = input.const_data_ptr(); dequantize(out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Bits16) { + const uint16_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); } else if (input.scalar_type() == ScalarType::Int) { const int32_t* input_data = input.const_data_ptr(); dequantize(out_data, input_data, scale, zero_point, numel); diff --git a/backends/cadence/hifi/operators/quantize_per_tensor.cpp b/backends/cadence/hifi/operators/quantize_per_tensor.cpp index c65d62968f5..1078b5716c1 100644 --- a/backends/cadence/hifi/operators/quantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/quantize_per_tensor.cpp @@ -44,6 +44,10 @@ void quantize_per_tensor_out( int16_t* out_data = out.mutable_data_ptr(); cadence::impl::HiFi::kernels::quantize( out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Bits16) { + uint16_t* out_data = out.mutable_data_ptr(); + cadence::impl::HiFi::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); } else if (out.scalar_type() == ScalarType::Int) { int32_t* out_data = out.mutable_data_ptr(); cadence::impl::HiFi::kernels::quantize( diff --git a/backends/cadence/reference/kernels/kernels.cpp b/backends/cadence/reference/kernels/kernels.cpp index 4d4ff26c3fb..faac3d7cb27 100644 --- a/backends/cadence/reference/kernels/kernels.cpp +++ b/backends/cadence/reference/kernels/kernels.cpp @@ -65,6 +65,7 @@ void dequantize( typed_quantize_val(int8_t); typed_quantize_val(uint8_t); typed_quantize_val(int16_t); +typed_quantize_val(uint16_t); typed_quantize_val(int32_t); #undef typed_quantize_val @@ -78,6 +79,7 @@ typed_quantize_val(int32_t); typed_quantize_vec(int8_t); typed_quantize_vec(uint8_t); typed_quantize_vec(int16_t); +typed_quantize_vec(uint16_t); typed_quantize_vec(int32_t); #undef typed_quantize_vec @@ -86,6 +88,7 @@ typed_quantize_vec(int32_t); typed_dequantize_val(int8_t); typed_dequantize_val(uint8_t); typed_dequantize_val(int16_t); +typed_dequantize_val(uint16_t); typed_dequantize_val(int32_t); #undef typed_dequantize_val @@ -99,6 +102,7 @@ typed_dequantize_val(int32_t); typed_dequantize_vec(int8_t); typed_dequantize_vec(uint8_t); typed_dequantize_vec(int16_t); +typed_dequantize_vec(uint16_t); typed_dequantize_vec(int32_t); #undef typed_dequantize_vec diff --git a/backends/cadence/reference/operators/dequantize_per_tensor.cpp b/backends/cadence/reference/operators/dequantize_per_tensor.cpp index aef730bfd1b..b49c045b94f 100644 --- a/backends/cadence/reference/operators/dequantize_per_tensor.cpp +++ b/backends/cadence/reference/operators/dequantize_per_tensor.cpp @@ -37,6 +37,14 @@ void dequantize_per_tensor_out( const int8_t* input_data = input.const_data_ptr(); impl::reference::kernels::dequantize( out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Bits16) { + const uint16_t* input_data = input.const_data_ptr(); + impl::reference::kernels::dequantize( + out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Short) { + const int16_t* input_data = input.const_data_ptr(); + impl::reference::kernels::dequantize( + out_data, input_data, scale, zero_point, numel); } else if (input.scalar_type() == ScalarType::Int) { const int32_t* input_data = input.const_data_ptr(); impl::reference::kernels::dequantize( diff --git a/backends/cadence/reference/operators/quantize_per_tensor.cpp b/backends/cadence/reference/operators/quantize_per_tensor.cpp index 0d7ff0bc7ea..ad5fa791b51 100644 --- a/backends/cadence/reference/operators/quantize_per_tensor.cpp +++ b/backends/cadence/reference/operators/quantize_per_tensor.cpp @@ -39,6 +39,14 @@ void quantize_per_tensor_out( int8_t* out_data = out.mutable_data_ptr(); impl::reference::kernels::quantize( out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Bits16) { + uint16_t* out_data = out.mutable_data_ptr(); + impl::reference::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Short) { + int16_t* out_data = out.mutable_data_ptr(); + impl::reference::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); } else if (out.scalar_type() == ScalarType::Int) { int32_t* out_data = out.mutable_data_ptr(); impl::reference::kernels::quantize(