@@ -15,33 +15,37 @@ namespace impl {
1515namespace HiFi {
1616namespace native {
1717
18- using executorch::aten::ScalarType;
19- using executorch::aten::Tensor;
20- using executorch::runtime::KernelRuntimeContext;
18+ using ::executorch::aten::ScalarType;
19+ using ::executorch::aten::Tensor;
20+ using ::executorch::runtime::KernelRuntimeContext;
21+ using ::cadence::impl::HiFi::kernels::dequantize;
2122
2223void dequantize_per_tensor_out (
23- KernelRuntimeContext& context ,
24+ KernelRuntimeContext& ctx ,
2425 const Tensor& input,
2526 double scale,
2627 int64_t zero_point,
27- int64_t quant_min,
28- int64_t quant_max,
28+ __ET_UNUSED int64_t quant_min,
29+ __ET_UNUSED int64_t quant_max,
2930 ScalarType dtype,
3031 Tensor& out) {
3132 float * out_data = out.mutable_data_ptr <float >();
32- size_t numel = out.numel ();
33-
33+ const size_t numel = out.numel ();
3434 if (input.scalar_type () == ScalarType::Byte) {
3535 const uint8_t * input_data = input.const_data_ptr <uint8_t >();
36- impl::HiFi::kernels:: dequantize<uint8_t >(
36+ dequantize<uint8_t >(
3737 out_data, input_data, scale, zero_point, numel);
3838 } else if (input.scalar_type () == ScalarType::Char) {
3939 const int8_t * input_data = input.const_data_ptr <int8_t >();
4040 xa_nn_elm_dequantize_asym8s_f32 (
4141 out_data, input_data, zero_point, scale, numel);
42+ } else if (input.scalar_type () == ScalarType::Short) {
43+ const int16_t * input_data = input.const_data_ptr <int16_t >();
44+ dequantize<int16_t >(
45+ out_data, input_data, scale, zero_point, numel);
4246 } else if (input.scalar_type () == ScalarType::Int) {
4347 const int32_t * input_data = input.const_data_ptr <int32_t >();
44- impl::HiFi::kernels:: dequantize<int32_t >(
48+ dequantize<int32_t >(
4549 out_data, input_data, scale, zero_point, numel);
4650 } else {
4751 ET_CHECK_MSG (false , " Unhandled input dtype %hhd" , input.scalar_type ());
0 commit comments