diff --git a/backends/cadence/hifi/kernels/kernels.cpp b/backends/cadence/hifi/kernels/kernels.cpp index 10e5fb176e0..1b335c846bb 100644 --- a/backends/cadence/hifi/kernels/kernels.cpp +++ b/backends/cadence/hifi/kernels/kernels.cpp @@ -165,6 +165,7 @@ void requantize( typed_quantize_val(int8_t); typed_quantize_val(uint8_t); typed_quantize_val(int16_t); +typed_quantize_val(uint16_t); #undef typed_quantize_val #define typed_quantize_vec(dtype) \ @@ -177,6 +178,7 @@ typed_quantize_val(int16_t); typed_quantize_vec(int8_t); typed_quantize_vec(uint8_t); typed_quantize_vec(int16_t); +typed_quantize_vec(uint16_t); typed_quantize_vec(int32_t); #undef typed_quantize_vec @@ -186,6 +188,7 @@ typed_quantize_vec(int32_t); typed_dequantize_val(int8_t); typed_dequantize_val(uint8_t); typed_dequantize_val(int16_t); +typed_dequantize_val(uint16_t); #undef typed_dequantize_val #define typed_dequantize_vec(dtype) \ @@ -198,6 +201,7 @@ typed_dequantize_val(int16_t); typed_dequantize_vec(int8_t); typed_dequantize_vec(uint8_t); typed_dequantize_vec(int16_t); +typed_dequantize_vec(uint16_t); typed_dequantize_vec(int32_t); #undef typed_dequantize_vec diff --git a/backends/cadence/hifi/operators/dequantize_per_tensor.cpp b/backends/cadence/hifi/operators/dequantize_per_tensor.cpp index 18381a26e0a..996d753c59d 100644 --- a/backends/cadence/hifi/operators/dequantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/dequantize_per_tensor.cpp @@ -41,6 +41,9 @@ void dequantize_per_tensor_out( } else if (input.scalar_type() == ScalarType::Short) { const int16_t* input_data = input.const_data_ptr(); dequantize(out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Bits16) { + const uint16_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); } else if (input.scalar_type() == ScalarType::Int) { const int32_t* input_data = input.const_data_ptr(); dequantize(out_data, input_data, scale, zero_point, numel); diff --git a/backends/cadence/hifi/operators/quantize_per_tensor.cpp b/backends/cadence/hifi/operators/quantize_per_tensor.cpp index c65d62968f5..1078b5716c1 100644 --- a/backends/cadence/hifi/operators/quantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/quantize_per_tensor.cpp @@ -44,6 +44,10 @@ void quantize_per_tensor_out( int16_t* out_data = out.mutable_data_ptr(); cadence::impl::HiFi::kernels::quantize( out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Bits16) { + uint16_t* out_data = out.mutable_data_ptr(); + cadence::impl::HiFi::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); } else if (out.scalar_type() == ScalarType::Int) { int32_t* out_data = out.mutable_data_ptr(); cadence::impl::HiFi::kernels::quantize( diff --git a/backends/cadence/reference/kernels/kernels.cpp b/backends/cadence/reference/kernels/kernels.cpp index 4d4ff26c3fb..faac3d7cb27 100644 --- a/backends/cadence/reference/kernels/kernels.cpp +++ b/backends/cadence/reference/kernels/kernels.cpp @@ -65,6 +65,7 @@ void dequantize( typed_quantize_val(int8_t); typed_quantize_val(uint8_t); typed_quantize_val(int16_t); +typed_quantize_val(uint16_t); typed_quantize_val(int32_t); #undef typed_quantize_val @@ -78,6 +79,7 @@ typed_quantize_val(int32_t); typed_quantize_vec(int8_t); typed_quantize_vec(uint8_t); typed_quantize_vec(int16_t); +typed_quantize_vec(uint16_t); typed_quantize_vec(int32_t); #undef typed_quantize_vec @@ -86,6 +88,7 @@ typed_quantize_vec(int32_t); typed_dequantize_val(int8_t); typed_dequantize_val(uint8_t); typed_dequantize_val(int16_t); +typed_dequantize_val(uint16_t); typed_dequantize_val(int32_t); #undef typed_dequantize_val @@ -99,6 +102,7 @@ typed_dequantize_val(int32_t); typed_dequantize_vec(int8_t); typed_dequantize_vec(uint8_t); typed_dequantize_vec(int16_t); +typed_dequantize_vec(uint16_t); typed_dequantize_vec(int32_t); #undef typed_dequantize_vec diff --git a/backends/cadence/reference/operators/dequantize_per_tensor.cpp b/backends/cadence/reference/operators/dequantize_per_tensor.cpp index aef730bfd1b..b49c045b94f 100644 --- a/backends/cadence/reference/operators/dequantize_per_tensor.cpp +++ b/backends/cadence/reference/operators/dequantize_per_tensor.cpp @@ -37,6 +37,14 @@ void dequantize_per_tensor_out( const int8_t* input_data = input.const_data_ptr(); impl::reference::kernels::dequantize( out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Bits16) { + const uint16_t* input_data = input.const_data_ptr(); + impl::reference::kernels::dequantize( + out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Short) { + const int16_t* input_data = input.const_data_ptr(); + impl::reference::kernels::dequantize( + out_data, input_data, scale, zero_point, numel); } else if (input.scalar_type() == ScalarType::Int) { const int32_t* input_data = input.const_data_ptr(); impl::reference::kernels::dequantize( diff --git a/backends/cadence/reference/operators/quantize_per_tensor.cpp b/backends/cadence/reference/operators/quantize_per_tensor.cpp index 0d7ff0bc7ea..ad5fa791b51 100644 --- a/backends/cadence/reference/operators/quantize_per_tensor.cpp +++ b/backends/cadence/reference/operators/quantize_per_tensor.cpp @@ -39,6 +39,14 @@ void quantize_per_tensor_out( int8_t* out_data = out.mutable_data_ptr(); impl::reference::kernels::quantize( out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Bits16) { + uint16_t* out_data = out.mutable_data_ptr(); + impl::reference::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Short) { + int16_t* out_data = out.mutable_data_ptr(); + impl::reference::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); } else if (out.scalar_type() == ScalarType::Int) { int32_t* out_data = out.mutable_data_ptr(); impl::reference::kernels::quantize(