Skip to content

Commit 639bbf0

Browse files
committed
Another fix; add benchmarks to cover all paths
1 parent de51eea commit 639bbf0

File tree

3 files changed

+47
-2
lines changed

3 files changed

+47
-2
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt7uBulkBenchmark.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.apache.lucene.store.IOContext;
1515
import org.apache.lucene.store.IndexInput;
1616
import org.apache.lucene.store.MMapDirectory;
17+
import org.apache.lucene.util.hnsw.RandomVectorScorer;
1718
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
1819
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
1920
import org.elasticsearch.common.logging.LogConfigurator;
@@ -48,6 +49,7 @@
4849
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.createRandomInt7VectorData;
4950
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.getScorerFactoryOrDie;
5051
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.luceneScoreSupplier;
52+
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.luceneScorer;
5153
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.readNodeCorrectionConstant;
5254
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.supportsHeapSegments;
5355
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.vectorValues;
@@ -100,6 +102,9 @@ public class VectorScorerInt7uBulkBenchmark {
100102
UpdateableRandomVectorScorer luceneDotScorer;
101103
UpdateableRandomVectorScorer nativeDotScorer;
102104

105+
RandomVectorScorer luceneDotScorerQuery;
106+
RandomVectorScorer nativeDotScorerQuery;
107+
103108
@Setup(Level.Trial)
104109
public void setup() throws IOException {
105110
factory = getScorerFactoryOrDie();
@@ -127,6 +132,17 @@ public void setup() throws IOException {
127132
.orElseThrow()
128133
.scorer();
129134
nativeDotScorer.setScoringOrdinal(targetOrd);
135+
136+
if (supportsHeapSegments()) {
137+
// setup for getInt7SQVectorScorer / query vector scoring
138+
float[] queryVec = new float[dims];
139+
for (int i = 0; i < dims; i++) {
140+
queryVec[i] = random.nextFloat();
141+
}
142+
luceneDotScorerQuery = luceneScorer(dotProductValues, VectorSimilarityFunction.DOT_PRODUCT, queryVec);
143+
nativeDotScorerQuery = factory.getInt7SQVectorScorer(VectorSimilarityFunction.DOT_PRODUCT, dotProductValues, queryVec)
144+
.orElseThrow();
145+
}
130146
}
131147

132148
@TearDown
@@ -151,6 +167,14 @@ public float[] dotProductLuceneMultipleRandom() throws IOException {
151167
return scores;
152168
}
153169

170+
@Benchmark
171+
public float[] dotProductLuceneQueryMultipleRandom() throws IOException {
172+
for (int v = 0; v < numVectorsToScore; v++) {
173+
scores[v] = luceneDotScorerQuery.score(ordinals[v]);
174+
}
175+
return scores;
176+
}
177+
154178
@Benchmark
155179
public float[] dotProductNativeMultipleSequential() throws IOException {
156180
for (int v = 0; v < numVectorsToScore; v++) {
@@ -167,6 +191,14 @@ public float[] dotProductNativeMultipleRandom() throws IOException {
167191
return scores;
168192
}
169193

194+
@Benchmark
195+
public float[] dotProductNativeQueryMultipleRandom() throws IOException {
196+
for (int v = 0; v < numVectorsToScore; v++) {
197+
scores[v] = nativeDotScorerQuery.score(ordinals[v]);
198+
}
199+
return scores;
200+
}
201+
170202
@Benchmark
171203
public float[] dotProductNativeMultipleSequentialBulk() throws IOException {
172204
nativeDotScorer.bulkScore(ids, scores, ordinals.length);
@@ -179,6 +211,12 @@ public float[] dotProductNativeMultipleRandomBulk() throws IOException {
179211
return scores;
180212
}
181213

214+
@Benchmark
215+
public float[] dotProductNativeQueryMultipleRandomBulk() throws IOException {
216+
nativeDotScorerQuery.bulkScore(ordinals, scores, ordinals.length);
217+
return scores;
218+
}
219+
182220
@Benchmark
183221
public float[] dotProductScalarMultipleSequential() throws IOException {
184222
var queryVector = dotProductValues.vectorValue(targetOrd);

benchmarks/src/test/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt7uBulkBenchmarkTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import java.util.Arrays;
2020

21+
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.supportsHeapSegments;
22+
2123
public class VectorScorerInt7uBulkBenchmarkTests extends ESTestCase {
2224

2325
final float delta = 1e-3f;
@@ -61,6 +63,11 @@ public void testDotProductRandom() throws Exception {
6163
assertArrayEquals(expected, bench.dotProductLuceneMultipleRandom(), delta);
6264
assertArrayEquals(expected, bench.dotProductNativeMultipleRandom(), delta);
6365
assertArrayEquals(expected, bench.dotProductNativeMultipleRandomBulk(), delta);
66+
if (supportsHeapSegments()) {
67+
assertArrayEquals(expected, bench.dotProductLuceneQueryMultipleRandom(), delta);
68+
assertArrayEquals(expected, bench.dotProductNativeQueryMultipleRandom(), delta);
69+
assertArrayEquals(expected, bench.dotProductNativeQueryMultipleRandomBulk(), delta);
70+
}
6471
} finally {
6572
bench.teardown();
6673
}

libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public final class JdkVectorLibrary implements VectorLibrary {
6161
LinkerHelperUtil.critical()
6262
);
6363
dot7uBulkWithOffsets$mh = downcallHandle(
64-
"dot7u_bulk_offsets_2",
64+
"vec_dot7u_bulk_offsets_2",
6565
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS),
6666
LinkerHelperUtil.critical()
6767
);
@@ -97,7 +97,7 @@ public final class JdkVectorLibrary implements VectorLibrary {
9797
LinkerHelperUtil.critical()
9898
);
9999
dot7uBulkWithOffsets$mh = downcallHandle(
100-
"dot7u_bulk_offsets",
100+
"vec_dot7u_bulk_offsets",
101101
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS),
102102
LinkerHelperUtil.critical()
103103
);

0 commit comments

Comments
 (0)