Skip to content

Commit 2519d06

Browse files
zonglinpengfacebook-github-bot
authored andcommitted
dequant_per_tensor (#6439)
Summary: ~ Reviewed By: hsharma35 Differential Revision: D64774100
1 parent fa30e80 commit 2519d06

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

backends/cadence/hifi/operators/dequantize_per_tensor.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,33 +15,37 @@ namespace impl {
1515
namespace HiFi {
1616
namespace native {
1717

18-
using executorch::aten::ScalarType;
19-
using executorch::aten::Tensor;
20-
using executorch::runtime::KernelRuntimeContext;
18+
using ::executorch::aten::ScalarType;
19+
using ::executorch::aten::Tensor;
20+
using ::executorch::runtime::KernelRuntimeContext;
21+
using ::cadence::impl::HiFi::kernels::dequantize;
2122

2223
void dequantize_per_tensor_out(
23-
KernelRuntimeContext& context,
24+
KernelRuntimeContext& ctx,
2425
const Tensor& input,
2526
double scale,
2627
int64_t zero_point,
27-
int64_t quant_min,
28-
int64_t quant_max,
28+
__ET_UNUSED int64_t quant_min,
29+
__ET_UNUSED int64_t quant_max,
2930
ScalarType dtype,
3031
Tensor& out) {
3132
float* out_data = out.mutable_data_ptr<float>();
32-
size_t numel = out.numel();
33-
33+
const size_t numel = out.numel();
3434
if (input.scalar_type() == ScalarType::Byte) {
3535
const uint8_t* input_data = input.const_data_ptr<uint8_t>();
36-
impl::HiFi::kernels::dequantize<uint8_t>(
36+
dequantize<uint8_t>(
3737
out_data, input_data, scale, zero_point, numel);
3838
} else if (input.scalar_type() == ScalarType::Char) {
3939
const int8_t* input_data = input.const_data_ptr<int8_t>();
4040
xa_nn_elm_dequantize_asym8s_f32(
4141
out_data, input_data, zero_point, scale, numel);
42+
} else if (input.scalar_type() == ScalarType::Short) {
43+
const int16_t* input_data = input.const_data_ptr<int16_t>();
44+
dequantize<int16_t>(
45+
out_data, input_data, scale, zero_point, numel);
4246
} else if (input.scalar_type() == ScalarType::Int) {
4347
const int32_t* input_data = input.const_data_ptr<int32_t>();
44-
impl::HiFi::kernels::dequantize<int32_t>(
48+
dequantize<int32_t>(
4549
out_data, input_data, scale, zero_point, numel);
4650
} else {
4751
ET_CHECK_MSG(false, "Unhandled input dtype %hhd", input.scalar_type());

0 commit comments

Comments
 (0)