|
16 | 16 | #include "hdr/stdint_proxy.h" |
17 | 17 | #include "src/__support/CPP/algorithm.h" |
18 | 18 | #include "src/__support/CPP/limits.h" |
| 19 | +#include "src/__support/CPP/tuple.h" |
19 | 20 | #include "src/__support/CPP/type_traits.h" |
| 21 | +#include "src/__support/CPP/utility/integer_sequence.h" |
20 | 22 | #include "src/__support/macros/attributes.h" |
21 | 23 | #include "src/__support/macros/config.h" |
22 | 24 |
|
@@ -51,6 +53,7 @@ template <typename T> LIBC_INLINE constexpr size_t native_vector_size = 1; |
51 | 53 | template <typename T> LIBC_INLINE constexpr T poison() { |
52 | 54 | return __builtin_nondeterministic_value(T()); |
53 | 55 | } |
| 56 | + |
54 | 57 | } // namespace internal |
55 | 58 |
|
56 | 59 | // Type aliases. |
@@ -273,6 +276,77 @@ LIBC_INLINE constexpr static simd<T, N> select(simd<bool, N> m, simd<T, N> x, |
273 | 276 | return m ? x : y; |
274 | 277 | } |
275 | 278 |
|
| 279 | +namespace internal { |
| 280 | +template <typename T, size_t N, size_t O, size_t... I> |
| 281 | +LIBC_INLINE constexpr static cpp::simd<T, sizeof...(I)> |
| 282 | +extend(cpp::simd<T, N> x, cpp::index_sequence<I...>) { |
| 283 | + return __builtin_shufflevector(x, x, (I < O ? static_cast<int>(I) : -1)...); |
| 284 | +} |
| 285 | +template <typename T, size_t N, size_t M, size_t O> |
| 286 | +LIBC_INLINE constexpr static auto extend(cpp::simd<T, N> x) { |
| 287 | + if constexpr (N == M) |
| 288 | + return x; |
| 289 | + else if constexpr (M <= 2 * N) |
| 290 | + return extend<T, N, M>(x, cpp::make_index_sequence<M>{}); |
| 291 | + else |
| 292 | + return extend<T, 2 * N, M, O>( |
| 293 | + extend<T, N, 2 * N>(x, cpp::make_index_sequence<2 * N>{})); |
| 294 | +} |
| 295 | +template <typename T, size_t N, size_t M, size_t... I> |
| 296 | +LIBC_INLINE constexpr static cpp::simd<T, N + M> |
| 297 | +concat(cpp::simd<T, N> x, cpp::simd<T, M> y, cpp::index_sequence<I...>) { |
| 298 | + constexpr size_t L = (N > M ? N : M); |
| 299 | + |
| 300 | + auto x_ext = extend<T, N, L, N>(x); |
| 301 | + auto y_ext = extend<T, M, L, M>(y); |
| 302 | + |
| 303 | + auto remap = [](size_t idx) -> int { |
| 304 | + if (idx < N) |
| 305 | + return static_cast<int>(idx); |
| 306 | + if (idx < N + M) |
| 307 | + return static_cast<int>((idx - N) + L); |
| 308 | + return -1; |
| 309 | + }; |
| 310 | + |
| 311 | + return __builtin_shufflevector(x_ext, y_ext, remap(I)...); |
| 312 | +} |
| 313 | + |
| 314 | +template <typename T, size_t N, size_t Count, size_t Offset, size_t... I> |
| 315 | +LIBC_INLINE constexpr static cpp::simd<T, Count> |
| 316 | +slice(cpp::simd<T, N> x, cpp::index_sequence<I...>) { |
| 317 | + return __builtin_shufflevector(x, x, (Offset + I)...); |
| 318 | +} |
| 319 | +template <typename T, size_t N, size_t Offset, size_t Head, size_t... Tail> |
| 320 | +LIBC_INLINE constexpr static auto split(cpp::simd<T, N> x) { |
| 321 | + auto first = cpp::make_tuple( |
| 322 | + slice<T, N, Head, Offset>(x, cpp::make_index_sequence<Head>{})); |
| 323 | + if constexpr (sizeof...(Tail) > 0) |
| 324 | + return cpp::tuple_cat(first, split<T, N, Offset + Head, Tail...>(x)); |
| 325 | + else |
| 326 | + return first; |
| 327 | +} |
| 328 | + |
| 329 | +} // namespace internal |
| 330 | + |
| 331 | +// Shuffling helpers. |
| 332 | +template <typename T, size_t N, size_t M> |
| 333 | +LIBC_INLINE constexpr static auto concat(cpp::simd<T, N> x, cpp::simd<T, M> y) { |
| 334 | + return internal::concat(x, y, make_index_sequence<N + M>{}); |
| 335 | +} |
| 336 | +template <typename T, size_t N, size_t M, typename... Rest> |
| 337 | +LIBC_INLINE constexpr static auto concat(cpp::simd<T, N> x, cpp::simd<T, M> y, |
| 338 | + Rest... rest) { |
| 339 | + auto xy = concat(x, y); |
| 340 | + if constexpr (sizeof...(Rest)) |
| 341 | + return concat(xy, rest...); |
| 342 | + else |
| 343 | + return xy; |
| 344 | +} |
| 345 | +template <size_t... Sizes, typename T, size_t N> auto split(cpp::simd<T, N> x) { |
| 346 | + static_assert((... + Sizes) == N, "split sizes must sum to vector size"); |
| 347 | + return internal::split<T, N, 0, Sizes...>(x); |
| 348 | +} |
| 349 | + |
276 | 350 | // TODO: where expressions, scalar overloads, ABI types. |
277 | 351 |
|
278 | 352 | } // namespace cpp |
|
0 commit comments