Skip to content

Commit 1a271b0

Browse files
committed
Add: AVX-512 variants
1 parent c32e3fa commit 1a271b0

File tree

2 files changed

+107
-3
lines changed

2 files changed

+107
-3
lines changed

reduce_bench.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ int main(int argc, char **argv) {
9898
->UseRealTime();
9999
bm::RegisterBenchmark("openmp<f32>", &make<openmp_t>)->MinTime(10)->UseRealTime();
100100

101-
// x86
101+
// x86 AVX2
102102
#if defined(__AVX2__)
103103
bm::RegisterBenchmark("avx2<f32>", &make<avx2_f32_t>)->MinTime(10)->UseRealTime();
104104
bm::RegisterBenchmark("avx2<f32kahan>", &make<avx2_f32kahan_t>)->MinTime(10)->UseRealTime();
@@ -107,6 +107,17 @@ int main(int argc, char **argv) {
107107
bm::RegisterBenchmark("avx2<f64>@threads", &make<threads_gt<avx2_f64_t>>)->MinTime(10)->UseRealTime();
108108
bm::RegisterBenchmark("sse<f32aligned>@threads", &make<threads_gt<sse_f32aligned_t>>)->MinTime(10)->UseRealTime();
109109
#endif
110+
// x86 AVX-512
111+
#if defined(__AVX512F__)
112+
bm::RegisterBenchmark("avx512<f32streamed>", &make<avx512_f32streamed_t>)->MinTime(10)->UseRealTime();
113+
bm::RegisterBenchmark("avx512<f32streamed>@threads", &make<threads_gt<avx512_f32streamed_t>>)
114+
->MinTime(10)
115+
->UseRealTime();
116+
bm::RegisterBenchmark("avx512<f32unrolled>", &make<avx512_f32streamed_t>)->MinTime(10)->UseRealTime();
117+
bm::RegisterBenchmark("avx512<f32unrolled>@threads", &make<threads_gt<avx512_f32streamed_t>>)
118+
->MinTime(10)
119+
->UseRealTime();
120+
#endif
110121

111122
// CUDA
112123
#if defined(__CUDACC__)

reduce_cpu.hpp

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
#include <omp.h> // `#pragma omp`
66
#include <thread> // `std::thread`
77

8-
#if defined(__AVX2__)
9-
#include <immintrin.h> // AVX2 intrinsics
8+
#if defined(__AVX2__) || defined(__AVX512F__)
9+
#include <immintrin.h> // x86 intrinsics
1010
#endif
1111

1212
namespace ashvardanian::reduce {
@@ -222,6 +222,99 @@ struct avx2_f32aligned_t {
222222

223223
#endif
224224

225+
#if defined(__AVX512F__)
226+
227+
/// Computes the sum of a sequence of float values using SIMD @b AVX-512 intrinsics,
228+
/// using streaming loads and bidirectional accumulation into 2 separate ZMM registers.
229+
struct avx512_f32streamed_t {
230+
float const *const begin_ = nullptr;
231+
float const *const end_ = nullptr;
232+
233+
float operator()() const noexcept {
234+
auto it_begin = begin_;
235+
auto it_end = end_;
236+
237+
__m512 acc1 = _mm512_set1_ps(0.0f); // Accumulator for forward direction
238+
__m512 acc2 = _mm512_set1_ps(0.0f); // Accumulator for reverse direction
239+
240+
// Process in chunks of 32 floats in each direction
241+
for (; it_end - it_begin >= 64; it_begin += 32, it_end -= 32) {
242+
acc1 = _mm512_add_ps(acc1, _mm512_castsi512_ps(_mm512_stream_load_si512((void *)(it_begin))));
243+
acc2 = _mm512_add_ps(acc2, _mm512_castsi512_ps(_mm512_stream_load_si512((void *)(it_end - 32))));
244+
}
245+
if (it_end - it_begin >= 32) {
246+
acc1 = _mm512_add_ps(acc1, _mm512_castsi512_ps(_mm512_stream_load_si512((void *)(it_begin))));
247+
it_begin += 32;
248+
}
249+
250+
// Combine the accumulators
251+
__m512 acc = _mm512_add_ps(acc1, acc2);
252+
float sum = _mm512_reduce_add_ps(acc);
253+
while (it_begin < it_end)
254+
sum += *it_begin++;
255+
return sum;
256+
}
257+
};
258+
259+
/// Computes the sum of a sequence of float values using SIMD @b AVX-512 intrinsics,
260+
/// using caching loads and bidirectional traversal using all the available ZMM registers.
261+
struct avx512_f32unrolled_t {
262+
float const *const begin_ = nullptr;
263+
float const *const end_ = nullptr;
264+
265+
float operator()() const noexcept {
266+
auto it_begin = begin_;
267+
auto it_end = end_;
268+
269+
// We have a grand-total of 32 floats in a ZMM register.
270+
// We want to keep half of them free for loading buffers, and the rest can be used for accumulation:
271+
// 8 in the forward direction, 8 in the reverse direction, and 16 for the accumulator.
272+
__m512 fwd0 = _mm512_set1_ps(0.0f), rev0 = _mm512_set1_ps(0.0f);
273+
__m512 fwd1 = _mm512_set1_ps(0.0f), rev1 = _mm512_set1_ps(0.0f);
274+
__m512 fwd2 = _mm512_set1_ps(0.0f), rev2 = _mm512_set1_ps(0.0f);
275+
__m512 fwd3 = _mm512_set1_ps(0.0f), rev3 = _mm512_set1_ps(0.0f);
276+
__m512 fwd4 = _mm512_set1_ps(0.0f), rev4 = _mm512_set1_ps(0.0f);
277+
__m512 fwd5 = _mm512_set1_ps(0.0f), rev5 = _mm512_set1_ps(0.0f);
278+
__m512 fwd6 = _mm512_set1_ps(0.0f), rev6 = _mm512_set1_ps(0.0f);
279+
__m512 fwd7 = _mm512_set1_ps(0.0f), rev7 = _mm512_set1_ps(0.0f);
280+
281+
// Process in chunks of 32 floats x 8 ZMM registers = 256 floats in each direction
282+
for (; it_end - it_begin >= 512; it_begin += 256, it_end -= 256) {
283+
fwd0 = _mm512_add_ps(fwd0, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_begin + 32 * 0))));
284+
fwd1 = _mm512_add_ps(fwd1, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_begin + 32 * 1))));
285+
fwd2 = _mm512_add_ps(fwd2, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_begin + 32 * 2))));
286+
fwd3 = _mm512_add_ps(fwd3, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_begin + 32 * 3))));
287+
fwd4 = _mm512_add_ps(fwd4, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_begin + 32 * 4))));
288+
fwd5 = _mm512_add_ps(fwd5, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_begin + 32 * 5))));
289+
fwd6 = _mm512_add_ps(fwd6, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_begin + 32 * 6))));
290+
fwd7 = _mm512_add_ps(fwd7, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_begin + 32 * 7))));
291+
rev0 = _mm512_add_ps(rev0, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_end - 32 * (1 + 0)))));
292+
rev1 = _mm512_add_ps(rev1, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_end - 32 * (1 + 1)))));
293+
rev2 = _mm512_add_ps(rev2, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_end - 32 * (1 + 2)))));
294+
rev3 = _mm512_add_ps(rev3, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_end - 32 * (1 + 3)))));
295+
rev4 = _mm512_add_ps(rev4, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_end - 32 * (1 + 4)))));
296+
rev5 = _mm512_add_ps(rev5, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_end - 32 * (1 + 5)))));
297+
rev6 = _mm512_add_ps(rev6, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_end - 32 * (1 + 6)))));
298+
rev7 = _mm512_add_ps(rev7, _mm512_castsi512_ps(_mm512_load_si512((void *)(it_end - 32 * (1 + 7)))));
299+
}
300+
for (; it_end - it_begin >= 32; it_begin += 32)
301+
fwd1 = _mm512_add_ps(fwd1, _mm512_castsi512_ps(_mm512_stream_load_si512((void *)(it_begin))));
302+
303+
// Combine the accumulators
304+
__m512 fwd = _mm512_add_ps(_mm512_add_ps(_mm512_add_ps(fwd0, fwd1), _mm512_add_ps(fwd2, fwd3)),
305+
_mm512_add_ps(_mm512_add_ps(fwd4, fwd5), _mm512_add_ps(fwd5, fwd7)));
306+
__m512 rev = _mm512_add_ps(_mm512_add_ps(_mm512_add_ps(rev0, rev1), _mm512_add_ps(rev2, rev3)),
307+
_mm512_add_ps(_mm512_add_ps(rev4, rev5), _mm512_add_ps(rev5, rev7)));
308+
__m512 acc = _mm512_add_ps(fwd, rev);
309+
float sum = _mm512_reduce_add_ps(acc);
310+
while (it_begin < it_end)
311+
sum += *it_begin++;
312+
return sum;
313+
}
314+
};
315+
316+
#endif
317+
225318
#pragma region Multi Core
226319

227320
/// Computes the sum of a sequence of float values using @b OpenMP on-CPU multi-core reductions acceleration.

0 commit comments

Comments
 (0)