Skip to content

Commit ec0b2b4

Browse files
committed
Merge remote-tracking branch 'origin/main' into Omer_arm_int8_sve_sve2
2 parents 69efdaa + 1e08ea4 commit ec0b2b4

36 files changed

+1271
-46
lines changed

cmake/aarch64InstructionFlags.cmake

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,29 @@ CHECK_CXX_COMPILER_FLAG("-march=armv7-a+neon" CXX_ARMV7_NEON)
88
CHECK_CXX_COMPILER_FLAG("-march=armv8-a" CXX_ARMV8A)
99
CHECK_CXX_COMPILER_FLAG("-march=armv8-a+sve" CXX_SVE)
1010
CHECK_CXX_COMPILER_FLAG("-march=armv9-a+sve2" CXX_SVE2)
11+
CHECK_CXX_COMPILER_FLAG("-march=armv8.2-a+fp16fml" CXX_NEON_HP)
12+
CHECK_CXX_COMPILER_FLAG("-march=armv8.2-a+bf16" CXX_NEON_BF16)
13+
CHECK_CXX_COMPILER_FLAG("-march=armv8.2-a+sve+bf16" CXX_SVE_BF16)
1114

1215
# Only use ARMv9 if both compiler and CPU support it
1316
if(CXX_SVE2)
1417
message(STATUS "Using ARMv9.0-a with SVE2 (supported by CPU)")
1518
add_compile_definitions(OPT_SVE2)
1619
endif()
1720
if (CXX_ARMV8A OR CXX_ARMV7_NEON)
21+
message(STATUS "Using ARMv8.0-a with NEON")
1822
add_compile_definitions(OPT_NEON)
23+
endif()
24+
if (CXX_NEON_HP)
25+
message(STATUS "Using ARMv8.2-a with NEON half-percision extension")
26+
add_compile_definitions(OPT_NEON_HP)
27+
endif()
28+
if (CXX_NEON_BF16)
29+
add_compile_definitions(OPT_NEON_BF16)
1930
endif()
2031
if (CXX_SVE)
2132
add_compile_definitions(OPT_SVE)
2233
endif()
34+
if (CXX_SVE_BF16)
35+
add_compile_definitions(OPT_SVE_BF16)
36+
endif()

src/VecSim/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ add_library(VectorSimilarity ${VECSIM_LIBTYPE}
6363
index_factories/tiered_factory.cpp
6464
${svs_factory_file}
6565
index_factories/index_factory.cpp
66+
index_factories/components/components_factory.cpp
6667
algorithms/hnsw/visited_nodes_handler.cpp
6768
vec_sim.cpp
6869
vec_sim_debug.cpp
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
*Copyright Redis Ltd. 2021 - present
3+
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
4+
*the Server Side Public License v1 (SSPLv1).
5+
*/
6+
7+
#include "VecSim/index_factories/components/components_factory.h"
8+
9+
PreprocessorsContainerParams CreatePreprocessorsContainerParams(VecSimMetric metric, size_t dim,
10+
bool is_normalized,
11+
unsigned char alignment) {
12+
// If the index metric is Cosine, and is_normalized == true, we will skip normalizing vectors
13+
// and query blobs.
14+
VecSimMetric pp_metric;
15+
if (is_normalized && metric == VecSimMetric_Cosine) {
16+
pp_metric = VecSimMetric_IP;
17+
} else {
18+
pp_metric = metric;
19+
}
20+
return {.metric = pp_metric, .dim = dim, .alignment = alignment};
21+
}

src/VecSim/index_factories/components/components_factory.h

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
#include "VecSim/index_factories/components/preprocessors_factory.h"
1313
#include "VecSim/spaces/computer/calculator.h"
1414

15+
PreprocessorsContainerParams CreatePreprocessorsContainerParams(VecSimMetric metric, size_t dim,
16+
bool is_normalized,
17+
unsigned char alignment);
18+
1519
template <typename DataType, typename DistType>
1620
IndexComponents<DataType, DistType>
1721
CreateIndexComponents(std::shared_ptr<VecSimAllocator> allocator, VecSimMetric metric, size_t dim,
@@ -22,16 +26,8 @@ CreateIndexComponents(std::shared_ptr<VecSimAllocator> allocator, VecSimMetric m
2226
// Currently we have only one distance calculator implementation
2327
auto indexCalculator = new (allocator) DistanceCalculatorCommon<DistType>(allocator, distFunc);
2428

25-
// If the index metric is Cosine, and is_normalized == true, we will skip normalizing vectors
26-
// and query blobs.
27-
VecSimMetric pp_metric;
28-
if (is_normalized && metric == VecSimMetric_Cosine) {
29-
pp_metric = VecSimMetric_IP;
30-
} else {
31-
pp_metric = metric;
32-
}
33-
PreprocessorsContainerParams ppParams = {
34-
.metric = pp_metric, .dim = dim, .alignment = alignment};
29+
PreprocessorsContainerParams ppParams =
30+
CreatePreprocessorsContainerParams(metric, dim, is_normalized, alignment);
3531
auto preprocessors = CreatePreprocessorsContainer<DataType>(allocator, ppParams);
3632

3733
return {indexCalculator, preprocessors};

src/VecSim/spaces/CMakeLists.txt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,34 @@ if (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "(aarch64)|(arm64)|(ARM64)|(armv.*)")
9191
list(APPEND OPTIMIZATIONS functions/NEON.cpp)
9292
endif()
9393

94+
# NEON half-precision support
95+
if (CXX_NEON_HP AND CXX_ARMV8A)
96+
message("Building with NEON+HP")
97+
set_source_files_properties(functions/NEON_HP.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+fp16fml")
98+
list(APPEND OPTIMIZATIONS functions/NEON_HP.cpp)
99+
endif()
100+
101+
# NEON bfloat16 support
102+
if (CXX_NEON_BF16)
103+
message("Building with NEON + BF16")
104+
set_source_files_properties(functions/NEON_BF16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+bf16")
105+
list(APPEND OPTIMIZATIONS functions/NEON_BF16.cpp)
106+
endif()
107+
94108
# SVE support
95109
if (CXX_SVE)
96110
message("Building with SVE")
97111
set_source_files_properties(functions/SVE.cpp PROPERTIES COMPILE_FLAGS "-march=armv8-a+sve")
98112
list(APPEND OPTIMIZATIONS functions/SVE.cpp)
99113
endif()
100114

115+
# SVE with BF16 support
116+
if (CXX_SVE_BF16)
117+
message("Building with SVE + BF16")
118+
set_source_files_properties(functions/SVE_BF16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+sve+bf16")
119+
list(APPEND OPTIMIZATIONS functions/SVE_BF16.cpp)
120+
endif()
121+
101122
# SVE2 support
102123
if (CXX_SVE2)
103124
message("Building with ARMV9A and SVE2")
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
*Copyright Redis Ltd. 2021 - present
3+
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
4+
*the Server Side Public License v1 (SSPLv1).
5+
*/
6+
7+
#include <arm_neon.h>
8+
9+
inline void InnerProduct_Step(const bfloat16_t *&vec1, const bfloat16_t *&vec2, float32x4_t &acc) {
10+
// Load brain-half-precision vectors
11+
bfloat16x8_t v1 = vld1q_bf16(vec1);
12+
bfloat16x8_t v2 = vld1q_bf16(vec2);
13+
vec1 += 8;
14+
vec2 += 8;
15+
// Compute multiplications and add to the accumulator
16+
acc = vbfdotq_f32(acc, v1, v2);
17+
}
18+
19+
template <unsigned char residual> // 0..31
20+
float BF16_InnerProduct_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) {
21+
const auto *vec1 = static_cast<const bfloat16_t *>(pVect1v);
22+
const auto *vec2 = static_cast<const bfloat16_t *>(pVect2v);
23+
const auto *const v1End = vec1 + dimension;
24+
float32x4_t acc1 = vdupq_n_f32(0.0f);
25+
float32x4_t acc2 = vdupq_n_f32(0.0f);
26+
float32x4_t acc3 = vdupq_n_f32(0.0f);
27+
float32x4_t acc4 = vdupq_n_f32(0.0f);
28+
29+
// First, handle the partial chunk residual
30+
if constexpr (residual % 8) {
31+
auto constexpr chunk_residual = residual % 8;
32+
// TODO: special cases for some residuals and benchmark if its better
33+
constexpr uint16x8_t mask = {
34+
0xFFFF,
35+
(chunk_residual >= 2) ? 0xFFFF : 0,
36+
(chunk_residual >= 3) ? 0xFFFF : 0,
37+
(chunk_residual >= 4) ? 0xFFFF : 0,
38+
(chunk_residual >= 5) ? 0xFFFF : 0,
39+
(chunk_residual >= 6) ? 0xFFFF : 0,
40+
(chunk_residual >= 7) ? 0xFFFF : 0,
41+
0,
42+
};
43+
44+
// Load partial vectors
45+
bfloat16x8_t v1 = vld1q_bf16(vec1);
46+
bfloat16x8_t v2 = vld1q_bf16(vec2);
47+
48+
// Apply mask to both vectors
49+
bfloat16x8_t masked_v1 =
50+
vreinterpretq_bf16_u16(vandq_u16(vreinterpretq_u16_bf16(v1), mask));
51+
bfloat16x8_t masked_v2 =
52+
vreinterpretq_bf16_u16(vandq_u16(vreinterpretq_u16_bf16(v2), mask));
53+
54+
acc1 = vbfdotq_f32(acc1, masked_v1, masked_v2);
55+
56+
// Advance pointers
57+
vec1 += chunk_residual;
58+
vec2 += chunk_residual;
59+
}
60+
61+
// Handle (residual - (residual % 8)) in chunks of 8 bfloat16
62+
if constexpr (residual >= 8)
63+
InnerProduct_Step(vec1, vec2, acc2);
64+
if constexpr (residual >= 16)
65+
InnerProduct_Step(vec1, vec2, acc3);
66+
if constexpr (residual >= 24)
67+
InnerProduct_Step(vec1, vec2, acc4);
68+
69+
// Process the rest of the vectors (the full chunks part)
70+
while (vec1 < v1End) {
71+
// TODO: use `vld1q_f16_x4` for quad-loading?
72+
InnerProduct_Step(vec1, vec2, acc1);
73+
InnerProduct_Step(vec1, vec2, acc2);
74+
InnerProduct_Step(vec1, vec2, acc3);
75+
InnerProduct_Step(vec1, vec2, acc4);
76+
}
77+
78+
// Accumulate accumulators
79+
acc1 = vpaddq_f32(acc1, acc3);
80+
acc2 = vpaddq_f32(acc2, acc4);
81+
acc1 = vpaddq_f32(acc1, acc2);
82+
83+
// Pairwise add to get horizontal sum
84+
float32x2_t folded = vadd_f32(vget_low_f32(acc1), vget_high_f32(acc1));
85+
folded = vpadd_f32(folded, folded);
86+
87+
// Extract result
88+
return 1.0f - vget_lane_f32(folded, 0);
89+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
*Copyright Redis Ltd. 2021 - present
3+
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
4+
*the Server Side Public License v1 (SSPLv1).
5+
*/
6+
7+
#include <arm_neon.h>
8+
9+
inline void InnerProduct_Step(const float16_t *&vec1, const float16_t *&vec2, float16x8_t &acc) {
10+
// Load half-precision vectors
11+
float16x8_t v1 = vld1q_f16(vec1);
12+
float16x8_t v2 = vld1q_f16(vec2);
13+
vec1 += 8;
14+
vec2 += 8;
15+
16+
// Multiply and accumulate
17+
acc = vfmaq_f16(acc, v1, v2);
18+
}
19+
20+
template <unsigned char residual> // 0..31
21+
float FP16_InnerProduct_NEON_HP(const void *pVect1v, const void *pVect2v, size_t dimension) {
22+
const auto *vec1 = static_cast<const float16_t *>(pVect1v);
23+
const auto *vec2 = static_cast<const float16_t *>(pVect2v);
24+
const auto *const v1End = vec1 + dimension;
25+
float16x8_t acc1 = vdupq_n_f16(0.0f);
26+
float16x8_t acc2 = vdupq_n_f16(0.0f);
27+
float16x8_t acc3 = vdupq_n_f16(0.0f);
28+
float16x8_t acc4 = vdupq_n_f16(0.0f);
29+
30+
// First, handle the partial chunk residual
31+
if constexpr (residual % 8) {
32+
auto constexpr chunk_residual = residual % 8;
33+
// TODO: spacial cases for some residuals and benchmark if its better
34+
constexpr uint16x8_t mask = {
35+
0xFFFF,
36+
(chunk_residual >= 2) ? 0xFFFF : 0,
37+
(chunk_residual >= 3) ? 0xFFFF : 0,
38+
(chunk_residual >= 4) ? 0xFFFF : 0,
39+
(chunk_residual >= 5) ? 0xFFFF : 0,
40+
(chunk_residual >= 6) ? 0xFFFF : 0,
41+
(chunk_residual >= 7) ? 0xFFFF : 0,
42+
0,
43+
};
44+
45+
// Load partial vectors
46+
float16x8_t v1 = vld1q_f16(vec1);
47+
float16x8_t v2 = vld1q_f16(vec2);
48+
49+
// Apply mask to both vectors
50+
float16x8_t masked_v1 = vbslq_f16(mask, v1, acc1); // `acc1` should be all zeros here
51+
float16x8_t masked_v2 = vbslq_f16(mask, v2, acc2); // `acc2` should be all zeros here
52+
53+
// Multiply and accumulate
54+
acc1 = vfmaq_f16(acc1, masked_v1, masked_v2);
55+
56+
// Advance pointers
57+
vec1 += chunk_residual;
58+
vec2 += chunk_residual;
59+
}
60+
61+
// Handle (residual - (residual % 8)) in chunks of 8 float16
62+
if constexpr (residual >= 8)
63+
InnerProduct_Step(vec1, vec2, acc2);
64+
if constexpr (residual >= 16)
65+
InnerProduct_Step(vec1, vec2, acc3);
66+
if constexpr (residual >= 24)
67+
InnerProduct_Step(vec1, vec2, acc4);
68+
69+
// Process the rest of the vectors (the full chunks part)
70+
while (vec1 < v1End) {
71+
// TODO: use `vld1q_f16_x4` for quad-loading?
72+
InnerProduct_Step(vec1, vec2, acc1);
73+
InnerProduct_Step(vec1, vec2, acc2);
74+
InnerProduct_Step(vec1, vec2, acc3);
75+
InnerProduct_Step(vec1, vec2, acc4);
76+
}
77+
78+
// Accumulate accumulators
79+
acc1 = vpaddq_f16(acc1, acc3);
80+
acc2 = vpaddq_f16(acc2, acc4);
81+
acc1 = vpaddq_f16(acc1, acc2);
82+
83+
// Horizontal sum of the accumulated values
84+
float32x4_t sum_f32 = vcvt_f32_f16(vget_low_f16(acc1));
85+
sum_f32 = vaddq_f32(sum_f32, vcvt_f32_f16(vget_high_f16(acc1)));
86+
87+
// Pairwise add to get horizontal sum
88+
float32x2_t sum_2 = vadd_f32(vget_low_f32(sum_f32), vget_high_f32(sum_f32));
89+
sum_2 = vpadd_f32(sum_2, sum_2);
90+
91+
// Extract result
92+
return 1.0f - vget_lane_f32(sum_2, 0);
93+
}

src/VecSim/spaces/IP/IP_SVE_BF16.h

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
*Copyright Redis Ltd. 2021 - present
3+
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
4+
*the Server Side Public License v1 (SSPLv1).
5+
*/
6+
7+
#include <arm_sve.h>
8+
9+
inline void InnerProduct_Step(const bfloat16_t *vec1, const bfloat16_t *vec2, svfloat32_t &acc,
10+
size_t &offset, const size_t chunk) {
11+
svbool_t all = svptrue_b16();
12+
13+
// Load brain-half-precision vectors.
14+
svbfloat16_t v1 = svld1_bf16(all, vec1 + offset);
15+
svbfloat16_t v2 = svld1_bf16(all, vec2 + offset);
16+
// Compute multiplications and add to the accumulator
17+
acc = svbfdot(acc, v1, v2);
18+
19+
// Move to next chunk
20+
offset += chunk;
21+
}
22+
23+
template <bool partial_chunk, unsigned char additional_steps> // [t/f, 0..3]
24+
float BF16_InnerProduct_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) {
25+
const auto *vec1 = static_cast<const bfloat16_t *>(pVect1v);
26+
const auto *vec2 = static_cast<const bfloat16_t *>(pVect2v);
27+
const size_t chunk = svcnth(); // number of 16-bit elements in a register
28+
svfloat32_t acc1 = svdup_f32(0.0f);
29+
svfloat32_t acc2 = svdup_f32(0.0f);
30+
svfloat32_t acc3 = svdup_f32(0.0f);
31+
svfloat32_t acc4 = svdup_f32(0.0f);
32+
size_t offset = 0;
33+
34+
// Process all full vectors
35+
const size_t full_iterations = dimension / chunk / 4;
36+
for (size_t iter = 0; iter < full_iterations; iter++) {
37+
InnerProduct_Step(vec1, vec2, acc1, offset, chunk);
38+
InnerProduct_Step(vec1, vec2, acc2, offset, chunk);
39+
InnerProduct_Step(vec1, vec2, acc3, offset, chunk);
40+
InnerProduct_Step(vec1, vec2, acc4, offset, chunk);
41+
}
42+
43+
// Perform between 0 and 3 additional steps, according to `additional_steps` value
44+
if constexpr (additional_steps >= 1)
45+
InnerProduct_Step(vec1, vec2, acc1, offset, chunk);
46+
if constexpr (additional_steps >= 2)
47+
InnerProduct_Step(vec1, vec2, acc2, offset, chunk);
48+
if constexpr (additional_steps >= 3)
49+
InnerProduct_Step(vec1, vec2, acc3, offset, chunk);
50+
51+
// Handle the tail with the residual predicate
52+
if constexpr (partial_chunk) {
53+
svbool_t pg = svwhilelt_b16_u64(offset, dimension);
54+
55+
// Load brain-half-precision vectors.
56+
// Inactive elements are zeros, according to the docs
57+
svbfloat16_t v1 = svld1_bf16(pg, vec1 + offset);
58+
svbfloat16_t v2 = svld1_bf16(pg, vec2 + offset);
59+
// Compute multiplications and add to the accumulator.
60+
acc4 = svbfdot(acc4, v1, v2);
61+
}
62+
63+
// Accumulate accumulators
64+
acc1 = svadd_f32_x(svptrue_b32(), acc1, acc3);
65+
acc2 = svadd_f32_x(svptrue_b32(), acc2, acc4);
66+
acc1 = svadd_f32_x(svptrue_b32(), acc1, acc2);
67+
68+
// Reduce the accumulated sum.
69+
float result = svaddv_f32(svptrue_b32(), acc1);
70+
return 1.0f - result;
71+
}

0 commit comments

Comments
 (0)