Skip to content

Commit 5bebc1f

Browse files
committed
Wire "bulk with offset" functions, c++ implementation, ARM optimization
1 parent 5dec80b commit 5bebc1f

File tree

10 files changed

+344
-64
lines changed

10 files changed

+344
-64
lines changed

libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ public interface VectorSimilarityFunctions {
4747
*/
4848
MethodHandle dotProductHandle7uBulk();
4949

50+
MethodHandle dotProductHandle7uBulkWithOffsets();
51+
5052
/**
5153
* Produces a method handle returning the square distance of byte (unsigned int7) vectors.
5254
*

libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public final class JdkVectorLibrary implements VectorLibrary {
3333

3434
static final MethodHandle dot7u$mh;
3535
static final MethodHandle dot7uBulk$mh;
36+
static final MethodHandle dot7uBulkWithOffsets$mh;
3637
static final MethodHandle sqr7u$mh;
3738
static final MethodHandle cosf32$mh;
3839
static final MethodHandle dotf32$mh;
@@ -59,6 +60,11 @@ public final class JdkVectorLibrary implements VectorLibrary {
5960
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS),
6061
LinkerHelperUtil.critical()
6162
);
63+
dot7uBulkWithOffsets$mh = downcallHandle(
64+
"dot7u_bulk_offsets_2",
65+
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS),
66+
LinkerHelperUtil.critical()
67+
);
6268
sqr7u$mh = downcallHandle(
6369
"sqr7u_2",
6470
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
@@ -90,6 +96,11 @@ public final class JdkVectorLibrary implements VectorLibrary {
9096
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS),
9197
LinkerHelperUtil.critical()
9298
);
99+
dot7uBulkWithOffsets$mh = downcallHandle(
100+
"dot7u_bulk_offsets",
101+
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS),
102+
LinkerHelperUtil.critical()
103+
);
93104
sqr7u$mh = downcallHandle(
94105
"sqr7u",
95106
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
@@ -120,6 +131,7 @@ public final class JdkVectorLibrary implements VectorLibrary {
120131
}
121132
dot7u$mh = null;
122133
dot7uBulk$mh = null;
134+
dot7uBulkWithOffsets$mh = null;
123135
sqr7u$mh = null;
124136
cosf32$mh = null;
125137
dotf32$mh = null;
@@ -161,6 +173,17 @@ static void dotProduct7uBulk(MemorySegment a, MemorySegment b, int length, int c
161173
dot7uBulk(a, b, length, count, result);
162174
}
163175

176+
static void dotProduct7uBulkWithOffsets(
177+
MemorySegment a,
178+
MemorySegment b,
179+
int length,
180+
MemorySegment offsets,
181+
int count,
182+
MemorySegment result
183+
) {
184+
dot7uBulkWithOffsets(a, b, length, offsets, count, result);
185+
}
186+
164187
/**
165188
* Computes the square distance of given unsigned int7 byte vectors.
166189
*
@@ -237,6 +260,21 @@ private static void dot7uBulk(MemorySegment a, MemorySegment b, int length, int
237260
}
238261
}
239262

263+
private static void dot7uBulkWithOffsets(
264+
MemorySegment a,
265+
MemorySegment b,
266+
int length,
267+
MemorySegment offsets,
268+
int count,
269+
MemorySegment result
270+
) {
271+
try {
272+
JdkVectorLibrary.dot7uBulkWithOffsets$mh.invokeExact(a, b, length, offsets, count, result);
273+
} catch (Throwable t) {
274+
throw new AssertionError(t);
275+
}
276+
}
277+
240278
private static int sqr7u(MemorySegment a, MemorySegment b, int length) {
241279
try {
242280
return (int) JdkVectorLibrary.sqr7u$mh.invokeExact(a, b, length);
@@ -271,6 +309,7 @@ private static float sqrf32(MemorySegment a, MemorySegment b, int length) {
271309

272310
static final MethodHandle DOT_HANDLE_7U;
273311
static final MethodHandle DOT_HANDLE_7U_BULK;
312+
static final MethodHandle DOT_HANDLE_7U_BULK_WITH_OFFSETS;
274313
static final MethodHandle SQR_HANDLE_7U;
275314
static final MethodHandle COS_HANDLE_FLOAT32;
276315
static final MethodHandle DOT_HANDLE_FLOAT32;
@@ -286,6 +325,17 @@ private static float sqrf32(MemorySegment a, MemorySegment b, int length) {
286325
mt = MethodType.methodType(void.class, MemorySegment.class, MemorySegment.class, int.class, int.class, MemorySegment.class);
287326
DOT_HANDLE_7U_BULK = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProduct7uBulk", mt);
288327

328+
mt = MethodType.methodType(
329+
void.class,
330+
MemorySegment.class,
331+
MemorySegment.class,
332+
int.class,
333+
MemorySegment.class,
334+
int.class,
335+
MemorySegment.class
336+
);
337+
DOT_HANDLE_7U_BULK_WITH_OFFSETS = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProduct7uBulkWithOffsets", mt);
338+
289339
mt = MethodType.methodType(float.class, MemorySegment.class, MemorySegment.class, int.class);
290340
COS_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "cosineF32", mt);
291341
DOT_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProductF32", mt);
@@ -305,6 +355,11 @@ public MethodHandle dotProductHandle7uBulk() {
305355
return DOT_HANDLE_7U_BULK;
306356
}
307357

358+
@Override
359+
public MethodHandle dotProductHandle7uBulkWithOffsets() {
360+
return DOT_HANDLE_7U_BULK_WITH_OFFSETS;
361+
}
362+
308363
@Override
309364
public MethodHandle squareDistanceHandle7u() {
310365
return SQR_HANDLE_7U;

libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt7uTests.java

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,42 @@ public void testInt7uBulk() {
118118
}
119119
}
120120

121+
public void testInt7uBulkWithOffsets() {
122+
assumeTrue(notSupportedMsg(), supported());
123+
final int dims = size;
124+
final int numVecs = randomIntBetween(2, 101);
125+
var offsets = new int[numVecs];
126+
var values = new byte[numVecs][dims];
127+
var segment = arena.allocate((long) dims * numVecs);
128+
for (int i = 0; i < numVecs; i++) {
129+
offsets[i] = randomInt(numVecs - 1);
130+
randomBytesBetween(values[i], MIN_INT7_VALUE, MAX_INT7_VALUE);
131+
MemorySegment.copy(MemorySegment.ofArray(values[i]), 0L, segment, (long) i * dims, dims);
132+
}
133+
int queryOrd = randomInt(numVecs - 1);
134+
float[] expectedScores = new float[numVecs];
135+
dotProductBulkWithOffsetsScalar(values[queryOrd], values, offsets, expectedScores);
136+
137+
var nativeQuerySeg = segment.asSlice((long) queryOrd * dims, dims);
138+
var bulkScoresSeg = arena.allocate((long) numVecs * Float.BYTES);
139+
var offsetsSeg = arena.allocate((long) numVecs * Integer.BYTES);
140+
dotProduct7uBulkWithOffsets(segment, nativeQuerySeg, dims, offsetsSeg, numVecs, bulkScoresSeg);
141+
assertScoresEquals(expectedScores, bulkScoresSeg);
142+
143+
if (supportsHeapSegments()) {
144+
float[] bulkScores = new float[numVecs];
145+
dotProduct7uBulkWithOffsets(
146+
segment,
147+
nativeQuerySeg,
148+
dims,
149+
MemorySegment.ofArray(offsets),
150+
numVecs,
151+
MemorySegment.ofArray(bulkScores)
152+
);
153+
assertArrayEquals(expectedScores, bulkScores, 0f);
154+
}
155+
}
156+
121157
public void testIllegalDims() {
122158
assumeTrue(notSupportedMsg(), supported());
123159
var segment = arena.allocate((long) size * 3);
@@ -203,6 +239,20 @@ void dotProduct7uBulk(MemorySegment a, MemorySegment b, int dims, int count, Mem
203239
}
204240
}
205241

242+
void dotProduct7uBulkWithOffsets(MemorySegment a, MemorySegment b, int dims, MemorySegment offsets, int count, MemorySegment result) {
243+
try {
244+
getVectorDistance().dotProductHandle7uBulkWithOffsets().invokeExact(a, b, dims, count, result);
245+
} catch (Throwable e) {
246+
if (e instanceof Error err) {
247+
throw err;
248+
} else if (e instanceof RuntimeException re) {
249+
throw re;
250+
} else {
251+
throw new RuntimeException(e);
252+
}
253+
}
254+
}
255+
206256
/** Computes the dot product of the given vectors a and b. */
207257
static int dotProductScalar(byte[] a, byte[] b) {
208258
int res = 0;
@@ -229,11 +279,16 @@ static void dotProductBulkScalar(byte[] query, byte[][] data, float[] scores) {
229279
}
230280
}
231281

282+
static void dotProductBulkWithOffsetsScalar(byte[] query, byte[][] data, int[] offsets, float[] scores) {
283+
for (int i = 0; i < data.length; i++) {
284+
scores[i] = dotProductScalar(query, data[offsets[i]]);
285+
}
286+
}
287+
232288
static void assertScoresEquals(float[] expectedScores, MemorySegment expectedScoresSeg) {
233289
assert expectedScores.length == (expectedScoresSeg.byteSize() / Float.BYTES);
234290
for (int i = 0; i < expectedScores.length; i++) {
235291
assertEquals(expectedScores[i], expectedScoresSeg.get(JAVA_FLOAT_UNALIGNED, i * Float.BYTES), 0f);
236292
}
237293
}
238-
239294
}

libs/simdvec/native/build.gradle

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ model {
4949
target("aarch64") {
5050
cCompiler.executable = "/usr/bin/gcc"
5151
cCompiler.withArguments { args -> args.addAll(["-O3", "-std=c11", "-march=armv8-a"]) }
52+
cppCompiler.executable = "/usr/bin/g++"
53+
cppCompiler.withArguments { args -> args.addAll(["-O3", "-march=armv8-a"]) }
5254
}
5355
target("amd64") {
5456
cCompiler.executable = "/usr/bin/gcc"
@@ -68,6 +70,7 @@ model {
6870
clang(Clang) {
6971
target("aarch64") {
7072
cCompiler.withArguments { args -> args.addAll(["-O3", "-std=c11", "-march=armv8-a"]) }
73+
cppCompiler.withArguments { args -> args.addAll(["-O3", "-march=armv8-a"]) }
7174
}
7275

7376
target("amd64") {

libs/simdvec/native/src/vec/c/aarch64/vec.c

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -95,31 +95,6 @@ EXPORT int32_t dot7u(int8_t* a, int8_t* b, const int32_t dims) {
9595
return res;
9696
}
9797

98-
EXPORT void dot7u_bulk(int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, float_t* results) {
99-
int32_t res = 0;
100-
if (dims > DOT7U_STRIDE_BYTES_LEN) {
101-
const int limit = dims & ~(DOT7U_STRIDE_BYTES_LEN - 1);
102-
for (int32_t c = 0; c < count; c++) {
103-
int i = limit;
104-
res = dot7u_inner(a, b, i);
105-
for (; i < dims; i++) {
106-
res += a[i] * b[i];
107-
}
108-
results[c] = (float_t)res;
109-
a += dims;
110-
}
111-
} else {
112-
for (int32_t c = 0; c < count; c++) {
113-
res = 0;
114-
for (int32_t i = 0; i < dims; i++) {
115-
res += a[i] * b[i];
116-
}
117-
results[c] = (float_t)res;
118-
a += dims;
119-
}
120-
}
121-
}
122-
12398
static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, const int32_t dims) {
12499
int32x4_t acc1 = vdupq_n_s32(0);
125100
int32x4_t acc2 = vdupq_n_s32(0);
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
#include <stddef.h>
11+
#include <stdint.h>
12+
#include <arm_neon.h>
13+
#include <math.h>
14+
extern "C" {
15+
#include "vec.h"
16+
}
17+
18+
#ifndef DOT7U_STRIDE_BYTES_LEN
19+
#define DOT7U_STRIDE_BYTES_LEN 32 // Must be a power of 2
20+
#endif
21+
22+
#ifndef SQR7U_STRIDE_BYTES_LEN
23+
#define SQR7U_STRIDE_BYTES_LEN 16 // Must be a power of 2
24+
#endif
25+
26+
template <int(*mapper)(int, int32_t*)>
27+
static inline void dot7u_inner_bulk(int8_t* a, int8_t* b, int dims, int32_t* offsets, int count, f32_t* results) {
28+
size_t blk = dims & ~15;
29+
size_t c = 0;
30+
31+
// Process 4 vectors at a time
32+
for (; c + 3 < count; c += 4) {
33+
const int8_t* a0 = a + mapper(c, offsets) * dims;
34+
const int8_t* a1 = a + mapper(c + 1, offsets) * dims;
35+
const int8_t* a2 = a + mapper(c + 2, offsets) * dims;
36+
const int8_t* a3 = a + mapper(c + 3, offsets) * dims;
37+
38+
int32x4_t acc0 = vdupq_n_s32(0);
39+
int32x4_t acc1 = vdupq_n_s32(0);
40+
int32x4_t acc2 = vdupq_n_s32(0);
41+
int32x4_t acc3 = vdupq_n_s32(0);
42+
int32x4_t acc4 = vdupq_n_s32(0);
43+
int32x4_t acc5 = vdupq_n_s32(0);
44+
int32x4_t acc6 = vdupq_n_s32(0);
45+
int32x4_t acc7 = vdupq_n_s32(0);
46+
47+
for (size_t i = 0; i < blk; i += 16) {
48+
int8x16_t vb = vld1q_s8(b + i);
49+
50+
int8x16_t v0 = vld1q_s8(a0 + i);
51+
int16x8_t lo0 = vmull_s8(vget_low_s8(v0), vget_low_s8(vb));
52+
int16x8_t hi0 = vmull_s8(vget_high_s8(v0), vget_high_s8(vb));
53+
acc0 = vpadalq_s16(acc0, lo0);
54+
acc1 = vpadalq_s16(acc1, hi0);
55+
56+
int8x16_t v1 = vld1q_s8(a1 + i);
57+
int16x8_t lo1 = vmull_s8(vget_low_s8(v1), vget_low_s8(vb));
58+
int16x8_t hi1 = vmull_s8(vget_high_s8(v1), vget_high_s8(vb));
59+
acc2 = vpadalq_s16(acc2, lo1);
60+
acc3 = vpadalq_s16(acc3, hi1);
61+
62+
int8x16_t v2 = vld1q_s8(a2 + i);
63+
int16x8_t lo2 = vmull_s8(vget_low_s8(v2), vget_low_s8(vb));
64+
int16x8_t hi2 = vmull_s8(vget_high_s8(v2), vget_high_s8(vb));
65+
acc4 = vpadalq_s16(acc4, lo2);
66+
acc5 = vpadalq_s16(acc5, hi2);
67+
68+
int8x16_t v3 = vld1q_s8(a3 + i);
69+
int16x8_t lo3 = vmull_s8(vget_low_s8(v3), vget_low_s8(vb));
70+
int16x8_t hi3 = vmull_s8(vget_high_s8(v3), vget_high_s8(vb));
71+
acc6 = vpadalq_s16(acc6, lo3);
72+
acc7 = vpadalq_s16(acc7, hi3);
73+
}
74+
int32x4_t acc01 = vaddq_s32(acc0, acc1);
75+
int32x4_t acc23 = vaddq_s32(acc2, acc3);
76+
int32x4_t acc45 = vaddq_s32(acc4, acc5);
77+
int32x4_t acc67 = vaddq_s32(acc6, acc7);
78+
79+
int32_t acc_scalar0 = vaddvq_s32(acc01);
80+
int32_t acc_scalar1 = vaddvq_s32(acc23);
81+
int32_t acc_scalar2 = vaddvq_s32(acc45);
82+
int32_t acc_scalar3 = vaddvq_s32(acc67);
83+
if (blk != dims) {
84+
// scalar tail
85+
for (size_t t = blk; t < dims; t++) {
86+
const int8_t bb = b[t];
87+
acc_scalar0 += a0[t] * bb;
88+
acc_scalar1 += a1[t] * bb;
89+
acc_scalar2 += a2[t] * bb;
90+
acc_scalar3 += a3[t] * bb;
91+
}
92+
}
93+
results[c + 0] = (f32_t)acc_scalar0;
94+
results[c + 1] = (f32_t)acc_scalar1;
95+
results[c + 2] = (f32_t)acc_scalar2;
96+
results[c + 3] = (f32_t)acc_scalar3;
97+
}
98+
99+
// Tail-handling: remaining 0..3 vectors
100+
for (; c < count; c++) {
101+
int8_t* a0 = a + mapper(c, offsets) * dims;
102+
results[c] = (f32_t)dot7u(a0, b, dims);
103+
}
104+
}
105+
106+
static inline int identity(int i, int32_t* offsets) {
107+
return i;
108+
}
109+
110+
static inline int index(int i, int32_t* offsets) {
111+
return offsets[i];
112+
}
113+
114+
extern "C"
115+
EXPORT void dot7u_bulk(int8_t* a, int8_t* b, int dims, int count, f32_t* results) {
116+
dot7u_inner_bulk<identity>(a, b, dims, NULL, count, results);
117+
}
118+
119+
extern "C"
120+
EXPORT void dot7u_bulk_offsets(int8_t* a, int8_t* b, int dims, int32_t* offsets, int count, f32_t* results) {
121+
dot7u_inner_bulk<index>(a, b, dims, offsets, count, results);
122+
}
123+

0 commit comments

Comments
 (0)