Skip to content

Commit 1b4e500

Browse files
[0.8] Implement optimized BF16 support for ARM architecture - [MOD-9079] (#641)
* Implement optimized BF16 support for ARM architecture - [MOD-9079] (#623) * SVE implementation for bf16 * add required build flags and fix implementation * final fixes and implement benchmarks * added tests * implement neon bf16 distance functions * implement build flow and benchmarks * added test * format * remove redundant check * typo fix Co-authored-by: Copilot <[email protected]> * fixes and cleanup * fix build * fix svwhilelt_b16 calls --------- Co-authored-by: Copilot <[email protected]> (cherry picked from commit bb41732) * fix benchmark macros for 0.8 --------- Co-authored-by: GuyAv46 <[email protected]> Co-authored-by: GuyAv46 <[email protected]>
1 parent 95bef4e commit 1b4e500

File tree

15 files changed

+561
-4
lines changed

15 files changed

+561
-4
lines changed

cmake/aarch64InstructionFlags.cmake

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ CHECK_CXX_COMPILER_FLAG("-march=armv7-a+neon" CXX_ARMV7_NEON)
88
CHECK_CXX_COMPILER_FLAG("-march=armv8-a" CXX_ARMV8A)
99
CHECK_CXX_COMPILER_FLAG("-march=armv8-a+sve" CXX_SVE)
1010
CHECK_CXX_COMPILER_FLAG("-march=armv9-a+sve2" CXX_SVE2)
11+
CHECK_CXX_COMPILER_FLAG("-march=armv8.2-a+bf16" CXX_NEON_BF16)
12+
CHECK_CXX_COMPILER_FLAG("-march=armv8.2-a+sve+bf16" CXX_SVE_BF16)
1113

1214
# Only use ARMv9 if both compiler and CPU support it
1315
if(CXX_SVE2)
@@ -17,6 +19,12 @@ endif()
1719
if (CXX_ARMV8A OR CXX_ARMV7_NEON)
1820
add_compile_definitions(OPT_NEON)
1921
endif()
22+
if (CXX_NEON_BF16)
23+
add_compile_definitions(OPT_NEON_BF16)
24+
endif()
2025
if (CXX_SVE)
2126
add_compile_definitions(OPT_SVE)
2227
endif()
28+
if (CXX_SVE_BF16)
29+
add_compile_definitions(OPT_SVE_BF16)
30+
endif()

src/VecSim/spaces/CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,27 @@ if (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "(aarch64)|(arm64)|(ARM64)|(armv.*)")
8585
list(APPEND OPTIMIZATIONS functions/NEON.cpp)
8686
endif()
8787

88+
# NEON bfloat16 support
89+
if (CXX_NEON_BF16)
90+
message("Building with NEON + BF16")
91+
set_source_files_properties(functions/NEON_BF16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+bf16")
92+
list(APPEND OPTIMIZATIONS functions/NEON_BF16.cpp)
93+
endif()
94+
8895
# SVE support
8996
if (CXX_SVE)
9097
message("Building with SVE")
9198
set_source_files_properties(functions/SVE.cpp PROPERTIES COMPILE_FLAGS "-march=armv8-a+sve")
9299
list(APPEND OPTIMIZATIONS functions/SVE.cpp)
93100
endif()
94101

102+
# SVE with BF16 support
103+
if (CXX_SVE_BF16)
104+
message("Building with SVE + BF16")
105+
set_source_files_properties(functions/SVE_BF16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+sve+bf16")
106+
list(APPEND OPTIMIZATIONS functions/SVE_BF16.cpp)
107+
endif()
108+
95109
# SVE2 support
96110
if (CXX_SVE2)
97111
message("Building with ARMV9A and SVE2")
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
*Copyright Redis Ltd. 2021 - present
3+
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
4+
*the Server Side Public License v1 (SSPLv1).
5+
*/
6+
7+
#include <arm_neon.h>
8+
9+
inline void InnerProduct_Step(const bfloat16_t *&vec1, const bfloat16_t *&vec2, float32x4_t &acc) {
10+
// Load brain-half-precision vectors
11+
bfloat16x8_t v1 = vld1q_bf16(vec1);
12+
bfloat16x8_t v2 = vld1q_bf16(vec2);
13+
vec1 += 8;
14+
vec2 += 8;
15+
// Compute multiplications and add to the accumulator
16+
acc = vbfdotq_f32(acc, v1, v2);
17+
}
18+
19+
template <unsigned char residual> // 0..31
20+
float BF16_InnerProduct_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) {
21+
const auto *vec1 = static_cast<const bfloat16_t *>(pVect1v);
22+
const auto *vec2 = static_cast<const bfloat16_t *>(pVect2v);
23+
const auto *const v1End = vec1 + dimension;
24+
float32x4_t acc1 = vdupq_n_f32(0.0f);
25+
float32x4_t acc2 = vdupq_n_f32(0.0f);
26+
float32x4_t acc3 = vdupq_n_f32(0.0f);
27+
float32x4_t acc4 = vdupq_n_f32(0.0f);
28+
29+
// First, handle the partial chunk residual
30+
if constexpr (residual % 8) {
31+
auto constexpr chunk_residual = residual % 8;
32+
// TODO: special cases for some residuals and benchmark if its better
33+
constexpr uint16x8_t mask = {
34+
0xFFFF,
35+
(chunk_residual >= 2) ? 0xFFFF : 0,
36+
(chunk_residual >= 3) ? 0xFFFF : 0,
37+
(chunk_residual >= 4) ? 0xFFFF : 0,
38+
(chunk_residual >= 5) ? 0xFFFF : 0,
39+
(chunk_residual >= 6) ? 0xFFFF : 0,
40+
(chunk_residual >= 7) ? 0xFFFF : 0,
41+
0,
42+
};
43+
44+
// Load partial vectors
45+
bfloat16x8_t v1 = vld1q_bf16(vec1);
46+
bfloat16x8_t v2 = vld1q_bf16(vec2);
47+
48+
// Apply mask to both vectors
49+
bfloat16x8_t masked_v1 =
50+
vreinterpretq_bf16_u16(vandq_u16(vreinterpretq_u16_bf16(v1), mask));
51+
bfloat16x8_t masked_v2 =
52+
vreinterpretq_bf16_u16(vandq_u16(vreinterpretq_u16_bf16(v2), mask));
53+
54+
acc1 = vbfdotq_f32(acc1, masked_v1, masked_v2);
55+
56+
// Advance pointers
57+
vec1 += chunk_residual;
58+
vec2 += chunk_residual;
59+
}
60+
61+
// Handle (residual - (residual % 8)) in chunks of 8 bfloat16
62+
if constexpr (residual >= 8)
63+
InnerProduct_Step(vec1, vec2, acc2);
64+
if constexpr (residual >= 16)
65+
InnerProduct_Step(vec1, vec2, acc3);
66+
if constexpr (residual >= 24)
67+
InnerProduct_Step(vec1, vec2, acc4);
68+
69+
// Process the rest of the vectors (the full chunks part)
70+
while (vec1 < v1End) {
71+
// TODO: use `vld1q_f16_x4` for quad-loading?
72+
InnerProduct_Step(vec1, vec2, acc1);
73+
InnerProduct_Step(vec1, vec2, acc2);
74+
InnerProduct_Step(vec1, vec2, acc3);
75+
InnerProduct_Step(vec1, vec2, acc4);
76+
}
77+
78+
// Accumulate accumulators
79+
acc1 = vpaddq_f32(acc1, acc3);
80+
acc2 = vpaddq_f32(acc2, acc4);
81+
acc1 = vpaddq_f32(acc1, acc2);
82+
83+
// Pairwise add to get horizontal sum
84+
float32x2_t folded = vadd_f32(vget_low_f32(acc1), vget_high_f32(acc1));
85+
folded = vpadd_f32(folded, folded);
86+
87+
// Extract result
88+
return 1.0f - vget_lane_f32(folded, 0);
89+
}

src/VecSim/spaces/IP/IP_SVE_BF16.h

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
*Copyright Redis Ltd. 2021 - present
3+
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
4+
*the Server Side Public License v1 (SSPLv1).
5+
*/
6+
7+
#include <arm_sve.h>
8+
9+
inline void InnerProduct_Step(const bfloat16_t *vec1, const bfloat16_t *vec2, svfloat32_t &acc,
10+
size_t &offset, const size_t chunk) {
11+
svbool_t all = svptrue_b16();
12+
13+
// Load brain-half-precision vectors.
14+
svbfloat16_t v1 = svld1_bf16(all, vec1 + offset);
15+
svbfloat16_t v2 = svld1_bf16(all, vec2 + offset);
16+
// Compute multiplications and add to the accumulator
17+
acc = svbfdot(acc, v1, v2);
18+
19+
// Move to next chunk
20+
offset += chunk;
21+
}
22+
23+
template <bool partial_chunk, unsigned char additional_steps> // [t/f, 0..3]
24+
float BF16_InnerProduct_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) {
25+
const auto *vec1 = static_cast<const bfloat16_t *>(pVect1v);
26+
const auto *vec2 = static_cast<const bfloat16_t *>(pVect2v);
27+
const size_t chunk = svcnth(); // number of 16-bit elements in a register
28+
svfloat32_t acc1 = svdup_f32(0.0f);
29+
svfloat32_t acc2 = svdup_f32(0.0f);
30+
svfloat32_t acc3 = svdup_f32(0.0f);
31+
svfloat32_t acc4 = svdup_f32(0.0f);
32+
size_t offset = 0;
33+
34+
// Process all full vectors
35+
const size_t full_iterations = dimension / chunk / 4;
36+
for (size_t iter = 0; iter < full_iterations; iter++) {
37+
InnerProduct_Step(vec1, vec2, acc1, offset, chunk);
38+
InnerProduct_Step(vec1, vec2, acc2, offset, chunk);
39+
InnerProduct_Step(vec1, vec2, acc3, offset, chunk);
40+
InnerProduct_Step(vec1, vec2, acc4, offset, chunk);
41+
}
42+
43+
// Perform between 0 and 3 additional steps, according to `additional_steps` value
44+
if constexpr (additional_steps >= 1)
45+
InnerProduct_Step(vec1, vec2, acc1, offset, chunk);
46+
if constexpr (additional_steps >= 2)
47+
InnerProduct_Step(vec1, vec2, acc2, offset, chunk);
48+
if constexpr (additional_steps >= 3)
49+
InnerProduct_Step(vec1, vec2, acc3, offset, chunk);
50+
51+
// Handle the tail with the residual predicate
52+
if constexpr (partial_chunk) {
53+
svbool_t pg = svwhilelt_b16_u64(offset, dimension);
54+
55+
// Load brain-half-precision vectors.
56+
// Inactive elements are zeros, according to the docs
57+
svbfloat16_t v1 = svld1_bf16(pg, vec1 + offset);
58+
svbfloat16_t v2 = svld1_bf16(pg, vec2 + offset);
59+
// Compute multiplications and add to the accumulator.
60+
acc4 = svbfdot(acc4, v1, v2);
61+
}
62+
63+
// Accumulate accumulators
64+
acc1 = svadd_f32_x(svptrue_b32(), acc1, acc3);
65+
acc2 = svadd_f32_x(svptrue_b32(), acc2, acc4);
66+
acc1 = svadd_f32_x(svptrue_b32(), acc1, acc2);
67+
68+
// Reduce the accumulated sum.
69+
float result = svaddv_f32(svptrue_b32(), acc1);
70+
return 1.0f - result;
71+
}

src/VecSim/spaces/IP_space.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
#include "VecSim/spaces/functions/AVX2.h"
2020
#include "VecSim/spaces/functions/SSE3.h"
2121
#include "VecSim/spaces/functions/NEON.h"
22+
#include "VecSim/spaces/functions/NEON_BF16.h"
2223
#include "VecSim/spaces/functions/SVE.h"
24+
#include "VecSim/spaces/functions/SVE_BF16.h"
2325
#include "VecSim/spaces/functions/SVE2.h"
2426

2527
using bfloat16 = vecsim_types::bfloat16;
@@ -133,13 +135,27 @@ dist_func_t<float> IP_BF16_GetDistFunc(size_t dim, unsigned char *alignment, con
133135
if (!is_little_endian()) {
134136
return BF16_InnerProduct_BigEndian;
135137
}
138+
auto features = getCpuOptimizationFeatures(arch_opt);
139+
140+
#if defined(CPU_FEATURES_ARCH_AARCH64)
141+
#ifdef OPT_SVE_BF16
142+
if (features.svebf16) {
143+
return Choose_BF16_IP_implementation_SVE_BF16(dim);
144+
}
145+
#endif
146+
#ifdef OPT_NEON_BF16
147+
if (features.bf16 && dim >= 8) { // Optimization assumes at least 8 BF16s (full chunk)
148+
return Choose_BF16_IP_implementation_NEON_BF16(dim);
149+
}
150+
#endif
151+
#endif // AARCH64
152+
153+
#if defined(CPU_FEATURES_ARCH_X86_64)
136154
// Optimizations assume at least 32 bfloats. If we have less, we use the naive implementation.
137155
if (dim < 32) {
138156
return ret_dist_func;
139157
}
140158

141-
#ifdef CPU_FEATURES_ARCH_X86_64
142-
auto features = getCpuOptimizationFeatures(arch_opt);
143159
#ifdef OPT_AVX512_BF16_VL
144160
if (features.avx512_bf16 && features.avx512vl) {
145161
if (dim % 32 == 0) // no point in aligning if we have an offsetting residual
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
*Copyright Redis Ltd. 2021 - present
3+
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
4+
*the Server Side Public License v1 (SSPLv1).
5+
*/
6+
7+
#include <arm_neon.h>
8+
9+
// Assumes little-endianess
10+
inline void L2Sqr_Op(float32x4_t &acc, bfloat16x8_t &v1, bfloat16x8_t &v2) {
11+
float32x4_t v1_lo = vcvtq_low_f32_bf16(v1);
12+
float32x4_t v2_lo = vcvtq_low_f32_bf16(v2);
13+
float32x4_t diff_lo = vsubq_f32(v1_lo, v2_lo);
14+
15+
acc = vfmaq_f32(acc, diff_lo, diff_lo);
16+
17+
float32x4_t v1_hi = vcvtq_high_f32_bf16(v1);
18+
float32x4_t v2_hi = vcvtq_high_f32_bf16(v2);
19+
float32x4_t diff_hi = vsubq_f32(v1_hi, v2_hi);
20+
21+
acc = vfmaq_f32(acc, diff_hi, diff_hi);
22+
}
23+
24+
inline void L2Sqr_Step(const bfloat16_t *&vec1, const bfloat16_t *&vec2, float32x4_t &acc) {
25+
// Load brain-half-precision vectors
26+
bfloat16x8_t v1 = vld1q_bf16(vec1);
27+
bfloat16x8_t v2 = vld1q_bf16(vec2);
28+
vec1 += 8;
29+
vec2 += 8;
30+
L2Sqr_Op(acc, v1, v2);
31+
}
32+
33+
template <unsigned char residual> // 0..31
34+
float BF16_L2Sqr_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) {
35+
const auto *vec1 = static_cast<const bfloat16_t *>(pVect1v);
36+
const auto *vec2 = static_cast<const bfloat16_t *>(pVect2v);
37+
const auto *const v1End = vec1 + dimension;
38+
float32x4_t acc1 = vdupq_n_f32(0.0f);
39+
float32x4_t acc2 = vdupq_n_f32(0.0f);
40+
float32x4_t acc3 = vdupq_n_f32(0.0f);
41+
float32x4_t acc4 = vdupq_n_f32(0.0f);
42+
43+
// First, handle the partial chunk residual
44+
if constexpr (residual % 8) {
45+
auto constexpr chunk_residual = residual % 8;
46+
// TODO: special cases for some residuals and benchmark if its better
47+
constexpr uint16x8_t mask = {
48+
0xFFFF,
49+
(chunk_residual >= 2) ? 0xFFFF : 0,
50+
(chunk_residual >= 3) ? 0xFFFF : 0,
51+
(chunk_residual >= 4) ? 0xFFFF : 0,
52+
(chunk_residual >= 5) ? 0xFFFF : 0,
53+
(chunk_residual >= 6) ? 0xFFFF : 0,
54+
(chunk_residual >= 7) ? 0xFFFF : 0,
55+
0,
56+
};
57+
58+
// Load partial vectors
59+
bfloat16x8_t v1 = vld1q_bf16(vec1);
60+
bfloat16x8_t v2 = vld1q_bf16(vec2);
61+
62+
// Apply mask to both vectors
63+
bfloat16x8_t masked_v1 =
64+
vreinterpretq_bf16_u16(vandq_u16(vreinterpretq_u16_bf16(v1), mask));
65+
bfloat16x8_t masked_v2 =
66+
vreinterpretq_bf16_u16(vandq_u16(vreinterpretq_u16_bf16(v2), mask));
67+
68+
L2Sqr_Op(acc1, masked_v1, masked_v2);
69+
70+
// Advance pointers
71+
vec1 += chunk_residual;
72+
vec2 += chunk_residual;
73+
}
74+
75+
// Handle (residual - (residual % 8)) in chunks of 8 bfloat16
76+
if constexpr (residual >= 8)
77+
L2Sqr_Step(vec1, vec2, acc2);
78+
if constexpr (residual >= 16)
79+
L2Sqr_Step(vec1, vec2, acc3);
80+
if constexpr (residual >= 24)
81+
L2Sqr_Step(vec1, vec2, acc4);
82+
83+
// Process the rest of the vectors (the full chunks part)
84+
while (vec1 < v1End) {
85+
// TODO: use `vld1q_f16_x4` for quad-loading?
86+
L2Sqr_Step(vec1, vec2, acc1);
87+
L2Sqr_Step(vec1, vec2, acc2);
88+
L2Sqr_Step(vec1, vec2, acc3);
89+
L2Sqr_Step(vec1, vec2, acc4);
90+
}
91+
92+
// Accumulate accumulators
93+
acc1 = vpaddq_f32(acc1, acc3);
94+
acc2 = vpaddq_f32(acc2, acc4);
95+
acc1 = vpaddq_f32(acc1, acc2);
96+
97+
// Pairwise add to get horizontal sum
98+
float32x2_t folded = vadd_f32(vget_low_f32(acc1), vget_high_f32(acc1));
99+
folded = vpadd_f32(folded, folded);
100+
101+
// Extract result
102+
return vget_lane_f32(folded, 0);
103+
}

0 commit comments

Comments
 (0)