Skip to content

Commit 0dffc46

Browse files
authored
[SIMD][ARM] Optimized native bulk dot product scoring for Int7 (#138552)
In #138204 @benwtrent implemented bulk scoring for int7 centroid scoring; recent Lucene versions also introduced the possibility to provide specialized code for bulk scoring by overriding RandomVectorScorer#bulkScore. In this PR we generalized Ben's native bulk int7 scoring implementation to work with the Lucene case too. We provided a simple implementation for x86, and we optimized the ARM implementation to issue multiple memory access instructions at the same time. This unrolled/optimized code for ARM has a minor benefit in the sequential access pattern too.
1 parent bd82d91 commit 0dffc46

File tree

16 files changed

+758
-53
lines changed

16 files changed

+758
-53
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt7uBulkBenchmark.java

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.apache.lucene.store.IOContext;
1515
import org.apache.lucene.store.IndexInput;
1616
import org.apache.lucene.store.MMapDirectory;
17+
import org.apache.lucene.util.hnsw.RandomVectorScorer;
1718
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
1819
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
1920
import org.elasticsearch.common.logging.LogConfigurator;
@@ -48,6 +49,7 @@
4849
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.createRandomInt7VectorData;
4950
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.getScorerFactoryOrDie;
5051
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.luceneScoreSupplier;
52+
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.luceneScorer;
5153
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.readNodeCorrectionConstant;
5254
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.supportsHeapSegments;
5355
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.vectorValues;
@@ -80,10 +82,10 @@ public class VectorScorerInt7uBulkBenchmark {
8082

8183
// 128k is typically enough to not fit in L1 (core) cache for most processors;
8284
// 1.5M is typically enough to not fit in L2 (core) cache;
83-
// 40M is typically enough to not fit in L3 cache
84-
@Param({ "128000", "1500000", "30000000" })
85+
// 130M is enough to not fit in L3 cache
86+
@Param({ "128", "1500", "130000" })
8587
public int numVectors;
86-
public int numVectorsToScore = 20_000;
88+
public int numVectorsToScore;
8789

8890
Path path;
8991
Directory dir;
@@ -100,8 +102,12 @@ public class VectorScorerInt7uBulkBenchmark {
100102
UpdateableRandomVectorScorer luceneDotScorer;
101103
UpdateableRandomVectorScorer nativeDotScorer;
102104

105+
RandomVectorScorer luceneDotScorerQuery;
106+
RandomVectorScorer nativeDotScorerQuery;
107+
103108
@Setup(Level.Trial)
104109
public void setup() throws IOException {
110+
numVectorsToScore = Math.min(numVectors, 20_000);
105111
factory = getScorerFactoryOrDie();
106112

107113
var random = ThreadLocalRandom.current();
@@ -127,6 +133,17 @@ public void setup() throws IOException {
127133
.orElseThrow()
128134
.scorer();
129135
nativeDotScorer.setScoringOrdinal(targetOrd);
136+
137+
if (supportsHeapSegments()) {
138+
// setup for getInt7SQVectorScorer / query vector scoring
139+
float[] queryVec = new float[dims];
140+
for (int i = 0; i < dims; i++) {
141+
queryVec[i] = random.nextFloat();
142+
}
143+
luceneDotScorerQuery = luceneScorer(dotProductValues, VectorSimilarityFunction.DOT_PRODUCT, queryVec);
144+
nativeDotScorerQuery = factory.getInt7SQVectorScorer(VectorSimilarityFunction.DOT_PRODUCT, dotProductValues, queryVec)
145+
.orElseThrow();
146+
}
130147
}
131148

132149
@TearDown
@@ -151,6 +168,14 @@ public float[] dotProductLuceneMultipleRandom() throws IOException {
151168
return scores;
152169
}
153170

171+
@Benchmark
172+
public float[] dotProductLuceneQueryMultipleRandom() throws IOException {
173+
for (int v = 0; v < numVectorsToScore; v++) {
174+
scores[v] = luceneDotScorerQuery.score(ordinals[v]);
175+
}
176+
return scores;
177+
}
178+
154179
@Benchmark
155180
public float[] dotProductNativeMultipleSequential() throws IOException {
156181
for (int v = 0; v < numVectorsToScore; v++) {
@@ -167,6 +192,14 @@ public float[] dotProductNativeMultipleRandom() throws IOException {
167192
return scores;
168193
}
169194

195+
@Benchmark
196+
public float[] dotProductNativeQueryMultipleRandom() throws IOException {
197+
for (int v = 0; v < numVectorsToScore; v++) {
198+
scores[v] = nativeDotScorerQuery.score(ordinals[v]);
199+
}
200+
return scores;
201+
}
202+
170203
@Benchmark
171204
public float[] dotProductNativeMultipleSequentialBulk() throws IOException {
172205
nativeDotScorer.bulkScore(ids, scores, ordinals.length);
@@ -179,6 +212,12 @@ public float[] dotProductNativeMultipleRandomBulk() throws IOException {
179212
return scores;
180213
}
181214

215+
@Benchmark
216+
public float[] dotProductNativeQueryMultipleRandomBulk() throws IOException {
217+
nativeDotScorerQuery.bulkScore(ordinals, scores, ordinals.length);
218+
return scores;
219+
}
220+
182221
@Benchmark
183222
public float[] dotProductScalarMultipleSequential() throws IOException {
184223
var queryVector = dotProductValues.vectorValue(targetOrd);

benchmarks/src/test/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt7uBulkBenchmarkTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import java.util.Arrays;
2020

21+
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.supportsHeapSegments;
22+
2123
public class VectorScorerInt7uBulkBenchmarkTests extends ESTestCase {
2224

2325
final float delta = 1e-3f;
@@ -61,6 +63,11 @@ public void testDotProductRandom() throws Exception {
6163
assertArrayEquals(expected, bench.dotProductLuceneMultipleRandom(), delta);
6264
assertArrayEquals(expected, bench.dotProductNativeMultipleRandom(), delta);
6365
assertArrayEquals(expected, bench.dotProductNativeMultipleRandomBulk(), delta);
66+
if (supportsHeapSegments()) {
67+
assertArrayEquals(expected, bench.dotProductLuceneQueryMultipleRandom(), delta);
68+
assertArrayEquals(expected, bench.dotProductNativeQueryMultipleRandom(), delta);
69+
assertArrayEquals(expected, bench.dotProductNativeQueryMultipleRandomBulk(), delta);
70+
}
6471
} finally {
6572
bench.teardown();
6673
}

docs/changelog/138552.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 138552
2+
summary: "[SIMD][ARM] Optimized native bulk dot product scoring for Int7"
3+
area: Vector Search
4+
type: enhancement
5+
issues: []

libs/native/libraries/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ configurations {
1919
}
2020

2121
var zstdVersion = "1.5.5"
22-
var vecVersion = "1.0.17"
22+
var vecVersion = "1.0.18"
2323

2424
repositories {
2525
exclusiveContent {

libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,32 @@ public interface VectorSimilarityFunctions {
4747
*/
4848
MethodHandle dotProductHandle7uBulk();
4949

50+
/**
51+
* Produces a method handle which computes the dot product of several byte (unsigned
52+
* int7) vectors. This bulk operation can be used to compute the dot product between a
53+
* single query vector and a subset of vectors from a dataset (array of vectors). Each
54+
* vector to include in the operation is identified by an offset inside the dataset.
55+
*
56+
* <p> Unsigned int7 byte vectors have values in the range of 0 to 127 (inclusive).
57+
*
58+
* <p> The type of the method handle will have {@code void} as return type. The type of
59+
* its arguments will be:
60+
* <ol>
61+
* <li>a {@code MemorySegment} containing the vector data bytes for several vectors;
62+
* in other words, a contiguous array of vectors</li>
63+
* <li>a {@code MemorySegment} containing the vector data bytes for a single ("query") vector</li>
64+
* <li>an {@code int}, representing the dimensions of each vector</li>
65+
* <li>an {@code int}, representing the width (in bytes) of each vector. Or, in other words,
66+
* the distance in bytes between two vectors inside the first param's {@code MemorySegment}</li>
67+
* <li>a {@code MemorySegment} containing the indices of the vectors inside the first param's array
68+
* on which we'll compute the dot product</li>
69+
* <li>an {@code int}, representing the number of vectors for which we'll compute the dot product
70+
* (which is equal to the size - in number of elements - of the 5th and 7th {@code MemorySegment}s)</li>
71+
* <li>a {@code MemorySegment}, into which the computed dot product float values will be stored</li>
72+
* </ol>
73+
*/
74+
MethodHandle dotProductHandle7uBulkWithOffsets();
75+
5076
/**
5177
* Produces a method handle returning the square distance of byte (unsigned int7) vectors.
5278
*

libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public final class JdkVectorLibrary implements VectorLibrary {
3333

3434
static final MethodHandle dot7u$mh;
3535
static final MethodHandle dot7uBulk$mh;
36+
static final MethodHandle dot7uBulkWithOffsets$mh;
3637
static final MethodHandle sqr7u$mh;
3738
static final MethodHandle cosf32$mh;
3839
static final MethodHandle dotf32$mh;
@@ -59,6 +60,11 @@ public final class JdkVectorLibrary implements VectorLibrary {
5960
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS),
6061
LinkerHelperUtil.critical()
6162
);
63+
dot7uBulkWithOffsets$mh = downcallHandle(
64+
"vec_dot7u_bulk_offsets_2",
65+
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS),
66+
LinkerHelperUtil.critical()
67+
);
6268
sqr7u$mh = downcallHandle(
6369
"vec_sqr7u_2",
6470
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
@@ -90,6 +96,11 @@ public final class JdkVectorLibrary implements VectorLibrary {
9096
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS),
9197
LinkerHelperUtil.critical()
9298
);
99+
dot7uBulkWithOffsets$mh = downcallHandle(
100+
"vec_dot7u_bulk_offsets",
101+
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS),
102+
LinkerHelperUtil.critical()
103+
);
93104
sqr7u$mh = downcallHandle(
94105
"vec_sqr7u",
95106
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
@@ -120,6 +131,7 @@ public final class JdkVectorLibrary implements VectorLibrary {
120131
}
121132
dot7u$mh = null;
122133
dot7uBulk$mh = null;
134+
dot7uBulkWithOffsets$mh = null;
123135
sqr7u$mh = null;
124136
cosf32$mh = null;
125137
dotf32$mh = null;
@@ -161,6 +173,18 @@ static void dotProduct7uBulk(MemorySegment a, MemorySegment b, int length, int c
161173
dot7uBulk(a, b, length, count, result);
162174
}
163175

176+
static void dotProduct7uBulkWithOffsets(
177+
MemorySegment a,
178+
MemorySegment b,
179+
int length,
180+
int pitch,
181+
MemorySegment offsets,
182+
int count,
183+
MemorySegment result
184+
) {
185+
dot7uBulkWithOffsets(a, b, length, pitch, offsets, count, result);
186+
}
187+
164188
/**
165189
* Computes the square distance of given unsigned int7 byte vectors.
166190
*
@@ -237,6 +261,22 @@ private static void dot7uBulk(MemorySegment a, MemorySegment b, int length, int
237261
}
238262
}
239263

264+
private static void dot7uBulkWithOffsets(
265+
MemorySegment a,
266+
MemorySegment b,
267+
int length,
268+
int pitch,
269+
MemorySegment offsets,
270+
int count,
271+
MemorySegment result
272+
) {
273+
try {
274+
JdkVectorLibrary.dot7uBulkWithOffsets$mh.invokeExact(a, b, length, pitch, offsets, count, result);
275+
} catch (Throwable t) {
276+
throw new AssertionError(t);
277+
}
278+
}
279+
240280
private static int sqr7u(MemorySegment a, MemorySegment b, int length) {
241281
try {
242282
return (int) JdkVectorLibrary.sqr7u$mh.invokeExact(a, b, length);
@@ -271,6 +311,7 @@ private static float sqrf32(MemorySegment a, MemorySegment b, int length) {
271311

272312
static final MethodHandle DOT_HANDLE_7U;
273313
static final MethodHandle DOT_HANDLE_7U_BULK;
314+
static final MethodHandle DOT_HANDLE_7U_BULK_WITH_OFFSETS;
274315
static final MethodHandle SQR_HANDLE_7U;
275316
static final MethodHandle COS_HANDLE_FLOAT32;
276317
static final MethodHandle DOT_HANDLE_FLOAT32;
@@ -286,6 +327,21 @@ private static float sqrf32(MemorySegment a, MemorySegment b, int length) {
286327
mt = MethodType.methodType(void.class, MemorySegment.class, MemorySegment.class, int.class, int.class, MemorySegment.class);
287328
DOT_HANDLE_7U_BULK = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProduct7uBulk", mt);
288329

330+
DOT_HANDLE_7U_BULK_WITH_OFFSETS = lookup.findStatic(
331+
JdkVectorSimilarityFunctions.class,
332+
"dotProduct7uBulkWithOffsets",
333+
MethodType.methodType(
334+
void.class,
335+
MemorySegment.class,
336+
MemorySegment.class,
337+
int.class,
338+
int.class,
339+
MemorySegment.class,
340+
int.class,
341+
MemorySegment.class
342+
)
343+
);
344+
289345
mt = MethodType.methodType(float.class, MemorySegment.class, MemorySegment.class, int.class);
290346
COS_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "cosineF32", mt);
291347
DOT_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProductF32", mt);
@@ -305,6 +361,11 @@ public MethodHandle dotProductHandle7uBulk() {
305361
return DOT_HANDLE_7U_BULK;
306362
}
307363

364+
@Override
365+
public MethodHandle dotProductHandle7uBulkWithOffsets() {
366+
return DOT_HANDLE_7U_BULK_WITH_OFFSETS;
367+
}
368+
308369
@Override
309370
public MethodHandle squareDistanceHandle7u() {
310371
return SQR_HANDLE_7U;

0 commit comments

Comments
 (0)