Skip to content

Commit 0b3b48a

Browse files
authored
Reinstate and test the native int7u bulk dot product (elastic#138317)
This commit reinstates, fixes, and tests the native int7u bulk dot product elastic#138239 The issue with the original change is that size_t is 64 bit, while we pass a 32 bit int from java. This is also arguably an issue with the other native definitions, but doesn't cause an issue because of position in the declaration (last). We should fix these declarations, but as a separate PR. closes elastic#138302
1 parent 816d501 commit 0b3b48a

File tree

11 files changed

+167
-14
lines changed

11 files changed

+167
-14
lines changed

libs/native/libraries/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ configurations {
1919
}
2020

2121
var zstdVersion = "1.5.5"
22-
var vecVersion = "1.0.13"
22+
var vecVersion = "1.0.15"
2323

2424
repositories {
2525
exclusiveContent {

libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,23 @@ public interface VectorSimilarityFunctions {
3030
*/
3131
MethodHandle dotProductHandle7u();
3232

33+
/**
34+
* Produces a method handle which computes the dot product of several byte (unsigned
35+
* int7) vectors. This bulk operation can be used to compute the dot product between a
36+
* single query vector and a number of other vectors.
37+
*
38+
* <p> Unsigned int7 byte vectors have values in the range of 0 to 127 (inclusive).
39+
*
40+
* <p> The type of the method handle will have {@code void} as return type. The type of
41+
* its first and second arguments will be {@code MemorySegment}, the former contains the
42+
* vector data bytes for several vectors, while the latter just a single vector. The
43+
* type of the third argument is an int, representing the dimensions of each vector. The
44+
* type of the fourth argument is an int, representing the number of vectors in the
45+
* first argument. The type of the final argument is a MemorySegment, into which the
46+
* computed dot product float values will be stored.
47+
*/
48+
MethodHandle dotProductHandle7uBulk();
49+
3350
/**
3451
* Produces a method handle returning the square distance of byte (unsigned int7) vectors.
3552
*

libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ public final class JdkVectorLibrary implements VectorLibrary {
3232
static final Logger logger = LogManager.getLogger(JdkVectorLibrary.class);
3333

3434
static final MethodHandle dot7u$mh;
35+
static final MethodHandle dot7uBulk$mh;
3536
static final MethodHandle sqr7u$mh;
3637
static final MethodHandle cosf32$mh;
3738
static final MethodHandle dotf32$mh;
@@ -53,6 +54,11 @@ public final class JdkVectorLibrary implements VectorLibrary {
5354
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
5455
LinkerHelperUtil.critical()
5556
);
57+
dot7uBulk$mh = downcallHandle(
58+
"dot7u_bulk_2",
59+
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS),
60+
LinkerHelperUtil.critical()
61+
);
5662
sqr7u$mh = downcallHandle(
5763
"sqr7u_2",
5864
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
@@ -79,6 +85,11 @@ public final class JdkVectorLibrary implements VectorLibrary {
7985
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
8086
LinkerHelperUtil.critical()
8187
);
88+
dot7uBulk$mh = downcallHandle(
89+
"dot7u_bulk",
90+
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS),
91+
LinkerHelperUtil.critical()
92+
);
8293
sqr7u$mh = downcallHandle(
8394
"sqr7u",
8495
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
@@ -108,6 +119,7 @@ public final class JdkVectorLibrary implements VectorLibrary {
108119
enable them in your OS/Hypervisor/VM/container""");
109120
}
110121
dot7u$mh = null;
122+
dot7uBulk$mh = null;
111123
sqr7u$mh = null;
112124
cosf32$mh = null;
113125
dotf32$mh = null;
@@ -142,6 +154,13 @@ static int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
142154
return dot7u(a, b, length);
143155
}
144156

157+
static void dotProduct7uBulk(MemorySegment a, MemorySegment b, int length, int count, MemorySegment result) {
158+
Objects.checkFromIndexSize(0, length * count, (int) a.byteSize());
159+
Objects.checkFromIndexSize(0, length, (int) b.byteSize());
160+
Objects.checkFromIndexSize(0, count * Float.BYTES, (int) result.byteSize());
161+
dot7uBulk(a, b, length, count, result);
162+
}
163+
145164
/**
146165
* Computes the square distance of given unsigned int7 byte vectors.
147166
*
@@ -210,6 +229,14 @@ private static int dot7u(MemorySegment a, MemorySegment b, int length) {
210229
}
211230
}
212231

232+
private static void dot7uBulk(MemorySegment a, MemorySegment b, int length, int count, MemorySegment result) {
233+
try {
234+
JdkVectorLibrary.dot7uBulk$mh.invokeExact(a, b, length, count, result);
235+
} catch (Throwable t) {
236+
throw new AssertionError(t);
237+
}
238+
}
239+
213240
private static int sqr7u(MemorySegment a, MemorySegment b, int length) {
214241
try {
215242
return (int) JdkVectorLibrary.sqr7u$mh.invokeExact(a, b, length);
@@ -243,6 +270,7 @@ private static float sqrf32(MemorySegment a, MemorySegment b, int length) {
243270
}
244271

245272
static final MethodHandle DOT_HANDLE_7U;
273+
static final MethodHandle DOT_HANDLE_7U_BULK;
246274
static final MethodHandle SQR_HANDLE_7U;
247275
static final MethodHandle COS_HANDLE_FLOAT32;
248276
static final MethodHandle DOT_HANDLE_FLOAT32;
@@ -255,6 +283,9 @@ private static float sqrf32(MemorySegment a, MemorySegment b, int length) {
255283
DOT_HANDLE_7U = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProduct7u", mt);
256284
SQR_HANDLE_7U = lookup.findStatic(JdkVectorSimilarityFunctions.class, "squareDistance7u", mt);
257285

286+
mt = MethodType.methodType(void.class, MemorySegment.class, MemorySegment.class, int.class, int.class, MemorySegment.class);
287+
DOT_HANDLE_7U_BULK = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProduct7uBulk", mt);
288+
258289
mt = MethodType.methodType(float.class, MemorySegment.class, MemorySegment.class, int.class);
259290
COS_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "cosineF32", mt);
260291
DOT_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProductF32", mt);
@@ -269,6 +300,11 @@ public MethodHandle dotProductHandle7u() {
269300
return DOT_HANDLE_7U;
270301
}
271302

303+
@Override
304+
public MethodHandle dotProductHandle7uBulk() {
305+
return DOT_HANDLE_7U_BULK;
306+
}
307+
272308
@Override
273309
public MethodHandle squareDistanceHandle7u() {
274310
return SQR_HANDLE_7U;

libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt7uTests.java

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import java.lang.foreign.MemorySegment;
1919

20+
import static java.lang.foreign.ValueLayout.JAVA_FLOAT_UNALIGNED;
2021
import static org.hamcrest.Matchers.containsString;
2122

2223
public class JDKVectorLibraryInt7uTests extends VectorSimilarityFunctionsTests {
@@ -71,6 +72,11 @@ public void testInt7BinaryVectors() {
7172
assertEquals(expected, dotProduct7u(heapSeg1, heapSeg2, dims));
7273
assertEquals(expected, dotProduct7u(nativeSeg1, heapSeg2, dims));
7374
assertEquals(expected, dotProduct7u(heapSeg1, nativeSeg2, dims));
75+
76+
// trivial bulk with a single vector
77+
float[] bulkScore = new float[1];
78+
dotProduct7uBulk(nativeSeg1, nativeSeg2, dims, 1, MemorySegment.ofArray(bulkScore));
79+
assertEquals(expected, bulkScore[0], 0f);
7480
}
7581

7682
// square distance
@@ -86,6 +92,32 @@ public void testInt7BinaryVectors() {
8692
}
8793
}
8894

95+
public void testInt7uBulk() {
96+
assumeTrue(notSupportedMsg(), supported());
97+
final int dims = size;
98+
final int numVecs = randomIntBetween(2, 101);
99+
var values = new byte[numVecs][dims];
100+
var segment = arena.allocate((long) dims * numVecs);
101+
for (int i = 0; i < numVecs; i++) {
102+
randomBytesBetween(values[i], MIN_INT7_VALUE, MAX_INT7_VALUE);
103+
MemorySegment.copy(MemorySegment.ofArray(values[i]), 0L, segment, (long) i * dims, dims);
104+
}
105+
int queryOrd = randomInt(numVecs - 1);
106+
float[] expectedScores = new float[numVecs];
107+
dotProductBulkScalar(values[queryOrd], values, expectedScores);
108+
109+
var nativeQuerySeg = segment.asSlice((long) queryOrd * dims, dims);
110+
var bulkScoresSeg = arena.allocate((long) numVecs * Float.BYTES);
111+
dotProduct7uBulk(segment, nativeQuerySeg, dims, numVecs, bulkScoresSeg);
112+
assertScoresEquals(expectedScores, bulkScoresSeg);
113+
114+
if (supportsHeapSegments()) {
115+
float[] bulkScores = new float[numVecs];
116+
dotProduct7uBulk(segment, nativeQuerySeg, dims, numVecs, MemorySegment.ofArray(bulkScores));
117+
assertArrayEquals(expectedScores, bulkScores, 0f);
118+
}
119+
}
120+
89121
public void testIllegalDims() {
90122
assumeTrue(notSupportedMsg(), supported());
91123
var segment = arena.allocate((long) size * 3);
@@ -109,6 +141,26 @@ public void testIllegalDims() {
109141
assertThat(e6.getMessage(), containsString("out of bounds for length"));
110142
}
111143

144+
public void testBulkIllegalDims() {
145+
assumeTrue(notSupportedMsg(), supported());
146+
var segA = arena.allocate((long) size * 3);
147+
var segB = arena.allocate((long) size * 3);
148+
var segS = arena.allocate((long) size * Float.BYTES);
149+
150+
var e1 = expectThrows(IOOBE, () -> dotProduct7uBulk(segA, segB, size, 4, segS));
151+
assertThat(e1.getMessage(), containsString("out of bounds for length"));
152+
153+
var e2 = expectThrows(IOOBE, () -> dotProduct7uBulk(segA, segB, size, -1, segS));
154+
assertThat(e2.getMessage(), containsString("out of bounds for length"));
155+
156+
var e3 = expectThrows(IOOBE, () -> dotProduct7uBulk(segA, segB, -1, 3, segS));
157+
assertThat(e3.getMessage(), containsString("out of bounds for length"));
158+
159+
var tooSmall = arena.allocate((long) 3 * Float.BYTES - 1);
160+
var e4 = expectThrows(IOOBE, () -> dotProduct7uBulk(segA, segB, size, 3, tooSmall));
161+
assertThat(e4.getMessage(), containsString("out of bounds for length"));
162+
}
163+
112164
int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
113165
try {
114166
return (int) getVectorDistance().dotProductHandle7u().invokeExact(a, b, length);
@@ -137,6 +189,20 @@ int squareDistance7u(MemorySegment a, MemorySegment b, int length) {
137189
}
138190
}
139191

192+
void dotProduct7uBulk(MemorySegment a, MemorySegment b, int dims, int count, MemorySegment result) {
193+
try {
194+
getVectorDistance().dotProductHandle7uBulk().invokeExact(a, b, dims, count, result);
195+
} catch (Throwable e) {
196+
if (e instanceof Error err) {
197+
throw err;
198+
} else if (e instanceof RuntimeException re) {
199+
throw re;
200+
} else {
201+
throw new RuntimeException(e);
202+
}
203+
}
204+
}
205+
140206
/** Computes the dot product of the given vectors a and b. */
141207
static int dotProductScalar(byte[] a, byte[] b) {
142208
int res = 0;
@@ -156,4 +222,18 @@ static int squareDistanceScalar(byte[] a, byte[] b) {
156222
}
157223
return squareSum;
158224
}
225+
226+
static void dotProductBulkScalar(byte[] query, byte[][] data, float[] scores) {
227+
for (int i = 0; i < data.length; i++) {
228+
scores[i] = dotProductScalar(query, data[i]);
229+
}
230+
}
231+
232+
static void assertScoresEquals(float[] expectedScores, MemorySegment expectedScoresSeg) {
233+
assert expectedScores.length == (expectedScoresSeg.byteSize() / Float.BYTES);
234+
for (int i = 0; i < expectedScores.length; i++) {
235+
assertEquals(expectedScores[i], expectedScoresSeg.get(JAVA_FLOAT_UNALIGNED, i * Float.BYTES), 0f);
236+
}
237+
}
238+
159239
}

libs/simdvec/native/publish_vec_binaries.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then
2020
exit 1;
2121
fi
2222

23-
VERSION="1.0.14"
23+
VERSION="1.0.15"
2424
ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}"
2525
TEMP=$(mktemp -d)
2626

libs/simdvec/native/src/vec/c/aarch64/vec.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,10 @@ EXPORT int32_t dot7u(int8_t* a, int8_t* b, size_t dims) {
9595
return res;
9696
}
9797

98-
EXPORT void dot7u_bulk(int8_t* a, int8_t* b, size_t dims, size_t count, float_t* results) {
98+
EXPORT void dot7u_bulk(int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, float_t* results) {
9999
int32_t res = 0;
100100
if (dims > DOT7U_STRIDE_BYTES_LEN) {
101-
int limit = dims & ~(DOT7U_STRIDE_BYTES_LEN - 1);
101+
const int limit = dims & ~(DOT7U_STRIDE_BYTES_LEN - 1);
102102
for (size_t c = 0; c < count; c++) {
103103
int i = limit;
104104
res = dot7u_inner(a, b, i);

libs/simdvec/native/src/vec/c/amd64/vec.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,10 @@ EXPORT int32_t dot7u(int8_t* a, int8_t* b, size_t dims) {
153153
return res;
154154
}
155155

156-
EXPORT void dot7u_bulk(int8_t* a, int8_t* b, size_t dims, size_t count, float_t* results) {
156+
EXPORT void dot7u_bulk(int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, float_t* results) {
157157
int32_t res = 0;
158158
if (dims > STRIDE_BYTES_LEN) {
159-
int limit = dims & ~(STRIDE_BYTES_LEN - 1);
159+
const int limit = dims & ~(STRIDE_BYTES_LEN - 1);
160160
for (size_t c = 0; c < count; c++) {
161161
int i = limit;
162162
res = dot7u_inner(a, b, i);

libs/simdvec/native/src/vec/c/amd64/vec_2.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ inline __m512i fma8(__m512i acc, const int8_t* p1, const int8_t* p2) {
4747
return _mm512_add_epi32(_mm512_madd_epi16(ones, dot), acc);
4848
}
4949

50-
static inline int32_t dot7u_inner_avx512(int8_t* a, int8_t* b, size_t dims) {
50+
static inline int32_t dot7u_inner_avx512(int8_t* a, const int8_t* b, size_t dims) {
5151
constexpr int stride8 = 8 * STRIDE_BYTES_LEN;
5252
constexpr int stride4 = 4 * STRIDE_BYTES_LEN;
5353
const int8_t* p1 = a;
@@ -115,10 +115,10 @@ EXPORT int32_t dot7u_2(int8_t* a, int8_t* b, size_t dims) {
115115
}
116116

117117
extern "C"
118-
EXPORT void dot7u_bulk_2(int8_t* a, int8_t* b, size_t dims, size_t count, float_t* results) {
118+
EXPORT void dot7u_bulk_2(int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, float_t* results) {
119119
int32_t res = 0;
120120
if (dims > STRIDE_BYTES_LEN) {
121-
int limit = dims & ~(STRIDE_BYTES_LEN - 1);
121+
const int limit = dims & ~(STRIDE_BYTES_LEN - 1);
122122
for (size_t c = 0; c < count; c++) {
123123
int i = limit;
124124
res = dot7u_inner_avx512(a, b, i);

libs/simdvec/native/src/vec/headers/vec.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ EXPORT int vec_caps();
1919

2020
EXPORT int32_t dot7u(int8_t* a, int8_t* b, size_t dims);
2121

22-
EXPORT void dot7u_bulk(int8_t* a, int8_t* b, size_t dims, size_t count, float_t* results);
22+
EXPORT void dot7u_bulk(int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, float_t* results);
2323

2424
EXPORT int32_t sqr7u(int8_t *a, int8_t *b, size_t length);
2525

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ public class Similarities {
2323

2424
static final MethodHandle DOT_PRODUCT_7U = DISTANCE_FUNCS.dotProductHandle7u();
2525
static final MethodHandle SQUARE_DISTANCE_7U = DISTANCE_FUNCS.squareDistanceHandle7u();
26+
static final MethodHandle DOT_PRODUCT_7U_BULK = DISTANCE_FUNCS.dotProductHandle7uBulk();
2627

2728
static int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
2829
try {
@@ -51,4 +52,18 @@ static int squareDistance7u(MemorySegment a, MemorySegment b, int length) {
5152
}
5253
}
5354
}
55+
56+
static void dotProduct7uBulk(MemorySegment a, MemorySegment b, int length, int count, MemorySegment scores) {
57+
try {
58+
DOT_PRODUCT_7U_BULK.invokeExact(a, b, length, count, scores);
59+
} catch (Throwable e) {
60+
if (e instanceof Error err) {
61+
throw err;
62+
} else if (e instanceof RuntimeException re) {
63+
throw re;
64+
} else {
65+
throw new RuntimeException(e);
66+
}
67+
}
68+
}
5469
}

0 commit comments

Comments
 (0)