@@ -48,24 +48,27 @@ void quantized_relu_(
4848void quantized_relu_per_tensor_out (
4949 KernelRuntimeContext& ctx,
5050 const Tensor& input,
51- const Tensor& in_zero_point,
51+ const int64_t in_zero_point,
5252 const int64_t out_zero_point,
53- const Tensor& out_multiplier,
54- const Tensor& out_shift,
53+ const int64_t out_multiplier,
54+ const int64_t out_shift,
5555 Tensor& output) {
56+ const uint8_t _in_zero_point = static_cast <uint8_t >(in_zero_point);
57+ const uint8_t _out_zero_point = static_cast <uint8_t >(out_zero_point);
58+ const int32_t _out_multiplier = static_cast <int32_t >(out_multiplier);
59+ const int32_t _out_shift = static_cast <int32_t >(out_shift);
5660 if (input.scalar_type () == executorch::aten::ScalarType::Byte) {
5761 const uint8_t * p_in = input.const_data_ptr <uint8_t >();
5862 uint8_t * p_out = output.mutable_data_ptr <uint8_t >();
59- uint8_t q_zero_point = in_zero_point.const_data_ptr <uint8_t >()[0 ];
6063
6164 WORD32 ret_val = xa_nn_vec_relu_asym8u_asym8u (
6265 p_out,
6366 p_in,
64- ( int )q_zero_point ,
65- out_multiplier. const_data_ptr < int32_t >()[ 0 ] ,
66- out_shift. const_data_ptr < int32_t >()[ 0 ] ,
67- ( int )out_zero_point ,
68- ( int )out_zero_point ,
67+ _in_zero_point ,
68+ _out_multiplier ,
69+ _out_shift ,
70+ _out_zero_point ,
71+ _out_zero_point ,
6972 255 ,
7073 input.numel ());
7174
@@ -74,16 +77,15 @@ void quantized_relu_per_tensor_out(
7477 } else if (input.scalar_type () == executorch::aten::ScalarType::Char) {
7578 const int8_t * p_in = input.const_data_ptr <int8_t >();
7679 int8_t * p_out = output.mutable_data_ptr <int8_t >();
77- int8_t q_zero_point = in_zero_point.const_data_ptr <int8_t >()[0 ];
7880
7981 WORD32 ret_val = xa_nn_vec_relu_asym8s_asym8s (
8082 p_out,
8183 p_in,
82- ( int )q_zero_point ,
83- out_multiplier. const_data_ptr < int32_t >()[ 0 ] ,
84- out_shift. const_data_ptr < int32_t >()[ 0 ] ,
85- ( int )out_zero_point ,
86- ( int )out_zero_point ,
84+ _in_zero_point ,
85+ _out_multiplier ,
86+ _out_shift ,
87+ _out_zero_point ,
88+ _out_zero_point ,
8789 127 ,
8890 input.numel ());
8991
@@ -97,6 +99,30 @@ void quantized_relu_per_tensor_out(
9799 }
98100}
99101
102+ void quantized_relu_per_tensor_out (
103+ KernelRuntimeContext& ctx,
104+ const Tensor& input,
105+ const Tensor& in_zero_point,
106+ const int64_t out_zero_point,
107+ const Tensor& out_multiplier,
108+ const Tensor& out_shift,
109+ Tensor& output) {
110+ const uint8_t * p_in = input.const_data_ptr <uint8_t >();
111+ uint8_t * p_out = output.mutable_data_ptr <uint8_t >();
112+ uint8_t _in_zero_point = in_zero_point.const_data_ptr <uint8_t >()[0 ];
113+ int32_t _out_multiplier = out_multiplier.const_data_ptr <int32_t >()[0 ];
114+ int32_t _out_shift = out_shift.const_data_ptr <int32_t >()[0 ];
115+
116+ quantized_relu_per_tensor_out (
117+ ctx,
118+ input,
119+ _in_zero_point,
120+ out_zero_point,
121+ _out_multiplier,
122+ _out_shift,
123+ output);
124+ }
125+
100126} // namespace native
101127} // namespace HiFi
102128} // namespace impl
0 commit comments