From 1f4905a920de3c0ee0c3103d9a9808871a6e6153 Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Wed, 12 Feb 2025 11:16:12 -0800 Subject: [PATCH] Fix HiFi relu for int8 Summary: As titled. Differential Revision: D69541729 --- .../cadence/hifi/operators/op_quantized_relu_out.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/backends/cadence/hifi/operators/op_quantized_relu_out.cpp b/backends/cadence/hifi/operators/op_quantized_relu_out.cpp index 28227b7cc92..a7c99378920 100644 --- a/backends/cadence/hifi/operators/op_quantized_relu_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_relu_out.cpp @@ -26,11 +26,11 @@ void quantized_relu_per_tensor_out( const int64_t out_multiplier, const int64_t out_shift, Tensor& output) { - const uint8_t _in_zero_point = static_cast(in_zero_point); - const uint8_t _out_zero_point = static_cast(out_zero_point); - const int32_t _out_multiplier = static_cast(out_multiplier); - const int32_t _out_shift = static_cast(out_shift); if (input.scalar_type() == executorch::aten::ScalarType::Byte) { + const uint8_t _in_zero_point = static_cast(in_zero_point); + const uint8_t _out_zero_point = static_cast(out_zero_point); + const int32_t _out_multiplier = static_cast(out_multiplier); + const int32_t _out_shift = static_cast(out_shift); const uint8_t* p_in = input.const_data_ptr(); uint8_t* p_out = output.mutable_data_ptr(); @@ -48,6 +48,10 @@ void quantized_relu_per_tensor_out( ET_CHECK_MSG(ret_val == 0, "An internal error occured"); } else if (input.scalar_type() == executorch::aten::ScalarType::Char) { + const int8_t _in_zero_point = static_cast(in_zero_point); + const int8_t _out_zero_point = static_cast(out_zero_point); + const int32_t _out_multiplier = static_cast(out_multiplier); + const int32_t _out_shift = static_cast(out_shift); const int8_t* p_in = input.const_data_ptr(); int8_t* p_out = output.mutable_data_ptr();