@@ -20,28 +20,31 @@ using executorch::aten::Tensor;
2020using executorch::runtime::KernelRuntimeContext;
2121
2222void dequantize_per_tensor_out (
23- KernelRuntimeContext& context ,
23+ KernelRuntimeContext& ctx ,
2424 const Tensor& input,
2525 double scale,
2626 int64_t zero_point,
27- int64_t quant_min,
28- int64_t quant_max,
27+ __ET_UNUSED int64_t quant_min,
28+ __ET_UNUSED int64_t quant_max,
2929 ScalarType dtype,
3030 Tensor& out) {
3131 float * out_data = out.mutable_data_ptr <float >();
32- size_t numel = out.numel ();
33-
32+ const size_t numel = out.numel ();
3433 if (input.scalar_type () == ScalarType::Byte) {
3534 const uint8_t * input_data = input.const_data_ptr <uint8_t >();
36- impl::HiFi:: kernels::dequantize<uint8_t >(
35+ kernels::dequantize<uint8_t >(
3736 out_data, input_data, scale, zero_point, numel);
3837 } else if (input.scalar_type () == ScalarType::Char) {
3938 const int8_t * input_data = input.const_data_ptr <int8_t >();
4039 xa_nn_elm_dequantize_asym8s_f32 (
4140 out_data, input_data, zero_point, scale, numel);
41+ } else if (input.scalar_type () == ScalarType::Short) {
42+ const int16_t * input_data = input.const_data_ptr <int16_t >();
43+ kernels::dequantize<int16_t >(
44+ out_data, input_data, scale, zero_point, numel);
4245 } else if (input.scalar_type () == ScalarType::Int) {
4346 const int32_t * input_data = input.const_data_ptr <int32_t >();
44- impl::HiFi:: kernels::dequantize<int32_t >(
47+ kernels::dequantize<int32_t >(
4548 out_data, input_data, scale, zero_point, numel);
4649 } else {
4750 ET_CHECK_MSG (false , " Unhandled input dtype %hhd" , input.scalar_type ());
0 commit comments