@@ -118,43 +118,63 @@ broadcast_like(const V& input, const R& other) {
118118
119119namespace detail {
120120
121- template <size_t N>
122- struct apply_recur_impl ;
123-
124121template <typename F, size_t N, typename Output, typename ... Args>
125122struct apply_impl {
126- KERNEL_FLOAT_INLINE static void call (F fun, Output* result, const Args*... inputs) {
127- apply_recur_impl<N>::call (fun, result, inputs...);
123+ KERNEL_FLOAT_INLINE static void call (F fun, Output* output, const Args*... args) {
124+ #pragma unroll
125+ for (size_t i = 0 ; i < N; i++) {
126+ output[i] = fun (args[i]...);
127+ }
128128 }
129129};
130130
131- template <size_t N>
132- struct apply_recur_impl {
133- static constexpr size_t K = round_up_to_power_of_two(N) / 2 ;
131+ template <typename F, size_t N, typename Output, typename ... Args>
132+ struct apply_fastmath_impl : apply_impl<F, N, Output, Args...> {};
133+
134+ template <typename F, size_t N, typename Output, typename ... Args>
135+ struct map_impl {
136+ static constexpr size_t packet_size = preferred_vector_size<Output>::value;
137+
138+ KERNEL_FLOAT_INLINE static void call (F fun, Output* output, const Args*... args) {
139+ if constexpr (N / packet_size > 0 ) {
140+ #pragma unroll
141+ for (size_t i = 0 ; i < N - N % packet_size; i += packet_size) {
142+ apply_impl<F, packet_size, Output, Args...>::call (fun, output + i, (args + i)...);
143+ }
144+ }
134145
135- template <typename F, typename Output, typename ... Args>
136- KERNEL_FLOAT_INLINE static void call (F fun, Output* result, const Args*... inputs) {
137- apply_impl<F, K, Output, Args...>::call (fun, result, inputs...);
138- apply_impl<F, N - K, Output, Args...>::call (fun, result + K, (inputs + K)...);
146+ if constexpr (N % packet_size > 0 ) {
147+ #pragma unroll
148+ for (size_t i = N - N % packet_size; i < N; i++) {
149+ apply_impl<F, 1 , Output, Args...>::call (fun, output + i, (args + i)...);
150+ }
151+ }
139152 }
140153};
141154
142- template <>
143- struct apply_recur_impl <0 > {
144- template <typename F, typename Output, typename ... Args>
145- KERNEL_FLOAT_INLINE static void call (F fun, Output* result, const Args*... inputs) {}
146- };
155+ template <typename F, size_t N, typename Output, typename ... Args>
156+ struct fast_map_impl {
157+ static constexpr size_t packet_size = preferred_vector_size<Output>::value;
158+
159+ KERNEL_FLOAT_INLINE static void call (F fun, Output* output, const Args*... args) {
160+ if constexpr (N / packet_size > 0 ) {
161+ #pragma unroll
162+ for (size_t i = 0 ; i < N - N % packet_size; i += packet_size) {
163+ apply_fastmath_impl<F, packet_size, Output, Args...>::call (
164+ fun,
165+ output + i,
166+ (args + i)...);
167+ }
168+ }
147169
148- template <>
149- struct apply_recur_impl <1 > {
150- template <typename F, typename Output, typename ... Args>
151- KERNEL_FLOAT_INLINE static void call (F fun, Output* result, const Args*... inputs) {
152- result[0 ] = fun (inputs[0 ]...);
170+ if constexpr (N % packet_size > 0 ) {
171+ #pragma unroll
172+ for (size_t i = N - N % packet_size; i < N; i++) {
173+ apply_fastmath_impl<F, 1 , Output, Args...>::call (fun, output + i, (args + i)...);
174+ }
175+ }
153176 }
154177};
155-
156- template <typename F, size_t N, typename Output, typename ... Args>
157- struct apply_fastmath_impl : apply_impl<F, N, Output, Args...> {};
158178} // namespace detail
159179
160180template <typename F, typename ... Args>
@@ -180,12 +200,12 @@ KERNEL_FLOAT_INLINE map_type<F, Args...> map(F fun, const Args&... args) {
180200 // Use the `apply_fastmath_impl` if KERNEL_FLOAT_FAST_MATH is enabled
181201#if KERNEL_FLOAT_FAST_MATH
182202 using apply_impl =
183- detail::apply_fastmath_impl <F, extent_size<E>, Output, vector_value_type<Args>...>;
203+ detail::fast_math_impl <F, extent_size<E>, Output, vector_value_type<Args>...>;
184204#else
185- using apply_impl = detail::apply_impl <F, extent_size<E>, Output, vector_value_type<Args>...>;
205+ using map_impl = detail::map_impl <F, extent_size<E>, Output, vector_value_type<Args>...>;
186206#endif
187207
188- apply_impl ::call (
208+ map_impl ::call (
189209 fun,
190210 result.data (),
191211 (detail::broadcast_impl<vector_value_type<Args>, vector_extent_type<Args>, E>::call (
@@ -205,7 +225,7 @@ KERNEL_FLOAT_INLINE map_type<F, Args...> fast_map(F fun, const Args&... args) {
205225 using E = broadcast_vector_extent_type<Args...>;
206226 vector_storage<Output, extent_size<E>> result;
207227
208- detail::apply_fastmath_impl <F, extent_size<E>, Output, vector_value_type<Args>...>::call (
228+ detail::fast_map_impl <F, extent_size<E>, Output, vector_value_type<Args>...>::call (
209229 fun,
210230 result.data (),
211231 (detail::broadcast_impl<vector_value_type<Args>, vector_extent_type<Args>, E>::call (
0 commit comments