Skip to content

Commit 4b68f09

Browse files
committed
Add range(fun) utility function
1 parent 755a00e commit 4b68f09

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

include/kernel_float/iterate.h

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,35 @@ void for_each(V&& input, F fun) {
3030
namespace detail {
3131
template<typename T, size_t N>
3232
struct range_impl {
33-
KERNEL_FLOAT_INLINE
34-
static vector_storage<T, N> call() {
33+
template<typename F>
34+
KERNEL_FLOAT_INLINE static vector_storage<T, N> call(F fun) {
3535
vector_storage<T, N> result;
3636

3737
#pragma unroll
3838
for (size_t i = 0; i < N; i++) {
39-
result.data()[i] = T(i);
39+
result.data()[i] = fun(i);
4040
}
4141

4242
return result;
4343
}
4444
};
4545
} // namespace detail
4646

47+
/**
48+
* Generate vector consisting of the result `fun(0)...fun(N-1)`
49+
*
50+
* Example
51+
* =======
52+
* ```
53+
* // Returns [0.0f, 2.0f, 4.0f]
54+
* vec<float, 3> vec = range<3>([](auto i){ return float(i * 2.0f); });
55+
* ```
56+
*/
57+
template<size_t N, typename F, typename T = result_t<F, size_t>>
58+
KERNEL_FLOAT_INLINE vector<T, extent<N>> range(F fun) {
59+
return detail::range_impl<T, N>::call(fun);
60+
}
61+
4762
/**
4863
* Generate vector consisting of the numbers `0...N-1` of type `T`
4964
*
@@ -56,7 +71,7 @@ struct range_impl {
5671
*/
5772
template<typename T, size_t N>
5873
KERNEL_FLOAT_INLINE vector<T, extent<N>> range() {
59-
return detail::range_impl<T, N>::call();
74+
return detail::range_impl<T, N>::call(ops::cast<size_t, T>());
6075
}
6176

6277
/**
@@ -71,7 +86,7 @@ KERNEL_FLOAT_INLINE vector<T, extent<N>> range() {
7186
*/
7287
template<typename V>
7388
KERNEL_FLOAT_INLINE into_vector_type<V> range_like(const V& = {}) {
74-
return detail::range_impl<vector_value_type<V>, vector_extent<V>>::call();
89+
return range<vector_value_type<V>, vector_extent<V>>();
7590
}
7691

7792
/**
@@ -96,7 +111,7 @@ KERNEL_FLOAT_INLINE into_vector_type<V> range_like(const V& = {}) {
96111
*/
97112
template<typename T = size_t, typename V>
98113
KERNEL_FLOAT_INLINE vector<T, vector_extent_type<V>> each_index(const V& = {}) {
99-
return detail::range_impl<T, vector_extent<V>>::call();
114+
return range<T, vector_extent<V>>();
100115
}
101116

102117
namespace detail {

0 commit comments

Comments
 (0)