@@ -15,34 +15,35 @@ namespace impl {
1515namespace HiFi {
1616namespace native {
1717
18- using executorch::aten::ScalarType;
19- using executorch::aten::Tensor;
20- using executorch::runtime::KernelRuntimeContext;
18+ using ::cadence::impl::HiFi::kernels::dequantize;
19+ using ::executorch::aten::ScalarType;
20+ using ::executorch::aten::Tensor;
21+ using ::executorch::runtime::KernelRuntimeContext;
2122
2223void dequantize_per_tensor_out (
23- KernelRuntimeContext& context ,
24+ KernelRuntimeContext& ctx ,
2425 const Tensor& input,
2526 double scale,
2627 int64_t zero_point,
27- int64_t quant_min,
28- int64_t quant_max,
28+ __ET_UNUSED int64_t quant_min,
29+ __ET_UNUSED int64_t quant_max,
2930 ScalarType dtype,
3031 Tensor& out) {
3132 float * out_data = out.mutable_data_ptr <float >();
32- size_t numel = out.numel ();
33-
33+ const size_t numel = out.numel ();
3434 if (input.scalar_type () == ScalarType::Byte) {
3535 const uint8_t * input_data = input.const_data_ptr <uint8_t >();
36- impl::HiFi::kernels::dequantize<uint8_t >(
37- out_data, input_data, scale, zero_point, numel);
36+ dequantize<uint8_t >(out_data, input_data, scale, zero_point, numel);
3837 } else if (input.scalar_type () == ScalarType::Char) {
3938 const int8_t * input_data = input.const_data_ptr <int8_t >();
4039 xa_nn_elm_dequantize_asym8s_f32 (
4140 out_data, input_data, zero_point, scale, numel);
41+ } else if (input.scalar_type () == ScalarType::Short) {
42+ const int16_t * input_data = input.const_data_ptr <int16_t >();
43+ dequantize<int16_t >(out_data, input_data, scale, zero_point, numel);
4244 } else if (input.scalar_type () == ScalarType::Int) {
4345 const int32_t * input_data = input.const_data_ptr <int32_t >();
44- impl::HiFi::kernels::dequantize<int32_t >(
45- out_data, input_data, scale, zero_point, numel);
46+ dequantize<int32_t >(out_data, input_data, scale, zero_point, numel);
4647 } else {
4748 ET_CHECK_MSG (false , " Unhandled input dtype %hhd" , input.scalar_type ());
4849 }
0 commit comments