66 * LICENSE file in the root directory of this source tree.
77 */
88
9- #include < ATen/native/cpu/Elu.h >
9+ #include < cmath >
1010
11+ #include < ATen/cpu/vec/functional.h>
12+ #include < ATen/cpu/vec/vec.h>
1113#include < executorch/kernels/portable/cpu/scalar_utils.h>
1214#include < executorch/runtime/kernel/kernel_includes.h>
13- #include < executorch/runtime/kernel/thread_parallel_interface.h>
1415#include < executorch/runtime/platform/assert.h>
1516
1617namespace torch ::executor::native {
@@ -31,38 +32,24 @@ void elu(
3132 const auto math_alpha = utils::scalar_to<MathT>(alpha);
3233 const auto math_scale = utils::scalar_to<MathT>(scale);
3334 const auto math_input_scale = utils::scalar_to<MathT>(input_scale);
34- const auto scalar_func =
35- at::native::get_scalar_elu_elementwise_func<CTYPE, MathT>(
36- math_alpha, math_scale, math_input_scale);
37- const auto vec_func = at::native::get_vectorized_elu_elementwise_func<CTYPE>(
38- math_alpha, math_scale, math_input_scale);
3935
40- ::executorch::extension::parallel_for (
41- 0 ,
42- out.numel(),
43- ::executorch::extension::internal::GRAIN_SIZE,
44- [&](const auto begin, const auto end) {
45- using Vec = at::vec::Vectorized<CTYPE>;
46- const auto vectorized_begin =
47- begin + (Vec::size () - begin % Vec::size ()) % Vec::size ();
48- const auto vectorized_end = end - (end % Vec::size ());
49- // Scalar prologue.
50- for (const auto idx : c10::irange (begin, vectorized_begin)) {
51- out_data[idx] = scalar_func (in_data[idx]);
52- }
36+ using Vec = at::vec::Vectorized<CTYPE>;
37+ at::vec::map (
38+ [math_alpha, math_scale, math_input_scale](Vec x) {
39+ auto scaled_input = x * Vec (static_cast <CTYPE>(math_input_scale));
40+ auto zero = Vec (static_cast <CTYPE>(0 ));
41+ auto one = Vec (static_cast <CTYPE>(1 ));
42+ auto alpha_vec = Vec (static_cast <CTYPE>(math_alpha));
43+ auto scale_vec = Vec (static_cast <CTYPE>(math_scale));
5344
54- // Main vectorized loop.
55- for (auto idx = vectorized_begin; idx < vectorized_end;
56- idx += Vec::size ()) {
57- auto result_vec = vec_func (Vec::loadu (&in_data[idx]));
58- result_vec.store (&out_data[idx]);
59- }
60-
61- // Scalar epilogue.
62- for (const auto idx : c10::irange (vectorized_end, end)) {
63- out_data[idx] = scalar_func (in_data[idx]);
64- }
65- });
45+ auto pos_mask = scaled_input > zero;
46+ auto neg_result = alpha_vec * ((scaled_input.exp ()) - one);
47+ auto result = Vec::blendv (neg_result, scaled_input, pos_mask);
48+ return result * scale_vec;
49+ },
50+ out_data,
51+ in_data,
52+ out.numel ());
6653}
6754} // namespace
6855
0 commit comments