Skip to content

Commit d970056

Browse files
committed
Implement optimized FP16 support for ARM architecture - [MOD-9078] (#620)
* implement L2 SVE with intermediate casting to f32 * implement IP SVE with f16 ops only * implements L2 sve with no intermediate casting * add SVE and SVE2 functions files * add new files to cmake and use new implementations * added benchmarks * fix and switch implementation (due to sve2-only op) * test with SVE2 intrinsics * Revert "test with SVE2 intrinsics" This reverts commit 06dd65c. * remove redundant implementation * move to 4 steps per iteration implementations * add macro cleanup * fix implementation * refactor to use 4 accumulators * added tests * refactor accumulation * add initial neon implementation * fix build flags and file layout * fix tests * cleanup and L2 implementation with neon+fp16 * format * fix test for any arch * another attempt * fix test * rename step functions * comment-in neon benchmarks * fix benchmark * review fixes * more review fixes * fixes and cleanup * fix svwhilelt_b16 calls * use vbslq_f16 * typo fix * fix test for OSs that don't support fp16 * added back guards for a specific x86 test (cherry picked from commit fcc8d78)
1 parent 38dbd58 commit d970056

File tree

18 files changed

+631
-32
lines changed

18 files changed

+631
-32
lines changed

cmake/aarch64InstructionFlags.cmake

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ CHECK_CXX_COMPILER_FLAG("-march=armv7-a+neon" CXX_ARMV7_NEON)
88
CHECK_CXX_COMPILER_FLAG("-march=armv8-a" CXX_ARMV8A)
99
CHECK_CXX_COMPILER_FLAG("-march=armv8-a+sve" CXX_SVE)
1010
CHECK_CXX_COMPILER_FLAG("-march=armv9-a+sve2" CXX_SVE2)
11+
CHECK_CXX_COMPILER_FLAG("-march=armv8.2-a+fp16fml" CXX_NEON_HP)
1112
CHECK_CXX_COMPILER_FLAG("-march=armv8.2-a+bf16" CXX_NEON_BF16)
1213
CHECK_CXX_COMPILER_FLAG("-march=armv8.2-a+sve+bf16" CXX_SVE_BF16)
1314

@@ -17,7 +18,12 @@ if(CXX_SVE2)
1718
add_compile_definitions(OPT_SVE2)
1819
endif()
1920
if (CXX_ARMV8A OR CXX_ARMV7_NEON)
21+
message(STATUS "Using ARMv8.0-a with NEON")
2022
add_compile_definitions(OPT_NEON)
23+
endif()
24+
if (CXX_NEON_HP)
25+
message(STATUS "Using ARMv8.2-a with NEON half-percision extension")
26+
add_compile_definitions(OPT_NEON_HP)
2127
endif()
2228
if (CXX_NEON_BF16)
2329
add_compile_definitions(OPT_NEON_BF16)

src/VecSim/spaces/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ if (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "(aarch64)|(arm64)|(ARM64)|(armv.*)")
8585
list(APPEND OPTIMIZATIONS functions/NEON.cpp)
8686
endif()
8787

88+
# NEON half-precision support
89+
if (CXX_NEON_HP AND CXX_ARMV8A)
90+
message("Building with NEON+HP")
91+
set_source_files_properties(functions/NEON_HP.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+fp16fml")
92+
list(APPEND OPTIMIZATIONS functions/NEON_HP.cpp)
93+
endif()
94+
8895
# NEON bfloat16 support
8996
if (CXX_NEON_BF16)
9097
message("Building with NEON + BF16")
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
*Copyright Redis Ltd. 2021 - present
3+
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
4+
*the Server Side Public License v1 (SSPLv1).
5+
*/
6+
7+
#include <arm_neon.h>
8+
9+
inline void InnerProduct_Step(const float16_t *&vec1, const float16_t *&vec2, float16x8_t &acc) {
10+
// Load half-precision vectors
11+
float16x8_t v1 = vld1q_f16(vec1);
12+
float16x8_t v2 = vld1q_f16(vec2);
13+
vec1 += 8;
14+
vec2 += 8;
15+
16+
// Multiply and accumulate
17+
acc = vfmaq_f16(acc, v1, v2);
18+
}
19+
20+
template <unsigned char residual> // 0..31
21+
float FP16_InnerProduct_NEON_HP(const void *pVect1v, const void *pVect2v, size_t dimension) {
22+
const auto *vec1 = static_cast<const float16_t *>(pVect1v);
23+
const auto *vec2 = static_cast<const float16_t *>(pVect2v);
24+
const auto *const v1End = vec1 + dimension;
25+
float16x8_t acc1 = vdupq_n_f16(0.0f);
26+
float16x8_t acc2 = vdupq_n_f16(0.0f);
27+
float16x8_t acc3 = vdupq_n_f16(0.0f);
28+
float16x8_t acc4 = vdupq_n_f16(0.0f);
29+
30+
// First, handle the partial chunk residual
31+
if constexpr (residual % 8) {
32+
auto constexpr chunk_residual = residual % 8;
33+
// TODO: spacial cases for some residuals and benchmark if its better
34+
constexpr uint16x8_t mask = {
35+
0xFFFF,
36+
(chunk_residual >= 2) ? 0xFFFF : 0,
37+
(chunk_residual >= 3) ? 0xFFFF : 0,
38+
(chunk_residual >= 4) ? 0xFFFF : 0,
39+
(chunk_residual >= 5) ? 0xFFFF : 0,
40+
(chunk_residual >= 6) ? 0xFFFF : 0,
41+
(chunk_residual >= 7) ? 0xFFFF : 0,
42+
0,
43+
};
44+
45+
// Load partial vectors
46+
float16x8_t v1 = vld1q_f16(vec1);
47+
float16x8_t v2 = vld1q_f16(vec2);
48+
49+
// Apply mask to both vectors
50+
float16x8_t masked_v1 = vbslq_f16(mask, v1, acc1); // `acc1` should be all zeros here
51+
float16x8_t masked_v2 = vbslq_f16(mask, v2, acc2); // `acc2` should be all zeros here
52+
53+
// Multiply and accumulate
54+
acc1 = vfmaq_f16(acc1, masked_v1, masked_v2);
55+
56+
// Advance pointers
57+
vec1 += chunk_residual;
58+
vec2 += chunk_residual;
59+
}
60+
61+
// Handle (residual - (residual % 8)) in chunks of 8 float16
62+
if constexpr (residual >= 8)
63+
InnerProduct_Step(vec1, vec2, acc2);
64+
if constexpr (residual >= 16)
65+
InnerProduct_Step(vec1, vec2, acc3);
66+
if constexpr (residual >= 24)
67+
InnerProduct_Step(vec1, vec2, acc4);
68+
69+
// Process the rest of the vectors (the full chunks part)
70+
while (vec1 < v1End) {
71+
// TODO: use `vld1q_f16_x4` for quad-loading?
72+
InnerProduct_Step(vec1, vec2, acc1);
73+
InnerProduct_Step(vec1, vec2, acc2);
74+
InnerProduct_Step(vec1, vec2, acc3);
75+
InnerProduct_Step(vec1, vec2, acc4);
76+
}
77+
78+
// Accumulate accumulators
79+
acc1 = vpaddq_f16(acc1, acc3);
80+
acc2 = vpaddq_f16(acc2, acc4);
81+
acc1 = vpaddq_f16(acc1, acc2);
82+
83+
// Horizontal sum of the accumulated values
84+
float32x4_t sum_f32 = vcvt_f32_f16(vget_low_f16(acc1));
85+
sum_f32 = vaddq_f32(sum_f32, vcvt_f32_f16(vget_high_f16(acc1)));
86+
87+
// Pairwise add to get horizontal sum
88+
float32x2_t sum_2 = vadd_f32(vget_low_f32(sum_f32), vget_high_f32(sum_f32));
89+
sum_2 = vpadd_f32(sum_2, sum_2);
90+
91+
// Extract result
92+
return 1.0f - vget_lane_f32(sum_2, 0);
93+
}

src/VecSim/spaces/IP/IP_SVE_FP16.h

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
*Copyright Redis Ltd. 2021 - present
3+
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
4+
*the Server Side Public License v1 (SSPLv1).
5+
*/
6+
7+
#include <arm_sve.h>
8+
9+
inline void InnerProduct_Step(const float16_t *vec1, const float16_t *vec2, svfloat16_t &acc,
10+
size_t &offset, const size_t chunk) {
11+
svbool_t all = svptrue_b16();
12+
13+
// Load half-precision vectors.
14+
svfloat16_t v1 = svld1_f16(all, vec1 + offset);
15+
svfloat16_t v2 = svld1_f16(all, vec2 + offset);
16+
// Compute multiplications and add to the accumulator
17+
acc = svmla_f16_x(all, acc, v1, v2);
18+
19+
// Move to next chunk
20+
offset += chunk;
21+
}
22+
23+
template <bool partial_chunk, unsigned char additional_steps> // [t/f, 0..3]
24+
float FP16_InnerProduct_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) {
25+
const auto *vec1 = static_cast<const float16_t *>(pVect1v);
26+
const auto *vec2 = static_cast<const float16_t *>(pVect2v);
27+
const size_t chunk = svcnth(); // number of 16-bit elements in a register
28+
svbool_t all = svptrue_b16();
29+
svfloat16_t acc1 = svdup_f16(0.0f);
30+
svfloat16_t acc2 = svdup_f16(0.0f);
31+
svfloat16_t acc3 = svdup_f16(0.0f);
32+
svfloat16_t acc4 = svdup_f16(0.0f);
33+
size_t offset = 0;
34+
35+
// Process all full vectors
36+
const size_t full_iterations = dimension / chunk / 4;
37+
for (size_t iter = 0; iter < full_iterations; iter++) {
38+
InnerProduct_Step(vec1, vec2, acc1, offset, chunk);
39+
InnerProduct_Step(vec1, vec2, acc2, offset, chunk);
40+
InnerProduct_Step(vec1, vec2, acc3, offset, chunk);
41+
InnerProduct_Step(vec1, vec2, acc4, offset, chunk);
42+
}
43+
44+
// Perform between 0 and 3 additional steps, according to `additional_steps` value
45+
if constexpr (additional_steps >= 1)
46+
InnerProduct_Step(vec1, vec2, acc1, offset, chunk);
47+
if constexpr (additional_steps >= 2)
48+
InnerProduct_Step(vec1, vec2, acc2, offset, chunk);
49+
if constexpr (additional_steps >= 3)
50+
InnerProduct_Step(vec1, vec2, acc3, offset, chunk);
51+
52+
// Handle the tail with the residual predicate
53+
if constexpr (partial_chunk) {
54+
svbool_t pg = svwhilelt_b16_u64(offset, dimension);
55+
56+
// Load half-precision vectors.
57+
svfloat16_t v1 = svld1_f16(pg, vec1 + offset);
58+
svfloat16_t v2 = svld1_f16(pg, vec2 + offset);
59+
// Compute multiplications and add to the accumulator.
60+
// use the existing value of `acc` for the inactive elements (by the `m` suffix)
61+
acc4 = svmla_f16_m(pg, acc4, v1, v2);
62+
}
63+
64+
// Accumulate accumulators
65+
acc1 = svadd_f16_x(all, acc1, acc3);
66+
acc2 = svadd_f16_x(all, acc2, acc4);
67+
acc1 = svadd_f16_x(all, acc1, acc2);
68+
69+
// Reduce the accumulated sum.
70+
float result = svaddv_f16(all, acc1);
71+
return 1.0f - result;
72+
}

src/VecSim/spaces/IP_space.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "VecSim/spaces/functions/AVX2.h"
2020
#include "VecSim/spaces/functions/SSE3.h"
2121
#include "VecSim/spaces/functions/NEON.h"
22+
#include "VecSim/spaces/functions/NEON_HP.h"
2223
#include "VecSim/spaces/functions/NEON_BF16.h"
2324
#include "VecSim/spaces/functions/SVE.h"
2425
#include "VecSim/spaces/functions/SVE_BF16.h"
@@ -213,14 +214,33 @@ dist_func_t<float> IP_FP16_GetDistFunc(size_t dim, unsigned char *alignment, con
213214
if (alignment == nullptr) {
214215
alignment = &dummy_alignment;
215216
}
217+
auto features = getCpuOptimizationFeatures(arch_opt);
216218

217219
dist_func_t<float> ret_dist_func = FP16_InnerProduct;
220+
221+
#if defined(CPU_FEATURES_ARCH_AARCH64)
222+
#ifdef OPT_SVE2
223+
if (features.sve2) {
224+
return Choose_FP16_IP_implementation_SVE2(dim);
225+
}
226+
#endif
227+
#ifdef OPT_SVE
228+
if (features.sve) {
229+
return Choose_FP16_IP_implementation_SVE(dim);
230+
}
231+
#endif
232+
#ifdef OPT_NEON_HP
233+
if (features.asimdhp && dim >= 8) { // Optimization assumes at least 8 16FPs (full chunk)
234+
return Choose_FP16_IP_implementation_NEON_HP(dim);
235+
}
236+
#endif
237+
#endif
238+
239+
#if defined(CPU_FEATURES_ARCH_X86_64)
218240
// Optimizations assume at least 32 16FPs. If we have less, we use the naive implementation.
219241
if (dim < 32) {
220242
return ret_dist_func;
221243
}
222-
#ifdef CPU_FEATURES_ARCH_X86_64
223-
auto features = getCpuOptimizationFeatures(arch_opt);
224244
#ifdef OPT_AVX512_FP16_VL
225245
// More details about the dimension limitation can be found in this PR's description:
226246
// https://github.com/RedisAI/VectorSimilarity/pull/477
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
*Copyright Redis Ltd. 2021 - present
3+
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
4+
*the Server Side Public License v1 (SSPLv1).
5+
*/
6+
7+
#include <arm_neon.h>
8+
9+
inline void L2Sqr_Step(const float16_t *&vec1, const float16_t *&vec2, float16x8_t &acc) {
10+
// Load half-precision vectors
11+
float16x8_t v1 = vld1q_f16(vec1);
12+
float16x8_t v2 = vld1q_f16(vec2);
13+
vec1 += 8;
14+
vec2 += 8;
15+
16+
// Calculate differences
17+
float16x8_t diff = vsubq_f16(v1, v2);
18+
// Square and accumulate
19+
acc = vfmaq_f16(acc, diff, diff);
20+
}
21+
22+
template <unsigned char residual> // 0..31
23+
float FP16_L2Sqr_NEON_HP(const void *pVect1v, const void *pVect2v, size_t dimension) {
24+
const auto *vec1 = static_cast<const float16_t *>(pVect1v);
25+
const auto *vec2 = static_cast<const float16_t *>(pVect2v);
26+
const auto *const v1End = vec1 + dimension;
27+
float16x8_t acc1 = vdupq_n_f16(0.0f);
28+
float16x8_t acc2 = vdupq_n_f16(0.0f);
29+
float16x8_t acc3 = vdupq_n_f16(0.0f);
30+
float16x8_t acc4 = vdupq_n_f16(0.0f);
31+
32+
// First, handle the partial chunk residual
33+
if constexpr (residual % 8) {
34+
auto constexpr chunk_residual = residual % 8;
35+
// TODO: spacial cases for some residuals and benchmark if its better
36+
constexpr uint16x8_t mask = {
37+
0xFFFF,
38+
(chunk_residual >= 2) ? 0xFFFF : 0,
39+
(chunk_residual >= 3) ? 0xFFFF : 0,
40+
(chunk_residual >= 4) ? 0xFFFF : 0,
41+
(chunk_residual >= 5) ? 0xFFFF : 0,
42+
(chunk_residual >= 6) ? 0xFFFF : 0,
43+
(chunk_residual >= 7) ? 0xFFFF : 0,
44+
0,
45+
};
46+
47+
// Load partial vectors
48+
float16x8_t v1 = vld1q_f16(vec1);
49+
float16x8_t v2 = vld1q_f16(vec2);
50+
51+
// Apply mask to both vectors
52+
float16x8_t masked_v1 = vbslq_f16(mask, v1, acc1); // `acc1` should be all zeros here
53+
float16x8_t masked_v2 = vbslq_f16(mask, v2, acc2); // `acc2` should be all zeros here
54+
55+
// Calculate differences
56+
float16x8_t diff = vsubq_f16(masked_v1, masked_v2);
57+
// Square and accumulate
58+
acc1 = vfmaq_f16(acc1, diff, diff);
59+
60+
// Advance pointers
61+
vec1 += chunk_residual;
62+
vec2 += chunk_residual;
63+
}
64+
65+
// Handle (residual - (residual % 8)) in chunks of 8 float16
66+
if constexpr (residual >= 8)
67+
L2Sqr_Step(vec1, vec2, acc2);
68+
if constexpr (residual >= 16)
69+
L2Sqr_Step(vec1, vec2, acc3);
70+
if constexpr (residual >= 24)
71+
L2Sqr_Step(vec1, vec2, acc4);
72+
73+
// Process the rest of the vectors (the full chunks part)
74+
while (vec1 < v1End) {
75+
// TODO: use `vld1q_f16_x4` for quad-loading?
76+
L2Sqr_Step(vec1, vec2, acc1);
77+
L2Sqr_Step(vec1, vec2, acc2);
78+
L2Sqr_Step(vec1, vec2, acc3);
79+
L2Sqr_Step(vec1, vec2, acc4);
80+
}
81+
82+
// Accumulate accumulators
83+
acc1 = vpaddq_f16(acc1, acc3);
84+
acc2 = vpaddq_f16(acc2, acc4);
85+
acc1 = vpaddq_f16(acc1, acc2);
86+
87+
// Horizontal sum of the accumulated values
88+
float32x4_t sum_f32 = vcvt_f32_f16(vget_low_f16(acc1));
89+
sum_f32 = vaddq_f32(sum_f32, vcvt_f32_f16(vget_high_f16(acc1)));
90+
91+
// Pairwise add to get horizontal sum
92+
float32x2_t sum_2 = vadd_f32(vget_low_f32(sum_f32), vget_high_f32(sum_f32));
93+
sum_2 = vpadd_f32(sum_2, sum_2);
94+
95+
// Extract result
96+
return vget_lane_f32(sum_2, 0);
97+
}

0 commit comments

Comments
 (0)