Skip to content

Commit 4f7a52b

Browse files
dulinrileyfacebook-github-bot
authored andcommitted
Add UInt16 support to Cadence kernels
Summary: In preparation for using uint16 on Cadence, add support to the quant kernels. Same as: pytorch#6724, but added UInt16 as well as Bits16. Differential Revision: D66016288
1 parent b4ab76f commit 4f7a52b

File tree

4 files changed

+4
-4
lines changed

4 files changed

+4
-4
lines changed

backends/cadence/hifi/operators/dequantize_per_tensor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void dequantize_per_tensor_out(
4141
} else if (input.scalar_type() == ScalarType::Short) {
4242
const int16_t* input_data = input.const_data_ptr<int16_t>();
4343
dequantize<int16_t>(out_data, input_data, scale, zero_point, numel);
44-
} else if (input.scalar_type() == ScalarType::Bits16) {
44+
} else if (input.scalar_type() == ScalarType::Bits16 || input.scalar_type() == ScalarType::UInt16) {
4545
const uint16_t* input_data = input.const_data_ptr<uint16_t>();
4646
dequantize<uint16_t>(out_data, input_data, scale, zero_point, numel);
4747
} else if (input.scalar_type() == ScalarType::Int) {

backends/cadence/hifi/operators/quantize_per_tensor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ void quantize_per_tensor_out(
4444
int16_t* out_data = out.mutable_data_ptr<int16_t>();
4545
cadence::impl::HiFi::kernels::quantize<int16_t>(
4646
out_data, input_data, 1. / scale, zero_point, numel);
47-
} else if (out.scalar_type() == ScalarType::Bits16) {
47+
} else if (out.scalar_type() == ScalarType::Bits16 || out.scalar_type() == ScalarType::UInt16) {
4848
uint16_t* out_data = out.mutable_data_ptr<uint16_t>();
4949
cadence::impl::HiFi::kernels::quantize<uint16_t>(
5050
out_data, input_data, 1. / scale, zero_point, numel);

backends/cadence/reference/operators/dequantize_per_tensor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ void dequantize_per_tensor_out(
3737
const int8_t* input_data = input.const_data_ptr<int8_t>();
3838
impl::reference::kernels::dequantize<int8_t>(
3939
out_data, input_data, scale, zero_point, numel);
40-
} else if (input.scalar_type() == ScalarType::Bits16) {
40+
} else if (input.scalar_type() == ScalarType::Bits16 || input.scalar_type() == ScalarType::UInt16) {
4141
const uint16_t* input_data = input.const_data_ptr<uint16_t>();
4242
impl::reference::kernels::dequantize<uint16_t>(
4343
out_data, input_data, scale, zero_point, numel);

backends/cadence/reference/operators/quantize_per_tensor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ void quantize_per_tensor_out(
3939
int8_t* out_data = out.mutable_data_ptr<int8_t>();
4040
impl::reference::kernels::quantize<int8_t>(
4141
out_data, input_data, 1. / scale, zero_point, numel);
42-
} else if (out.scalar_type() == ScalarType::Bits16) {
42+
} else if (out.scalar_type() == ScalarType::Bits16 || out.scalar_type() == ScalarType::UInt16) {
4343
uint16_t* out_data = out.mutable_data_ptr<uint16_t>();
4444
impl::reference::kernels::quantize<uint16_t>(
4545
out_data, input_data, 1. / scale, zero_point, numel);

0 commit comments

Comments
 (0)