Skip to content

Commit 3a2f8f6

Browse files
authored
Add square distance query variants to the vector distance benchmark (#119219)
This commit adds square distance query variants to the vector distance benchmark.
1 parent 2c736f4 commit 3a2f8f6

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ public class VectorScorerBenchmark {
8383

8484
RandomVectorScorer luceneDotScorerQuery;
8585
RandomVectorScorer nativeDotScorerQuery;
86+
RandomVectorScorer luceneSqrScorerQuery;
87+
RandomVectorScorer nativeSqrScorerQuery;
8688

8789
@Setup
8890
public void setup() throws IOException {
@@ -130,6 +132,8 @@ public void setup() throws IOException {
130132
}
131133
luceneDotScorerQuery = luceneScorer(values, VectorSimilarityFunction.DOT_PRODUCT, queryVec);
132134
nativeDotScorerQuery = factory.getInt7SQVectorScorer(VectorSimilarityFunction.DOT_PRODUCT, values, queryVec).get();
135+
luceneSqrScorerQuery = luceneScorer(values, VectorSimilarityFunction.EUCLIDEAN, queryVec);
136+
nativeSqrScorerQuery = factory.getInt7SQVectorScorer(VectorSimilarityFunction.EUCLIDEAN, values, queryVec).get();
133137

134138
// sanity
135139
var f1 = dotProductLucene();
@@ -157,6 +161,12 @@ public void setup() throws IOException {
157161
if (q1 != q2) {
158162
throw new AssertionError("query: lucene[" + q1 + "] != " + "native[" + q2 + "]");
159163
}
164+
165+
var sqr1 = squareDistanceLuceneQuery();
166+
var sqr2 = squareDistanceNativeQuery();
167+
if (sqr1 != sqr2) {
168+
throw new AssertionError("query: lucene[" + q1 + "] != " + "native[" + q2 + "]");
169+
}
160170
}
161171

162172
@TearDown
@@ -217,6 +227,16 @@ public float squareDistanceScalar() {
217227
return 1 / (1f + adjustedDistance);
218228
}
219229

230+
@Benchmark
231+
public float squareDistanceLuceneQuery() throws IOException {
232+
return luceneSqrScorerQuery.score(1);
233+
}
234+
235+
@Benchmark
236+
public float squareDistanceNativeQuery() throws IOException {
237+
return nativeSqrScorerQuery.score(1);
238+
}
239+
220240
QuantizedByteVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
221241
var sq = new ScalarQuantizer(0.1f, 0.9f, (byte) 7);
222242
var slice = in.slice("values", 0, in.length());

0 commit comments

Comments
 (0)