1+ #ifndef KERNEL_FLOAT_APPLY_H
2+ #define KERNEL_FLOAT_APPLY_H
3+
4+ #include " base.h"
5+
6+ namespace kernel_float {
7+ namespace detail {
8+
9+ template <typename ... Es>
10+ struct broadcast_extent_helper ;
11+
12+ template <typename E>
13+ struct broadcast_extent_helper <E> {
14+ using type = E;
15+ };
16+
17+ template <size_t N>
18+ struct broadcast_extent_helper <extent<N>, extent<N>> {
19+ using type = extent<N>;
20+ };
21+
22+ template <size_t N>
23+ struct broadcast_extent_helper <extent<1 >, extent<N>> {
24+ using type = extent<N>;
25+ };
26+
27+ template <size_t N>
28+ struct broadcast_extent_helper <extent<N>, extent<1 >> {
29+ using type = extent<N>;
30+ };
31+
32+ template <>
33+ struct broadcast_extent_helper <extent<1 >, extent<1 >> {
34+ using type = extent<1 >;
35+ };
36+
37+ template <typename A, typename B, typename C, typename ... Rest>
38+ struct broadcast_extent_helper <A, B, C, Rest...>:
39+ broadcast_extent_helper<typename broadcast_extent_helper<A, B>::type, C, Rest...> {};
40+
41+ } // namespace detail
42+
43+ template <typename ... Es>
44+ using broadcast_extent = typename detail::broadcast_extent_helper<Es...>::type;
45+
46+ template <typename ... Vs>
47+ using broadcast_vector_extent_type = broadcast_extent<vector_extent_type<Vs>...>;
48+
49+ template <typename From, typename To>
50+ static constexpr bool is_broadcastable = is_same_type<broadcast_extent<From, To>, To>;
51+
52+ template <typename V, typename To>
53+ static constexpr bool is_vector_broadcastable = is_broadcastable<vector_extent_type<V>, To>;
54+
55+ namespace detail {
56+
57+ template <typename T, typename From, typename To>
58+ struct broadcast_impl ;
59+
60+ template <typename T, size_t N>
61+ struct broadcast_impl <T, extent<1 >, extent<N>> {
62+ KERNEL_FLOAT_INLINE static vector_storage<T, N> call (const vector_storage<T, 1 >& input) {
63+ vector_storage<T, N> output;
64+ for (size_t i = 0 ; i < N; i++) {
65+ output.data ()[i] = input.data ()[0 ];
66+ }
67+ return output;
68+ }
69+ };
70+
71+ template <typename T, size_t N>
72+ struct broadcast_impl <T, extent<N>, extent<N>> {
73+ KERNEL_FLOAT_INLINE static vector_storage<T, N> call (vector_storage<T, N> input) {
74+ return input;
75+ }
76+ };
77+
78+ template <typename T>
79+ struct broadcast_impl <T, extent<1 >, extent<1 >> {
80+ KERNEL_FLOAT_INLINE static vector_storage<T, 1 > call (vector_storage<T, 1 > input) {
81+ return input;
82+ }
83+ };
84+
85+ } // namespace detail
86+
87+ /* *
88+ * Takes the given vector `input` and extends its size to a length of `N`. This is only valid if the size of `input`
89+ * is 1 or `N`.
90+ *
91+ * Example
92+ * =======
93+ * ```
94+ * vec<float, 1> a = {1.0f};
95+ * vec<float, 5> x = broadcast<5>(a); // Returns [1.0f, 1.0f, 1.0f, 1.0f, 1.0f]
96+ *
97+ * vec<float, 5> b = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
98+ * vec<float, 5> y = broadcast<5>(b); // Returns [1.0f, 2.0f, 3.0f, 4.0f, 5.0f]
99+ * ```
100+ */
101+ template <size_t N, typename V>
102+ KERNEL_FLOAT_INLINE vector<vector_value_type<V>, extent<N>>
103+ broadcast (const V& input, extent<N> new_size = {}) {
104+ using T = vector_value_type<V>;
105+ return detail::broadcast_impl<T, vector_extent_type<V>, extent<N>>::call (
106+ into_vector_storage (input));
107+ }
108+
109+ /* *
110+ * Takes the given vector `input` and extends its size to the same length as vector `other`. This is only valid if the
111+ * size of `input` is 1 or the same as `other`.
112+ */
113+ template <typename V, typename R>
114+ KERNEL_FLOAT_INLINE vector<vector_value_type<V>, vector_extent_type<R>>
115+ broadcast_like (const V& input, const R& other) {
116+ return broadcast (input, vector_extent_type<R> {});
117+ }
118+
119+ namespace detail {
120+
121+ template <size_t N>
122+ struct apply_recur_impl ;
123+
124+ template <typename F, size_t N, typename Output, typename ... Args>
125+ struct apply_impl {
126+ KERNEL_FLOAT_INLINE static void call (F fun, Output* result, const Args*... inputs) {
127+ apply_recur_impl<N>::call (fun, result, inputs...);
128+ }
129+ };
130+
131+ template <size_t N>
132+ struct apply_recur_impl {
133+ static constexpr size_t K = round_up_to_power_of_two(N) / 2 ;
134+
135+ template <typename F, typename Output, typename ... Args>
136+ KERNEL_FLOAT_INLINE static void call (F fun, Output* result, const Args*... inputs) {
137+ apply_impl<F, K, Output, Args...>::call (fun, result, inputs...);
138+ apply_impl<F, N - K, Output, Args...>::call (fun, result + K, (inputs + K)...);
139+ }
140+ };
141+
142+ template <>
143+ struct apply_recur_impl <0 > {
144+ template <typename F, typename Output, typename ... Args>
145+ KERNEL_FLOAT_INLINE static void call (F fun, Output* result, const Args*... inputs) {}
146+ };
147+
148+ template <>
149+ struct apply_recur_impl <1 > {
150+ template <typename F, typename Output, typename ... Args>
151+ KERNEL_FLOAT_INLINE static void call (F fun, Output* result, const Args*... inputs) {
152+ result[0 ] = fun (inputs[0 ]...);
153+ }
154+ };
155+ } // namespace detail
156+
157+ template <typename F, typename ... Args>
158+ using map_type =
159+ vector<result_t <F, vector_value_type<Args>...>, broadcast_vector_extent_type<Args...>>;
160+
161+ /* *
162+ * Apply the function `F` to each element from the vector `input` and return the results as a new vector.
163+ *
164+ * Examples
165+ * ========
166+ * ```
167+ * vec<float, 4> input = {1.0f, 2.0f, 3.0f, 4.0f};
168+ * vec<float, 4> squared = map([](auto x) { return x * x; }, input); // [1.0f, 4.0f, 9.0f, 16.0f]
169+ * ```
170+ */
171+ template <typename F, typename ... Args>
172+ KERNEL_FLOAT_INLINE map_type<F, Args...> map (F fun, const Args&... args) {
173+ using Output = result_t <F, vector_value_type<Args>...>;
174+ using E = broadcast_vector_extent_type<Args...>;
175+ vector_storage<Output, E::value> result;
176+
177+ detail::apply_impl<F, E::value, Output, vector_value_type<Args>...>::call (
178+ fun,
179+ result.data (),
180+ (detail::broadcast_impl<vector_value_type<Args>, vector_extent_type<Args>, E>::call (
181+ into_vector_storage (args))
182+ .data ())...);
183+
184+ return result;
185+ }
186+
187+ } // namespace kernel_float
188+
189+ #endif // KERNEL_FLOAT_APPLY_H
0 commit comments