Skip to content

Commit c9f4c8e

Browse files
ldemattencordon
authored andcommitted
[SIMD] Move native/vec code to C++ (elastic#138525)
Native code in the simdvec library is currently a mix of C and C++ code. We found that C++ templates are helpful to reduce source code duplication while retaining a great (sometime even greater) code inlining and expansion (e.g. loop unrolling) that we use to maximize performance. This PR moves all existing C code to C++; in general, it's just a matter of renaming + disabling name mangling, plus renaming of the exported functions. This last operation is needed as they name clash with the "extended" math functions (e.g. cosf32 which is the cosine function over float32_t types). Now all the exported symbols have a vec_ prefix.
1 parent 722b209 commit c9f4c8e

File tree

8 files changed

+72
-76
lines changed

8 files changed

+72
-76
lines changed

libs/native/libraries/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ configurations {
1919
}
2020

2121
var zstdVersion = "1.5.5"
22-
var vecVersion = "1.0.16"
22+
var vecVersion = "1.0.17"
2323

2424
repositories {
2525
exclusiveContent {

libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,63 +50,63 @@ public final class JdkVectorLibrary implements VectorLibrary {
5050
if (caps > 0) {
5151
if (caps == 2) {
5252
dot7u$mh = downcallHandle(
53-
"dot7u_2",
53+
"vec_dot7u_2",
5454
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
5555
LinkerHelperUtil.critical()
5656
);
5757
dot7uBulk$mh = downcallHandle(
58-
"dot7u_bulk_2",
58+
"vec_dot7u_bulk_2",
5959
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS),
6060
LinkerHelperUtil.critical()
6161
);
6262
sqr7u$mh = downcallHandle(
63-
"sqr7u_2",
63+
"vec_sqr7u_2",
6464
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
6565
LinkerHelperUtil.critical()
6666
);
6767
cosf32$mh = downcallHandle(
68-
"cosf32_2",
68+
"vec_cosf32_2",
6969
FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT),
7070
LinkerHelperUtil.critical()
7171
);
7272
dotf32$mh = downcallHandle(
73-
"dotf32_2",
73+
"vec_dotf32_2",
7474
FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT),
7575
LinkerHelperUtil.critical()
7676
);
7777
sqrf32$mh = downcallHandle(
78-
"sqrf32_2",
78+
"vec_sqrf32_2",
7979
FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT),
8080
LinkerHelperUtil.critical()
8181
);
8282
} else {
8383
dot7u$mh = downcallHandle(
84-
"dot7u",
84+
"vec_dot7u",
8585
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
8686
LinkerHelperUtil.critical()
8787
);
8888
dot7uBulk$mh = downcallHandle(
89-
"dot7u_bulk",
89+
"vec_dot7u_bulk",
9090
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS),
9191
LinkerHelperUtil.critical()
9292
);
9393
sqr7u$mh = downcallHandle(
94-
"sqr7u",
94+
"vec_sqr7u",
9595
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
9696
LinkerHelperUtil.critical()
9797
);
9898
cosf32$mh = downcallHandle(
99-
"cosf32",
99+
"vec_cosf32",
100100
FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT),
101101
LinkerHelperUtil.critical()
102102
);
103103
dotf32$mh = downcallHandle(
104-
"dotf32",
104+
"vec_dotf32",
105105
FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT),
106106
LinkerHelperUtil.critical()
107107
);
108108
sqrf32$mh = downcallHandle(
109-
"sqrf32",
109+
"vec_sqrf32",
110110
FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT),
111111
LinkerHelperUtil.critical()
112112
);

libs/simdvec/native/build.gradle

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ var os = org.gradle.internal.os.OperatingSystem.current()
2929
// objdump --disassemble-symbols=_dot7u build/libs/vec/shared/aarch64/libvec.dylib
3030
// Note: symbol decoration may differ on Linux, i.e. the leading underscore is not present
3131
//
32-
// gcc -shared -fpic -o libvec.so -I src/vec/headers/ src/vec/c/vec.c -O3
32+
// g++ -shared -fpic -o libvec.so -I src/vec/headers/ src/vec/c/vec.c -O3
3333

3434
group = 'org.elasticsearch'
3535

@@ -47,12 +47,10 @@ model {
4747
toolChains {
4848
gcc(Gcc) {
4949
target("aarch64") {
50-
cCompiler.executable = "/usr/bin/gcc"
51-
cCompiler.withArguments { args -> args.addAll(["-O3", "-std=c11", "-march=armv8-a"]) }
50+
cppCompiler.executable = "/usr/bin/g++"
51+
cppCompiler.withArguments { args -> args.addAll(["-O3", "-march=armv8-a"]) }
5252
}
5353
target("amd64") {
54-
cCompiler.executable = "/usr/bin/gcc"
55-
cCompiler.withArguments { args -> args.addAll(["-O3", "-std=c11", "-march=core-avx2", "-Wno-incompatible-pointer-types"]) }
5654
cppCompiler.executable = "/usr/bin/g++"
5755
cppCompiler.withArguments { args -> args.addAll(["-O3", "-march=core-avx2"]) }
5856
}
@@ -61,17 +59,16 @@ model {
6159
eachPlatform { toolchain ->
6260
def platform = toolchain.getPlatform()
6361
if (platform.name == "x64") {
64-
cCompiler.withArguments { args -> args.addAll(["/O2", "/LD", "-march=core-avx2"]) }
62+
cppCompiler.withArguments { args -> args.addAll(["/O2", "/LD", "-march=core-avx2"]) }
6563
}
6664
}
6765
}
6866
clang(Clang) {
6967
target("aarch64") {
70-
cCompiler.withArguments { args -> args.addAll(["-O3", "-std=c11", "-march=armv8-a"]) }
68+
cppCompiler.withArguments { args -> args.addAll(["-O3", "-march=armv8-a"]) }
7169
}
7270

7371
target("amd64") {
74-
cCompiler.withArguments { args -> args.addAll(["-O3", "-std=c11", "-march=core-avx2"]) }
7572
cppCompiler.withArguments { args -> args.addAll(["-O3", "-march=core-avx2"]) }
7673
}
7774
}

libs/simdvec/native/publish_vec_binaries.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then
2020
exit 1;
2121
fi
2222

23-
VERSION="1.0.16"
23+
VERSION="1.0.17"
2424
ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}"
2525
TEMP=$(mktemp -d)
2626

libs/simdvec/native/src/vec/c/aarch64/vec.c renamed to libs/simdvec/native/src/vec/c/aarch64/vec_1.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
* License v3.0 only", or the "Server Side Public License, v 1".
88
*/
99

10+
// This file contains implementations for basic vector processing functionalities,
11+
// including support for "1st tier" vector capabilities; in the case of ARM,
12+
// this first tier include functions for processors supporting at least the NEON
13+
// instruction set.
14+
1015
#include <stddef.h>
1116
#include <arm_neon.h>
1217
#include <math.h>
@@ -48,7 +53,7 @@ EXPORT int vec_caps() {
4853
#endif
4954
}
5055

51-
static inline int32_t dot7u_inner(int8_t* a, int8_t* b, const int32_t dims) {
56+
static inline int32_t dot7u_inner(const int8_t* a, const int8_t* b, const int32_t dims) {
5257
// We have contention in the instruction pipeline on the accumulation
5358
// registers if we use too few.
5459
int32x4_t acc1 = vdupq_n_s32(0);
@@ -82,7 +87,7 @@ static inline int32_t dot7u_inner(int8_t* a, int8_t* b, const int32_t dims) {
8287
return vaddvq_s32(vaddq_s32(acc5, acc6));
8388
}
8489

85-
EXPORT int32_t dot7u(int8_t* a, int8_t* b, const int32_t dims) {
90+
EXPORT int32_t vec_dot7u(int8_t* a, int8_t* b, const int32_t dims) {
8691
int32_t res = 0;
8792
int i = 0;
8893
if (dims > DOT7U_STRIDE_BYTES_LEN) {
@@ -95,7 +100,7 @@ EXPORT int32_t dot7u(int8_t* a, int8_t* b, const int32_t dims) {
95100
return res;
96101
}
97102

98-
EXPORT void dot7u_bulk(int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, float_t* results) {
103+
EXPORT void vec_dot7u_bulk(int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, f32_t* results) {
99104
int32_t res = 0;
100105
if (dims > DOT7U_STRIDE_BYTES_LEN) {
101106
const int limit = dims & ~(DOT7U_STRIDE_BYTES_LEN - 1);
@@ -105,7 +110,7 @@ EXPORT void dot7u_bulk(int8_t* a, const int8_t* b, const int32_t dims, const int
105110
for (; i < dims; i++) {
106111
res += a[i] * b[i];
107112
}
108-
results[c] = (float_t)res;
113+
results[c] = (f32_t)res;
109114
a += dims;
110115
}
111116
} else {
@@ -114,7 +119,7 @@ EXPORT void dot7u_bulk(int8_t* a, const int8_t* b, const int32_t dims, const int
114119
for (int32_t i = 0; i < dims; i++) {
115120
res += a[i] * b[i];
116121
}
117-
results[c] = (float_t)res;
122+
results[c] = (f32_t)res;
118123
a += dims;
119124
}
120125
}
@@ -145,7 +150,7 @@ static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, const int32_t dims) {
145150
return vaddvq_s32(vaddq_s32(acc5, acc6));
146151
}
147152

148-
EXPORT int32_t sqr7u(int8_t* a, int8_t* b, const int32_t dims) {
153+
EXPORT int32_t vec_sqr7u(int8_t* a, int8_t* b, const int32_t dims) {
149154
int32_t res = 0;
150155
int i = 0;
151156
if (dims > SQR7U_STRIDE_BYTES_LEN) {
@@ -164,7 +169,7 @@ EXPORT int32_t sqr7u(int8_t* a, int8_t* b, const int32_t dims) {
164169
// const f32_t *a pointer to the first float vector
165170
// const f32_t *b pointer to the second float vector
166171
// const int32_t elementCount the number of floating point elements
167-
EXPORT f32_t dotf32(const f32_t *a, const f32_t *b, const int32_t elementCount) {
172+
EXPORT f32_t vec_dotf32(const f32_t *a, const f32_t *b, const int32_t elementCount) {
168173
float32x4_t sum0 = vdupq_n_f32(0.0f);
169174
float32x4_t sum1 = vdupq_n_f32(0.0f);
170175
float32x4_t sum2 = vdupq_n_f32(0.0f);
@@ -205,7 +210,7 @@ EXPORT f32_t dotf32(const f32_t *a, const f32_t *b, const int32_t elementCount)
205210
// const f32_t *a pointer to the first float vector
206211
// const f32_t *b pointer to the second float vector
207212
// const int32_t elementCount the number of floating point elements
208-
EXPORT f32_t cosf32(const f32_t *a, const f32_t *b, const int32_t elementCount) {
213+
EXPORT f32_t vec_cosf32(const f32_t *a, const f32_t *b, const int32_t elementCount) {
209214
float32x4_t sum0 = vdupq_n_f32(0.0f);
210215
float32x4_t sum1 = vdupq_n_f32(0.0f);
211216
float32x4_t sum2 = vdupq_n_f32(0.0f);
@@ -277,7 +282,7 @@ EXPORT f32_t cosf32(const f32_t *a, const f32_t *b, const int32_t elementCount)
277282
return dot / denom;
278283
}
279284

280-
EXPORT f32_t sqrf32(const f32_t *a, const f32_t *b, const int32_t elementCount) {
285+
EXPORT f32_t vec_sqrf32(const f32_t *a, const f32_t *b, const int32_t elementCount) {
281286
float32x4_t sum0 = vdupq_n_f32(0.0f);
282287
float32x4_t sum1 = vdupq_n_f32(0.0f);
283288
float32x4_t sum2 = vdupq_n_f32(0.0f);

libs/simdvec/native/src/vec/c/amd64/vec.c renamed to libs/simdvec/native/src/vec/c/amd64/vec_1.cpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
* License v3.0 only", or the "Server Side Public License, v 1".
88
*/
99

10+
// This file contains implementations for basic vector processing functionalities,
11+
// including support for "1st tier" vector capabilities; in the case of x64,
12+
// this first tier include functions for processors supporting at least AVX2.
13+
1014
#include <stddef.h>
1115
#include <stdint.h>
1216
#include <math.h>
@@ -116,7 +120,7 @@ EXPORT int vec_caps() {
116120
return 0;
117121
}
118122

119-
static inline int32_t dot7u_inner(int8_t* a, int8_t* b, const int32_t dims) {
123+
static inline int32_t dot7u_inner(const int8_t* a, const int8_t* b, const int32_t dims) {
120124
const __m256i ones = _mm256_set1_epi16(1);
121125

122126
// Init accumulator(s) with 0
@@ -125,8 +129,8 @@ static inline int32_t dot7u_inner(int8_t* a, int8_t* b, const int32_t dims) {
125129
#pragma GCC unroll 4
126130
for(int i = 0; i < dims; i += STRIDE_BYTES_LEN) {
127131
// Load packed 8-bit integers
128-
__m256i va1 = _mm256_loadu_si256(a + i);
129-
__m256i vb1 = _mm256_loadu_si256(b + i);
132+
__m256i va1 = _mm256_loadu_si256((const __m256i_u *)(a + i));
133+
__m256i vb1 = _mm256_loadu_si256((const __m256i_u *)(b + i));
130134

131135
// Perform multiplication and create 16-bit values
132136
// Vertically multiply each unsigned 8-bit integer from va with the corresponding
@@ -140,7 +144,7 @@ static inline int32_t dot7u_inner(int8_t* a, int8_t* b, const int32_t dims) {
140144
return hsum_i32_8(acc1);
141145
}
142146

143-
EXPORT int32_t dot7u(int8_t* a, int8_t* b, const int32_t dims) {
147+
EXPORT int32_t vec_dot7u(int8_t* a, int8_t* b, const int32_t dims) {
144148
int32_t res = 0;
145149
int i = 0;
146150
if (dims > STRIDE_BYTES_LEN) {
@@ -153,7 +157,7 @@ EXPORT int32_t dot7u(int8_t* a, int8_t* b, const int32_t dims) {
153157
return res;
154158
}
155159

156-
EXPORT void dot7u_bulk(int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, float_t* results) {
160+
EXPORT void vec_dot7u_bulk(int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, f32_t* results) {
157161
int32_t res = 0;
158162
if (dims > STRIDE_BYTES_LEN) {
159163
const int limit = dims & ~(STRIDE_BYTES_LEN - 1);
@@ -163,7 +167,7 @@ EXPORT void dot7u_bulk(int8_t* a, const int8_t* b, const int32_t dims, const int
163167
for (; i < dims; i++) {
164168
res += a[i] * b[i];
165169
}
166-
results[c] = (float_t)res;
170+
results[c] = (f32_t)res;
167171
a += dims;
168172
}
169173
} else {
@@ -172,7 +176,7 @@ EXPORT void dot7u_bulk(int8_t* a, const int8_t* b, const int32_t dims, const int
172176
for (int32_t i = 0; i < dims; i++) {
173177
res += a[i] * b[i];
174178
}
175-
results[c] = (float_t)res;
179+
results[c] = (f32_t)res;
176180
a += dims;
177181
}
178182
}
@@ -187,8 +191,8 @@ static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, const int32_t dims) {
187191
#pragma GCC unroll 4
188192
for(int i = 0; i < dims; i += STRIDE_BYTES_LEN) {
189193
// Load packed 8-bit integers
190-
__m256i va1 = _mm256_loadu_si256(a + i);
191-
__m256i vb1 = _mm256_loadu_si256(b + i);
194+
__m256i va1 = _mm256_loadu_si256((const __m256i_u *)(a + i));
195+
__m256i vb1 = _mm256_loadu_si256((const __m256i_u *)(b + i));
192196

193197
const __m256i dist1 = _mm256_sub_epi8(va1, vb1);
194198
const __m256i abs_dist1 = _mm256_sign_epi8(dist1, dist1);
@@ -200,7 +204,7 @@ static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, const int32_t dims) {
200204
return hsum_i32_8(acc1);
201205
}
202206

203-
EXPORT int32_t sqr7u(int8_t* a, int8_t* b, const int32_t dims) {
207+
EXPORT int32_t vec_sqr7u(int8_t* a, int8_t* b, const int32_t dims) {
204208
int32_t res = 0;
205209
int i = 0;
206210
if (dims > STRIDE_BYTES_LEN) {
@@ -236,7 +240,7 @@ static inline f32_t hsum_f32_8(const __m256 v) {
236240
// const f32_t *a pointer to the first float vector
237241
// const f32_t *b pointer to the second float vector
238242
// const int32_t elementCount the number of floating point elements
239-
EXPORT f32_t cosf32(const f32_t *a, const f32_t *b, const int32_t elementCount) {
243+
EXPORT f32_t vec_cosf32(const f32_t *a, const f32_t *b, const int32_t elementCount) {
240244
__m256 dot0 = _mm256_setzero_ps();
241245
__m256 dot1 = _mm256_setzero_ps();
242246
__m256 dot2 = _mm256_setzero_ps();
@@ -309,7 +313,7 @@ EXPORT f32_t cosf32(const f32_t *a, const f32_t *b, const int32_t elementCount)
309313
// const f32_t *a pointer to the first float vector
310314
// const f32_t *b pointer to the second float vector
311315
// const int32_t elementCount the number of floating point elements
312-
EXPORT f32_t dotf32(const f32_t *a, const f32_t *b, const int32_t elementCount) {
316+
EXPORT f32_t vec_dotf32(const f32_t *a, const f32_t *b, const int32_t elementCount) {
313317
__m256 acc0 = _mm256_setzero_ps();
314318
__m256 acc1 = _mm256_setzero_ps();
315319
__m256 acc2 = _mm256_setzero_ps();
@@ -339,7 +343,7 @@ EXPORT f32_t dotf32(const f32_t *a, const f32_t *b, const int32_t elementCount)
339343
// const f32_t *a pointer to the first float vector
340344
// const f32_t *b pointer to the second float vector
341345
// const int32_t elementCount the number of floating point elements
342-
EXPORT f32_t sqrf32(const f32_t *a, const f32_t *b, const int32_t elementCount) {
346+
EXPORT f32_t vec_sqrf32(const f32_t *a, const f32_t *b, const int32_t elementCount) {
343347
__m256 sum0 = _mm256_setzero_ps();
344348
__m256 sum1 = _mm256_setzero_ps();
345349
__m256 sum2 = _mm256_setzero_ps();

0 commit comments

Comments
 (0)