|
5 | 5 | #include <omp.h> // `#pragma omp` |
6 | 6 | #include <thread> // `std::thread` |
7 | 7 |
|
8 | | -#if defined(__AVX2__) |
9 | | -#include <immintrin.h> // AVX2 intrinsics |
| 8 | +#if defined(__AVX2__) || defined(__AVX512F__) |
| 9 | +#include <immintrin.h> // x86 intrinsics |
10 | 10 | #endif |
11 | 11 |
|
12 | 12 | namespace ashvardanian::reduce { |
@@ -222,6 +222,99 @@ struct avx2_f32aligned_t { |
222 | 222 |
|
223 | 223 | #endif |
224 | 224 |
|
| 225 | +#if defined(__AVX512F__) |
| 226 | + |
| 227 | +/// Computes the sum of a sequence of float values using SIMD @b AVX-512 intrinsics, |
| 228 | +/// using streaming loads and bidirectional accumulation into 2 separate ZMM registers. |
| 229 | +struct avx512_f32streamed_t { |
| 230 | + float const *const begin_ = nullptr; |
| 231 | + float const *const end_ = nullptr; |
| 232 | + |
| 233 | + float operator()() const noexcept { |
| 234 | + auto it_begin = begin_; |
| 235 | + auto it_end = end_; |
| 236 | + |
| 237 | + __m512 acc1 = _mm512_set1_ps(0.0f); // Accumulator for forward direction |
| 238 | + __m512 acc2 = _mm512_set1_ps(0.0f); // Accumulator for reverse direction |
| 239 | + |
| 240 | + // Process in chunks of 32 floats in each direction |
| 241 | + for (; it_end - it_begin >= 64; it_begin += 32, it_end -= 32) { |
| 242 | + acc1 = _mm512_add_ps(acc1, _mm512_castsi512_ps(_mm512_stream_load_si512((void *)(it_begin)))); |
| 243 | + acc2 = _mm512_add_ps(acc2, _mm512_castsi512_ps(_mm512_stream_load_si512((void *)(it_end - 32)))); |
| 244 | + } |
| 245 | + if (it_end - it_begin >= 32) { |
| 246 | + acc1 = _mm512_add_ps(acc1, _mm512_castsi512_ps(_mm512_stream_load_si512((void *)(it_begin)))); |
| 247 | + it_begin += 32; |
| 248 | + } |
| 249 | + |
| 250 | + // Combine the accumulators |
| 251 | + __m512 acc = _mm512_add_ps(acc1, acc2); |
| 252 | + float sum = _mm512_reduce_add_ps(acc); |
| 253 | + while (it_begin < it_end) |
| 254 | + sum += *it_begin++; |
| 255 | + return sum; |
| 256 | + } |
| 257 | +}; |
| 258 | + |
| 259 | +/// Computes the sum of a sequence of float values using SIMD @b AVX-512 intrinsics, |
| 260 | +/// using caching loads and bidirectional traversal using all the available ZMM registers. |
| 261 | +struct avx512_f32unrolled_t { |
| 262 | + float const *const begin_ = nullptr; |
| 263 | + float const *const end_ = nullptr; |
| 264 | + |
| 265 | + float operator()() const noexcept { |
| 266 | + auto it_begin = begin_; |
| 267 | + auto it_end = end_; |
| 268 | + |
| 269 | + // We have a grand-total of 32 floats in a ZMM register. |
| 270 | + // We want to keep half of them free for loading buffers, and the rest can be used for accumulation: |
| 271 | + // 8 in the forward direction, 8 in the reverse direction, and 16 for the accumulator. |
| 272 | + __m512 fwd0 = _mm512_set1_ps(0.0f), rev0 = _mm512_set1_ps(0.0f); |
| 273 | + __m512 fwd1 = _mm512_set1_ps(0.0f), rev1 = _mm512_set1_ps(0.0f); |
| 274 | + __m512 fwd2 = _mm512_set1_ps(0.0f), rev2 = _mm512_set1_ps(0.0f); |
| 275 | + __m512 fwd3 = _mm512_set1_ps(0.0f), rev3 = _mm512_set1_ps(0.0f); |
| 276 | + __m512 fwd4 = _mm512_set1_ps(0.0f), rev4 = _mm512_set1_ps(0.0f); |
| 277 | + __m512 fwd5 = _mm512_set1_ps(0.0f), rev5 = _mm512_set1_ps(0.0f); |
| 278 | + __m512 fwd6 = _mm512_set1_ps(0.0f), rev6 = _mm512_set1_ps(0.0f); |
| 279 | + __m512 fwd7 = _mm512_set1_ps(0.0f), rev7 = _mm512_set1_ps(0.0f); |
| 280 | + |
| 281 | + // Process in chunks of 32 floats x 8 ZMM registers = 256 floats in each direction |
| 282 | + for (; it_end - it_begin >= 512; it_begin += 256, it_end -= 256) { |
| 283 | + fwd0 = _mm512_add_ps(fwd0, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_begin + 32 * 0)))); |
| 284 | + fwd1 = _mm512_add_ps(fwd1, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_begin + 32 * 1)))); |
| 285 | + fwd2 = _mm512_add_ps(fwd2, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_begin + 32 * 2)))); |
| 286 | + fwd3 = _mm512_add_ps(fwd3, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_begin + 32 * 3)))); |
| 287 | + fwd4 = _mm512_add_ps(fwd4, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_begin + 32 * 4)))); |
| 288 | + fwd5 = _mm512_add_ps(fwd5, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_begin + 32 * 5)))); |
| 289 | + fwd6 = _mm512_add_ps(fwd6, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_begin + 32 * 6)))); |
| 290 | + fwd7 = _mm512_add_ps(fwd7, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_begin + 32 * 7)))); |
| 291 | + rev0 = _mm512_add_ps(rev0, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_end - 32 * (1 + 0))))); |
| 292 | + rev1 = _mm512_add_ps(rev1, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_end - 32 * (1 + 1))))); |
| 293 | + rev2 = _mm512_add_ps(rev2, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_end - 32 * (1 + 2))))); |
| 294 | + rev3 = _mm512_add_ps(rev3, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_end - 32 * (1 + 3))))); |
| 295 | + rev4 = _mm512_add_ps(rev4, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_end - 32 * (1 + 4))))); |
| 296 | + rev5 = _mm512_add_ps(rev5, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_end - 32 * (1 + 5))))); |
| 297 | + rev6 = _mm512_add_ps(rev6, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_end - 32 * (1 + 6))))); |
| 298 | + rev7 = _mm512_add_ps(rev7, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_end - 32 * (1 + 7))))); |
| 299 | + } |
| 300 | + for (; it_end - it_begin >= 32; it_begin += 32) |
| 301 | + fwd1 = _mm512_add_ps(fwd1, _mm512_castsi512_ps(_mm512_stream_load_si512((void *)(it_begin)))); |
| 302 | + |
| 303 | + // Combine the accumulators |
| 304 | + __m512 fwd = _mm512_add_ps(_mm512_add_ps(_mm512_add_ps(fwd0, fwd1), _mm512_add_ps(fwd2, fwd3)), |
| 305 | + _mm512_add_ps(_mm512_add_ps(fwd4, fwd5), _mm512_add_ps(fwd5, fwd7))); |
| 306 | + __m512 rev = _mm512_add_ps(_mm512_add_ps(_mm512_add_ps(rev0, rev1), _mm512_add_ps(rev2, rev3)), |
| 307 | + _mm512_add_ps(_mm512_add_ps(rev4, rev5), _mm512_add_ps(rev5, rev7))); |
| 308 | + __m512 acc = _mm512_add_ps(fwd, rev); |
| 309 | + float sum = _mm512_reduce_add_ps(acc); |
| 310 | + while (it_begin < it_end) |
| 311 | + sum += *it_begin++; |
| 312 | + return sum; |
| 313 | + } |
| 314 | +}; |
| 315 | + |
| 316 | +#endif |
| 317 | + |
225 | 318 | #pragma region Multi Core |
226 | 319 |
|
227 | 320 | /// Computes the sum of a sequence of float values using @b OpenMP on-CPU multi-core reductions acceleration. |
|
0 commit comments