@@ -15,28 +15,25 @@ struct zip_helper {
1515 template <size_t ... Is>
1616 KERNEL_FLOAT_INLINE static Output
1717 call_with_indices (F fun, const Left& left, const Right& right, index_sequence<Is...> = {}) {
18- return Output { fun (left. get (const_index <Is> {} ), right. get (const_index <Is> {} ))...} ;
18+ return vector_traits< Output>:: create ( fun (vector_get <Is>(left ), vector_get <Is>(right ))...) ;
1919 }
2020};
2121
22- template <typename F, typename T, typename L, typename R, size_t N>
23- struct zip_helper <F, vector_compound<T, N>, vector_compound<L, N>, vector_compound<R, N>> {
24- KERNEL_FLOAT_INLINE static vector_compound<T, N>
25- call (F fun, const vector_compound<L, N>& left, const vector_compound<R, N>& right) {
26- static constexpr size_t low_size = vector_compound<T, N>::low_size;
27- static constexpr size_t high_size = vector_compound<T, N>::high_size;
28-
29- return {
30- zip_helper<
31- F,
32- vector_storage<T, low_size>,
33- vector_storage<L, low_size>,
34- vector_storage<R, low_size>>::call (fun, left.low (), right.low ()),
35- zip_helper<
36- F,
37- vector_storage<T, high_size>,
38- vector_storage<L, high_size>,
39- vector_storage<R, high_size>>::call (fun, left.high (), right.high ())};
22+ template <typename F, typename V, size_t N>
23+ struct zip_helper <F, nested_array<V, N>, nested_array<V, N>, nested_array<V, N>> {
24+ KERNEL_FLOAT_INLINE static nested_array<V, N>
25+ call (F fun, const nested_array<V, N>& left, const nested_array<V, N>& right) {
26+ return call (fun, left, right, make_index_sequence<nested_array<V, N>::num_packets> {});
27+ }
28+
29+ private:
30+ template <size_t ... Is>
31+ KERNEL_FLOAT_INLINE static nested_array<V, N> call (
32+ F fun,
33+ const nested_array<V, N>& left,
34+ const nested_array<V, N>& right,
35+ index_sequence<Is...>) {
36+ return {zip_helper<F, V, V, V>::call (fun, left[Is], right[Is])...};
4037 }
4138};
4239}; // namespace detail
@@ -48,7 +45,7 @@ template<typename... Ts>
4845static constexpr size_t common_vector_size = common_size<vector_size<Ts>...>;
4946
5047template <typename F, typename L, typename R>
51- using zip_type = vector_storage <
48+ using zip_type = default_storage_type <
5249 result_t <F, vector_value_type<L>, vector_value_type<R>>,
5350 common_vector_size<L, R>>;
5451
@@ -63,16 +60,19 @@ using zip_type = vector_storage<
6360 * ``zip_common`` for that functionality.
6461 */
6562template <typename F, typename Left, typename Right, typename Output = zip_type<F, Left, Right>>
66- KERNEL_FLOAT_INLINE Output zip (F fun, Left&& left, Right&& right) {
63+ KERNEL_FLOAT_INLINE vector< Output> zip (F fun, Left&& left, Right&& right) {
6764 static constexpr size_t N = vector_size<Output>;
68- return detail::zip_helper<F, Output, into_vector_type<Left>, into_vector_type<Right>>::call (
65+ using LeftInput = default_storage_type<vector_value_type<Left>, N>;
66+ using RightInput = default_storage_type<vector_value_type<Right>, N>;
67+
68+ return detail::zip_helper<F, Output, LeftInput, RightInput>::call (
6969 fun,
70- broadcast<N >(std::forward<Left>(left)),
71- broadcast<N >(std::forward<Right>(right)));
70+ broadcast<LeftInput, Left >(std::forward<Left>(left)),
71+ broadcast<RightInput, Right >(std::forward<Right>(right)));
7272}
7373
7474template <typename F, typename L, typename R>
75- using zip_common_type = vector_storage <
75+ using zip_common_type = default_storage_type <
7676 result_t <F, common_vector_value_type<L, R>, common_vector_value_type<L, R>>,
7777 common_vector_size<L, R>>;
7878
@@ -99,38 +99,50 @@ template<
9999 typename Left,
100100 typename Right,
101101 typename Output = zip_common_type<F, Left, Right>>
102- KERNEL_FLOAT_INLINE Output zip_common (F fun, Left&& left, Right&& right) {
102+ KERNEL_FLOAT_INLINE vector< Output> zip_common (F fun, Left&& left, Right&& right) {
103103 static constexpr size_t N = vector_size<Output>;
104104 using C = common_t <vector_value_type<Left>, vector_value_type<Right>>;
105+ using Input = default_storage_type<C, N>;
105106
106- return detail::zip_helper<F, Output, vector_storage<C, N>, vector_storage<C, N> >::call (
107+ return detail::zip_helper<F, Output, Input, Input >::call (
107108 fun,
108- broadcast<C, N >(std::forward<Left>(left)),
109- broadcast<C, N >(std::forward<Right>(right)));
109+ broadcast<Input, Left >(std::forward<Left>(left)),
110+ broadcast<Input, Right >(std::forward<Right>(right)));
110111}
111112
112- #define KERNEL_FLOAT_DEFINE_BINARY (NAME, EXPR ) \
113- namespace ops { \
114- template <typename T> \
115- struct NAME { \
116- KERNEL_FLOAT_INLINE T operator ()(T left, T right) { \
117- return T (EXPR); \
118- } \
119- }; \
120- } \
121- template <typename L, typename R, typename C = common_vector_value_type<L, R>> \
122- KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, L, R> NAME (L&& left, R&& right) { \
123- return zip_common (ops::NAME<C> {}, std::forward<L>(left), std::forward<R>(right)); \
113+ #define KERNEL_FLOAT_DEFINE_BINARY (NAME, EXPR ) \
114+ namespace ops { \
115+ template <typename T> \
116+ struct NAME { \
117+ KERNEL_FLOAT_INLINE T operator ()(T left, T right) { \
118+ return T (EXPR); \
119+ } \
120+ }; \
121+ } \
122+ template <typename L, typename R, typename C = common_vector_value_type<L, R>> \
123+ KERNEL_FLOAT_INLINE vector< zip_common_type<ops::NAME<C>, L, R>> NAME (L&& left, R&& right) { \
124+ return zip_common (ops::NAME<C> {}, std::forward<L>(left), std::forward<R>(right)); \
124125 }
125126
126- #define KERNEL_FLOAT_DEFINE_BINARY_OP (NAME, OP ) \
127- KERNEL_FLOAT_DEFINE_BINARY (NAME, left OP right) \
128- template < \
129- typename L, \
130- typename R, \
131- typename C = enabled_t <is_vector<L> || is_vector<R>, common_vector_value_type<L, R>>> \
132- KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, L, R> operator OP (L&& left, R&& right) { \
133- return zip_common (ops::NAME<C> {}, std::forward<L>(left), std::forward<R>(right)); \
127+ #define KERNEL_FLOAT_DEFINE_BINARY_OP (NAME, OP ) \
128+ KERNEL_FLOAT_DEFINE_BINARY (NAME, left OP right) \
129+ template <typename L, typename R, typename C = common_vector_value_type<L, R>> \
130+ KERNEL_FLOAT_INLINE vector<zip_common_type<ops::NAME<C>, L, R>> operator OP ( \
131+ const vector<L>& left, \
132+ const vector<R>& right) { \
133+ return zip_common (ops::NAME<C> {}, left, right); \
134+ } \
135+ template <typename L, typename R, typename C = common_vector_value_type<L, R>> \
136+ KERNEL_FLOAT_INLINE vector<zip_common_type<ops::NAME<C>, L, R>> operator OP ( \
137+ const vector<L>& left, \
138+ const R& right) { \
139+ return zip_common (ops::NAME<C> {}, left, right); \
140+ } \
141+ template <typename L, typename R, typename C = common_vector_value_type<L, R>> \
142+ KERNEL_FLOAT_INLINE vector<zip_common_type<ops::NAME<C>, L, R>> operator OP ( \
143+ const L& left, \
144+ const vector<R>& right) { \
145+ return zip_common (ops::NAME<C> {}, left, right); \
134146 }
135147
136148KERNEL_FLOAT_DEFINE_BINARY_OP (add, +)
@@ -153,7 +165,6 @@ KERNEL_FLOAT_DEFINE_BINARY_OP(bit_xor, ^)
153165// clang-format off
154166template <template <typename T> typename F, typename L, typename R>
155167static constexpr bool vector_assign_allowed =
156- is_vector<L> &&
157168 common_vector_size<L, R> == vector_size<L> &&
158169 is_implicit_convertible<
159170 result_t <
@@ -170,9 +181,9 @@ static constexpr bool vector_assign_allowed =
170181 typename L, \
171182 typename R, \
172183 typename T = enabled_t <vector_assign_allowed<ops::NAME, L, R>, vector_value_type<L>>> \
173- KERNEL_FLOAT_INLINE L & operator OP (L & lhs, R&& rhs) { \
184+ KERNEL_FLOAT_INLINE vector<L> & operator OP (vector<L> & lhs, const R& rhs) { \
174185 using F = ops::NAME<T>; \
175- lhs = zip_common<F, L&, R, into_vector_type<L>> (F {}, lhs, std::forward<R>( rhs)); \
186+ lhs = zip_common<F, const L&, const R&, L> (F {}, lhs. storage (), rhs); \
176187 return lhs; \
177188 }
178189
0 commit comments