Skip to content

Commit e644a3d

Browse files
lerman25dor-forer
andauthored
Add SVE/SVE2 support for uint8 and int8 data type [MOD-9080] (#619)
* Add arm support * Changed the arm cpu info * Add ip test * Add to tests * Added tests andbm * fix tests * Add github benchmakrs * Check 1 * only arm * change ami * Try ireland * Try different image * try image * back to old image * larger image * Add option to change env * back to default region * Created new image * Try to add the x86 to check * Try different machine * added include * Try without opti on arm * Change to c6g * added matrix region * change to west * try the i8 * Try oregon * Change subnet id * Now subnet * Change subnet * add subnet * Try group id * Change to vpc id * change subnet * Change ami * Try without subnet * add security group again * Change the subnets * Change to ids * Change sg * psubnet * Try different * different * to a file * print * p * leave empty * empty * Try different account * Run 2 arm machines * Move both to us-west-2 * Try workflow * Change name * Changes * Change the secrets * Add supprted arch * Add defaults * Support all * Change the jq * Change machine to t4g * Change the name * Change the machine * fix the stop * only benchamrk * add the secrets * region secret * benchmark region * Change timeout * Added support for arch name in benchamrks * change th json * changed to v9.0 * Change the check * add v9 * Check alt version of armv9 * added check * add arc_arch * changed to CONCAT_WITH_UNDERSCORE_ARCH * change the check * Add full check * fix the instruct * Added the cmake * fix the support * put it back to cmake * back * change the condition * No armpl for now * cland format * remove the opt * Changed to one machine * Added BENCHMARK_ARCH * fix endif * Remove secrets call * pr changes * Changes * change to compile * add sve * add #endif * add armpl * add to cmake * remove armpl * add install * Add ARCH=$(uname -m) * change the path to armpl * suuport check for armv7 * change the armpl * Change or OR * add neon supported for spaces * add sve * add support * align * format * change error * change * Removed the ifdef * Add comments * clang * Change names * format * Try fp32 neon simd * add l2 * add cmake * add SVE * fix sve l2 * PR changes * Change to 1 * fix the l2 * fix format * add desciriopn for chunk == 1 * Change functions * Add include * Change the cast * add resudual * formatting * Move th consexpt * remove template armpl * Back to armpl * back to armpl_neon * include * armnpl * add choose * fix the residual div * raise the residuals values * back to char * Remove prefetch * Revert implemetion chooser * Remove armpl * Revert remove error * Remove comment * Remove empty line * format * Add support macos * add sudo * Add absolute path * find all libs * Change folder * Now set for real * Remove armpl from pull * change the templates * change chunk size to 1 * Back to 4 * Removed the for * Change to 2 sums * SVE L2 * Changed * Add get opt func * Change the var name * format * Pr fixes * PR * SVE IP , SVE2 IP & L2 * UINT8 support, remove int8_ip_sve * format * pr * pr fix * bm_spaces * PR * added conversion * small dim for intel only * Test smallDimChooser only for intel * align offset * align const expression * align cpu features function * format * change to svadd_f32_x where possible * change to _x where possible * move low dim check to intel only * format * fix IP * Optimize, convert on final step * format * chunking * change to inline * format * guy's comments * fix unit_test * format * reinterpet comment * using dot * fix uint8 * SVE2 -> SVE * for mat * fix comments * format :( * illegal --------- Co-authored-by: Dor Forer <dor.forer@redis.com>
1 parent 1e08ea4 commit e644a3d

File tree

13 files changed

+831
-11
lines changed

13 files changed

+831
-11
lines changed

src/VecSim/spaces/IP/IP_SVE_INT8.h

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
*Copyright Redis Ltd. 2021 - present
3+
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
4+
*the Server Side Public License v1 (SSPLv1).
5+
*/
6+
#pragma once
7+
#include "VecSim/spaces/space_includes.h"
8+
#include <arm_sve.h>
9+
10+
inline void InnerProductStep(const int8_t *&pVect1, const int8_t *&pVect2, size_t &offset,
11+
svint32_t &sum, const size_t chunk) {
12+
svbool_t pg = svptrue_b8();
13+
14+
// Load int8 vectors
15+
svint8_t v1_i8 = svld1_s8(pg, pVect1 + offset);
16+
svint8_t v2_i8 = svld1_s8(pg, pVect2 + offset);
17+
18+
sum = svdot_s32(sum, v1_i8, v2_i8);
19+
20+
offset += chunk; // Move to the next set of int8 elements
21+
}
22+
23+
template <bool partial_chunk, unsigned char additional_steps>
24+
float INT8_InnerProductImp(const void *pVect1v, const void *pVect2v, size_t dimension) {
25+
const int8_t *pVect1 = reinterpret_cast<const int8_t *>(pVect1v);
26+
const int8_t *pVect2 = reinterpret_cast<const int8_t *>(pVect2v);
27+
28+
size_t offset = 0;
29+
const size_t vl = svcntb();
30+
const size_t chunk_size = 4 * vl;
31+
32+
// Each innerProductStep adds maximum 2^8 & 2^8 = 2^16
33+
// Therefore, on a single accumulator, we can perform 2^15 steps before overflowing
34+
// That scenario will happen only is the dimension of the vector is larger than 16*4*2^15 = 2^21
35+
// (16 int8 in 1 SVE register) * (4 accumulators) * (2^15 steps)
36+
// We can safely assume that the dimension is smaller than that
37+
// So using int32_t is safe
38+
39+
svint32_t sum0 = svdup_s32(0);
40+
svint32_t sum1 = svdup_s32(0);
41+
svint32_t sum2 = svdup_s32(0);
42+
svint32_t sum3 = svdup_s32(0);
43+
44+
size_t num_chunks = dimension / chunk_size;
45+
46+
for (size_t i = 0; i < num_chunks; ++i) {
47+
InnerProductStep(pVect1, pVect2, offset, sum0, vl);
48+
InnerProductStep(pVect1, pVect2, offset, sum1, vl);
49+
InnerProductStep(pVect1, pVect2, offset, sum2, vl);
50+
InnerProductStep(pVect1, pVect2, offset, sum3, vl);
51+
}
52+
53+
// Process remaining complete SVE vectors that didn't fit into the main loop
54+
// These are full vector operations (0-3 elements)
55+
if constexpr (additional_steps > 0) {
56+
if constexpr (additional_steps >= 1) {
57+
InnerProductStep(pVect1, pVect2, offset, sum0, vl);
58+
}
59+
if constexpr (additional_steps >= 2) {
60+
InnerProductStep(pVect1, pVect2, offset, sum1, vl);
61+
}
62+
if constexpr (additional_steps >= 3) {
63+
InnerProductStep(pVect1, pVect2, offset, sum2, vl);
64+
}
65+
}
66+
67+
if constexpr (partial_chunk) {
68+
svbool_t pg = svwhilelt_b8_u64(offset, dimension);
69+
70+
svint8_t v1_i8 = svld1_s8(pg, pVect1 + offset); // Load int8 vectors
71+
svint8_t v2_i8 = svld1_s8(pg, pVect2 + offset); // Load int8 vectors
72+
73+
sum3 = svdot_s32(sum3, v1_i8, v2_i8);
74+
75+
pVect1 += vl;
76+
pVect2 += vl;
77+
}
78+
79+
sum0 = svadd_s32_x(svptrue_b32(), sum0, sum1);
80+
sum2 = svadd_s32_x(svptrue_b32(), sum2, sum3);
81+
82+
// Perform vector addition in parallel and Horizontal sum
83+
int32_t sum_all = svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sum0, sum2));
84+
85+
return sum_all;
86+
}
87+
88+
template <bool partial_chunk, unsigned char additional_steps>
89+
float INT8_InnerProductSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) {
90+
return 1.0f -
91+
INT8_InnerProductImp<partial_chunk, additional_steps>(pVect1v, pVect2v, dimension);
92+
}
93+
94+
template <bool partial_chunk, unsigned char additional_steps>
95+
float INT8_CosineSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) {
96+
float ip = INT8_InnerProductImp<partial_chunk, additional_steps>(pVect1v, pVect2v, dimension);
97+
float norm_v1 =
98+
*reinterpret_cast<const float *>(static_cast<const int8_t *>(pVect1v) + dimension);
99+
float norm_v2 =
100+
*reinterpret_cast<const float *>(static_cast<const int8_t *>(pVect2v) + dimension);
101+
return 1.0f - ip / (norm_v1 * norm_v2);
102+
}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
*Copyright Redis Ltd. 2021 - present
3+
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
4+
*the Server Side Public License v1 (SSPLv1).
5+
*/
6+
#pragma once
7+
#include "VecSim/spaces/space_includes.h"
8+
#include <arm_sve.h>
9+
10+
inline void InnerProductStep(const uint8_t *&pVect1, const uint8_t *&pVect2, size_t &offset,
11+
svuint32_t &sum, const size_t chunk) {
12+
svbool_t pg = svptrue_b8();
13+
14+
// Load uint8 vectors
15+
svuint8_t v1_ui8 = svld1_u8(pg, pVect1 + offset);
16+
svuint8_t v2_ui8 = svld1_u8(pg, pVect2 + offset);
17+
18+
sum = svdot_u32(sum, v1_ui8, v2_ui8);
19+
20+
offset += chunk; // Move to the next set of uint8 elements
21+
}
22+
23+
template <bool partial_chunk, unsigned char additional_steps>
24+
float UINT8_InnerProductImp(const void *pVect1v, const void *pVect2v, size_t dimension) {
25+
const uint8_t *pVect1 = reinterpret_cast<const uint8_t *>(pVect1v);
26+
const uint8_t *pVect2 = reinterpret_cast<const uint8_t *>(pVect2v);
27+
28+
size_t offset = 0;
29+
const size_t vl = svcntb();
30+
const size_t chunk_size = 4 * vl;
31+
32+
// Each innerProductStep adds maximum 2^8 & 2^8 = 2^16
33+
// Therefore, on a single accumulator, we can perform 2^16 steps before overflowing
34+
// That scenario will happen only is the dimension of the vector is larger than 16*4*2^16 = 2^22
35+
// (16 uint8 in 1 SVE register) * (4 accumulators) * (2^16 steps)
36+
// We can safely assume that the dimension is smaller than that
37+
// So using int32_t is safe
38+
39+
svuint32_t sum0 = svdup_u32(0);
40+
svuint32_t sum1 = svdup_u32(0);
41+
svuint32_t sum2 = svdup_u32(0);
42+
svuint32_t sum3 = svdup_u32(0);
43+
44+
size_t num_chunks = dimension / chunk_size;
45+
46+
for (size_t i = 0; i < num_chunks; ++i) {
47+
InnerProductStep(pVect1, pVect2, offset, sum0, vl);
48+
InnerProductStep(pVect1, pVect2, offset, sum1, vl);
49+
InnerProductStep(pVect1, pVect2, offset, sum2, vl);
50+
InnerProductStep(pVect1, pVect2, offset, sum3, vl);
51+
}
52+
53+
// Process remaining complete SVE vectors that didn't fit into the main loop
54+
// These are full vector operations (0-3 elements)
55+
if constexpr (additional_steps > 0) {
56+
if constexpr (additional_steps >= 1) {
57+
InnerProductStep(pVect1, pVect2, offset, sum0, vl);
58+
}
59+
if constexpr (additional_steps >= 2) {
60+
InnerProductStep(pVect1, pVect2, offset, sum1, vl);
61+
}
62+
if constexpr (additional_steps >= 3) {
63+
InnerProductStep(pVect1, pVect2, offset, sum2, vl);
64+
}
65+
}
66+
67+
if constexpr (partial_chunk) {
68+
svbool_t pg = svwhilelt_b8_u64(offset, dimension);
69+
70+
svuint8_t v1_ui8 = svld1_u8(pg, pVect1 + offset); // Load uint8 vectors
71+
svuint8_t v2_ui8 = svld1_u8(pg, pVect2 + offset); // Load uint8 vectors
72+
73+
sum3 = svdot_u32(sum3, v1_ui8, v2_ui8);
74+
75+
pVect1 += vl;
76+
pVect2 += vl;
77+
}
78+
79+
sum0 = svadd_u32_x(svptrue_b32(), sum0, sum1);
80+
sum2 = svadd_u32_x(svptrue_b32(), sum2, sum3);
81+
82+
// Perform vector addition in parallel and Horizontal sum
83+
int32_t sum_all = svaddv_u32(svptrue_b32(), svadd_u32_x(svptrue_b32(), sum0, sum2));
84+
85+
return sum_all;
86+
}
87+
88+
template <bool partial_chunk, unsigned char additional_steps>
89+
float UINT8_InnerProductSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) {
90+
return 1.0f -
91+
UINT8_InnerProductImp<partial_chunk, additional_steps>(pVect1v, pVect2v, dimension);
92+
}
93+
94+
template <bool partial_chunk, unsigned char additional_steps>
95+
float UINT8_CosineSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) {
96+
float ip = UINT8_InnerProductImp<partial_chunk, additional_steps>(pVect1v, pVect2v, dimension);
97+
float norm_v1 =
98+
*reinterpret_cast<const float *>(static_cast<const uint8_t *>(pVect1v) + dimension);
99+
float norm_v2 =
100+
*reinterpret_cast<const float *>(static_cast<const uint8_t *>(pVect2v) + dimension);
101+
return 1.0f - ip / (norm_v1 * norm_v2);
102+
}

src/VecSim/spaces/IP_space.cpp

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -276,12 +276,27 @@ dist_func_t<float> IP_INT8_GetDistFunc(size_t dim, unsigned char *alignment, con
276276
}
277277

278278
dist_func_t<float> ret_dist_func = INT8_InnerProduct;
279+
280+
auto features = getCpuOptimizationFeatures(arch_opt);
281+
282+
#ifdef CPU_FEATURES_ARCH_AARCH64
283+
#ifdef OPT_SVE2
284+
if (features.sve2) {
285+
return Choose_INT8_IP_implementation_SVE2(dim);
286+
}
287+
#endif
288+
#ifdef OPT_SVE
289+
if (features.sve) {
290+
return Choose_INT8_IP_implementation_SVE(dim);
291+
}
292+
#endif
293+
#endif
294+
#ifdef CPU_FEATURES_ARCH_X86_64
279295
// Optimizations assume at least 32 int8. If we have less, we use the naive implementation.
280296
if (dim < 32) {
281297
return ret_dist_func;
282298
}
283-
#ifdef CPU_FEATURES_ARCH_X86_64
284-
auto features = getCpuOptimizationFeatures(arch_opt);
299+
285300
#ifdef OPT_AVX512_F_BW_VL_VNNI
286301
if (features.avx512f && features.avx512bw && features.avx512vl && features.avx512vnni) {
287302
if (dim % 32 == 0) // no point in aligning if we have an offsetting residual
@@ -301,12 +316,26 @@ dist_func_t<float> Cosine_INT8_GetDistFunc(size_t dim, unsigned char *alignment,
301316
}
302317

303318
dist_func_t<float> ret_dist_func = INT8_Cosine;
319+
320+
auto features = getCpuOptimizationFeatures(arch_opt);
321+
322+
#ifdef CPU_FEATURES_ARCH_AARCH64
323+
#ifdef OPT_SVE2
324+
if (features.sve2) {
325+
return Choose_INT8_Cosine_implementation_SVE2(dim);
326+
}
327+
#endif
328+
#ifdef OPT_SVE
329+
if (features.sve) {
330+
return Choose_INT8_Cosine_implementation_SVE(dim);
331+
}
332+
#endif
333+
#endif
334+
#ifdef CPU_FEATURES_ARCH_X86_64
304335
// Optimizations assume at least 32 int8. If we have less, we use the naive implementation.
305336
if (dim < 32) {
306337
return ret_dist_func;
307338
}
308-
#ifdef CPU_FEATURES_ARCH_X86_64
309-
auto features = getCpuOptimizationFeatures(arch_opt);
310339
#ifdef OPT_AVX512_F_BW_VL_VNNI
311340
if (features.avx512f && features.avx512bw && features.avx512vl && features.avx512vnni) {
312341
// For int8 vectors with cosine distance, the extra float for the norm shifts alignment to
@@ -329,12 +358,26 @@ dist_func_t<float> IP_UINT8_GetDistFunc(size_t dim, unsigned char *alignment,
329358
}
330359

331360
dist_func_t<float> ret_dist_func = UINT8_InnerProduct;
361+
362+
auto features = getCpuOptimizationFeatures(arch_opt);
363+
364+
#ifdef CPU_FEATURES_ARCH_AARCH64
365+
#ifdef OPT_SVE2
366+
if (features.sve2) {
367+
return Choose_UINT8_IP_implementation_SVE2(dim);
368+
}
369+
#endif
370+
#ifdef OPT_SVE
371+
if (features.sve) {
372+
return Choose_UINT8_IP_implementation_SVE(dim);
373+
}
374+
#endif
375+
#endif
376+
#ifdef CPU_FEATURES_ARCH_X86_64
332377
// Optimizations assume at least 32 uint8. If we have less, we use the naive implementation.
333378
if (dim < 32) {
334379
return ret_dist_func;
335380
}
336-
#ifdef CPU_FEATURES_ARCH_X86_64
337-
auto features = getCpuOptimizationFeatures(arch_opt);
338381
#ifdef OPT_AVX512_F_BW_VL_VNNI
339382
if (features.avx512f && features.avx512bw && features.avx512vl && features.avx512vnni) {
340383
if (dim % 32 == 0) // no point in aligning if we have an offsetting residual
@@ -354,12 +397,26 @@ dist_func_t<float> Cosine_UINT8_GetDistFunc(size_t dim, unsigned char *alignment
354397
}
355398

356399
dist_func_t<float> ret_dist_func = UINT8_Cosine;
400+
401+
auto features = getCpuOptimizationFeatures(arch_opt);
402+
403+
#ifdef CPU_FEATURES_ARCH_AARCH64
404+
#ifdef OPT_SVE2
405+
if (features.sve2) {
406+
return Choose_UINT8_Cosine_implementation_SVE2(dim);
407+
}
408+
#endif
409+
#ifdef OPT_SVE
410+
if (features.sve) {
411+
return Choose_UINT8_Cosine_implementation_SVE(dim);
412+
}
413+
#endif
414+
#endif
415+
#ifdef CPU_FEATURES_ARCH_X86_64
357416
// Optimizations assume at least 32 uint8. If we have less, we use the naive implementation.
358417
if (dim < 32) {
359418
return ret_dist_func;
360419
}
361-
#ifdef CPU_FEATURES_ARCH_X86_64
362-
auto features = getCpuOptimizationFeatures(arch_opt);
363420
#ifdef OPT_AVX512_F_BW_VL_VNNI
364421
if (features.avx512f && features.avx512bw && features.avx512vl && features.avx512vnni) {
365422
// For uint8 vectors with cosine distance, the extra float for the norm shifts alignment to

0 commit comments

Comments
 (0)