|
16 | 16 | #include <immintrin.h> // x86 intrinsics |
17 | 17 | #endif |
18 | 18 |
|
| 19 | +#if defined(__ARM_NEON) |
| 20 | +#include <arm_neon.h> // ARM NEON intrinsics |
| 21 | +#endif |
| 22 | + |
| 23 | +#if defined(__ARM_FEATURE_SVE) |
| 24 | +#include <arm_sve.h> // ARM SVE intrinsics |
| 25 | +#endif |
| 26 | + |
19 | 27 | namespace ashvardanian::reduce { |
20 | 28 |
|
21 | 29 | /** |
@@ -454,7 +462,74 @@ class avx512_f32interleaving_t { |
454 | 462 |
|
455 | 463 | #endif // defined(__AVX512F__) |
456 | 464 |
|
457 | | -#pragma region Multicore |
| 465 | +#pragma endregion x86 |
| 466 | +#pragma region ARM |
| 467 | +#if defined(__ARM_NEON) |
| 468 | + |
| 469 | +/** |
| 470 | + * @brief Computes the sum of a sequence of float values using SIMD @b NEON intrinsics, |
| 471 | + * processing 128 bits (4 floats) per vector. |
| 472 | + */ |
| 473 | +class neon_f32_t { |
| 474 | + float const *const begin_ = nullptr; |
| 475 | + float const *const end_ = nullptr; |
| 476 | + |
| 477 | + public: |
| 478 | + neon_f32_t() = default; |
| 479 | + neon_f32_t(float const *b, float const *e) noexcept : begin_(b), end_(e) {} |
| 480 | + |
| 481 | + float operator()() const noexcept { |
| 482 | + auto const count_neon = (end_ - begin_) / 4; |
| 483 | + auto const last_neon_ptr = begin_ + count_neon * 4; |
| 484 | + auto it = begin_; |
| 485 | + |
| 486 | + float32x4_t running_sums = vdupq_n_f32(0.f); |
| 487 | + for (; it != last_neon_ptr; it += 4) running_sums = vaddq_f32(running_sums, vld1q_f32(it)); |
| 488 | + |
| 489 | + float running_sum = vaddvq_f32(running_sums); |
| 490 | + for (; it != end_; ++it) running_sum += *it; |
| 491 | + return running_sum; |
| 492 | + } |
| 493 | +}; |
| 494 | + |
| 495 | +#endif // defined(__ARM_NEON) |
| 496 | + |
| 497 | +#if defined(__ARM_FEATURE_SVE) |
| 498 | + |
| 499 | +/** |
| 500 | + * @brief Computes the sum of a sequence of float values using SIMD @b SVE intrinsics, |
| 501 | + * processing multiple entries per cycle. |
| 502 | + */ |
| 503 | +class sve_f32_t { |
| 504 | + float const *const begin_ = nullptr; |
| 505 | + float const *const end_ = nullptr; |
| 506 | + |
| 507 | + public: |
| 508 | + sve_f32_t() = default; |
| 509 | + sve_f32_t(float const *b, float const *e) noexcept : begin_(b), end_(e) {} |
| 510 | + |
| 511 | + float operator()() const noexcept { |
| 512 | + auto const sve_register_width = svcntw(); |
| 513 | + auto const input_size = static_cast<std::size_t>(end_ - begin_); |
| 514 | + |
| 515 | + svfloat32_t running_sums = svdup_f32(0.f); |
| 516 | + for (std::size_t start_offset = 0; start_offset < input_size; start_offset += sve_register_width) { |
| 517 | + svbool_t progress_vec = svwhilelt_b32(start_offset, input_size); |
| 518 | + running_sums = svadd_f32_m(progress_vec, running_sums, svld1(progress_vec, begin_ + start_offset)); |
| 519 | + } |
| 520 | + |
| 521 | + // No need to handle the tail separately |
| 522 | + float const running_sum = svaddv(svptrue_b32(), running_sums); |
| 523 | + return running_sum; |
| 524 | + } |
| 525 | +}; |
| 526 | + |
| 527 | +#endif // defined(__ARM_FEATURE_SVE__) |
| 528 | + |
| 529 | +#pragma endregion ARM |
| 530 | +#pragma endregion Handwritten SIMD Kernels |
| 531 | + |
| 532 | +#pragma region - Multicore |
458 | 533 |
|
459 | 534 | /** |
460 | 535 | * @brief Computes the sum of a sequence of float values using @b OpenMP on-CPU |
|
0 commit comments