@@ -45,6 +45,60 @@ void quantized_relu_(
4545 }
4646}
4747
48+ void quantized_relu_per_tensor_out (
49+ KernelRuntimeContext& ctx,
50+ const Tensor& input,
51+ const int64_t in_zero_point,
52+ const int64_t out_zero_point,
53+ const int64_t out_multiplier,
54+ const int64_t out_shift,
55+ Tensor& output) {
56+ const uint8_t _in_zero_point = static_cast <uint8_t >(in_zero_point);
57+ const uint8_t _out_zero_point = static_cast <uint8_t >(out_zero_point);
58+ const int32_t _out_multiplier = static_cast <int32_t >(out_multiplier);
59+ const int32_t _out_shift = static_cast <int32_t >(out_shift);
60+ if (input.scalar_type () == executorch::aten::ScalarType::Byte) {
61+ const uint8_t * p_in = input.const_data_ptr <uint8_t >();
62+ uint8_t * p_out = output.mutable_data_ptr <uint8_t >();
63+
64+ WORD32 ret_val = xa_nn_vec_relu_asym8u_asym8u (
65+ p_out,
66+ p_in,
67+ _in_zero_point,
68+ _out_multiplier,
69+ _out_shift,
70+ _out_zero_point,
71+ _out_zero_point,
72+ 255 ,
73+ input.numel ());
74+
75+ ET_CHECK_MSG (ret_val == 0 , " An internal error occured" );
76+
77+ } else if (input.scalar_type () == executorch::aten::ScalarType::Char) {
78+ const int8_t * p_in = input.const_data_ptr <int8_t >();
79+ int8_t * p_out = output.mutable_data_ptr <int8_t >();
80+
81+ WORD32 ret_val = xa_nn_vec_relu_asym8s_asym8s (
82+ p_out,
83+ p_in,
84+ _in_zero_point,
85+ _out_multiplier,
86+ _out_shift,
87+ _out_zero_point,
88+ _out_zero_point,
89+ 127 ,
90+ input.numel ());
91+
92+ ET_CHECK_MSG (ret_val == 0 , " An internal error occured" );
93+
94+ } else {
95+ ET_CHECK_MSG (
96+ false ,
97+ " Unhandled input dtype %hhd" ,
98+ static_cast <int8_t >(input.scalar_type ()));
99+ }
100+ }
101+
48102void quantized_relu_per_tensor_out (
49103 KernelRuntimeContext& ctx,
50104 const Tensor& input,
0 commit comments