Skip to content

Commit 72e1a84

Browse files
committed
Refactor VectorScorerBenchmark to Int7uScorerBenchmark
1 parent 139cebc commit 72e1a84

File tree

2 files changed

+85
-39
lines changed

2 files changed

+85
-39
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java renamed to benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int7uScorerBenchmark.java

Lines changed: 5 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -55,24 +55,23 @@
5555
/**
5656
* Benchmark that compares various scalar quantized vector similarity function
5757
* implementations;: scalar, lucene's panama-ized, and Elasticsearch's native.
58-
* Run with ./gradlew -p benchmarks run --args 'VectorScorerBenchmark'
58+
* Run with ./gradlew -p benchmarks run --args 'Int7uScorerBenchmark'
5959
*/
60-
public class VectorScorerBenchmark {
60+
public class Int7uScorerBenchmark {
6161

6262
static {
6363
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
6464
}
6565

6666
@Param({ "96", "768", "1024" })
67-
int dims;
68-
int size = 2; // there are only two vectors to compare
67+
public int dims;
68+
final int size = 2; // there are only two vectors to compare
6969

7070
Directory dir;
7171
IndexInput in;
7272
VectorScorerFactory factory;
7373

74-
byte[] vec1;
75-
byte[] vec2;
74+
byte[] vec1, vec2;
7675
float vec1Offset;
7776
float vec2Offset;
7877
float scoreCorrectionConstant;
@@ -139,39 +138,6 @@ public void setup() throws IOException {
139138
nativeDotScorerQuery = factory.getInt7SQVectorScorer(VectorSimilarityFunction.DOT_PRODUCT, values, queryVec).get();
140139
luceneSqrScorerQuery = luceneScorer(values, VectorSimilarityFunction.EUCLIDEAN, queryVec);
141140
nativeSqrScorerQuery = factory.getInt7SQVectorScorer(VectorSimilarityFunction.EUCLIDEAN, values, queryVec).get();
142-
143-
// sanity
144-
var f1 = dotProductLucene();
145-
var f2 = dotProductNative();
146-
var f3 = dotProductScalar();
147-
if (f1 != f2) {
148-
throw new AssertionError("lucene[" + f1 + "] != " + "native[" + f2 + "]");
149-
}
150-
if (f1 != f3) {
151-
throw new AssertionError("lucene[" + f1 + "] != " + "scalar[" + f3 + "]");
152-
}
153-
// square distance
154-
f1 = squareDistanceLucene();
155-
f2 = squareDistanceNative();
156-
f3 = squareDistanceScalar();
157-
if (f1 != f2) {
158-
throw new AssertionError("lucene[" + f1 + "] != " + "native[" + f2 + "]");
159-
}
160-
if (f1 != f3) {
161-
throw new AssertionError("lucene[" + f1 + "] != " + "scalar[" + f3 + "]");
162-
}
163-
164-
var q1 = dotProductLuceneQuery();
165-
var q2 = dotProductNativeQuery();
166-
if (q1 != q2) {
167-
throw new AssertionError("query: lucene[" + q1 + "] != " + "native[" + q2 + "]");
168-
}
169-
170-
var sqr1 = squareDistanceLuceneQuery();
171-
var sqr2 = squareDistanceNativeQuery();
172-
if (sqr1 != sqr2) {
173-
throw new AssertionError("query: lucene[" + q1 + "] != " + "native[" + q2 + "]");
174-
}
175141
}
176142

177143
@TearDown
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.benchmark.vector;
11+
12+
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
13+
14+
import org.apache.lucene.util.Constants;
15+
import org.elasticsearch.test.ESTestCase;
16+
import org.junit.BeforeClass;
17+
import org.openjdk.jmh.annotations.Param;
18+
19+
import java.util.Arrays;
20+
21+
public class Int7uScorerBenchmarkTests extends ESTestCase {
22+
23+
final double delta = 1e-3;
24+
final int dims;
25+
26+
public Int7uScorerBenchmarkTests(int dims) {
27+
this.dims = dims;
28+
}
29+
30+
@BeforeClass
31+
public static void skipWindows() {
32+
assumeFalse("doesn't work on windows yet", Constants.WINDOWS);
33+
}
34+
35+
public void testDotProduct() throws Exception {
36+
for (int i = 0; i < 100; i++) {
37+
var bench = new Int7uScorerBenchmark();
38+
bench.dims = dims;
39+
bench.setup();
40+
try {
41+
float expected = bench.dotProductScalar();
42+
assertEquals(expected, bench.dotProductLucene(), delta);
43+
assertEquals(expected, bench.dotProductNative(), delta);
44+
45+
expected = bench.dotProductLuceneQuery();
46+
assertEquals(expected, bench.dotProductNativeQuery(), delta);
47+
} finally {
48+
bench.teardown();
49+
}
50+
}
51+
}
52+
53+
public void testSquareDistance() throws Exception {
54+
for (int i = 0; i < 100; i++) {
55+
var bench = new Int7uScorerBenchmark();
56+
bench.dims = dims;
57+
bench.setup();
58+
try {
59+
float expected = bench.squareDistanceScalar();
60+
assertEquals(expected, bench.squareDistanceLucene(), delta);
61+
assertEquals(expected, bench.squareDistanceNative(), delta);
62+
63+
expected = bench.squareDistanceLuceneQuery();
64+
assertEquals(expected, bench.squareDistanceNativeQuery(), delta);
65+
} finally {
66+
bench.teardown();
67+
}
68+
}
69+
}
70+
71+
@ParametersFactory
72+
public static Iterable<Object[]> parametersFactory() {
73+
try {
74+
var params = Int7uScorerBenchmark.class.getField("dims").getAnnotationsByType(Param.class)[0].value();
75+
return () -> Arrays.stream(params).map(Integer::parseInt).map(i -> new Object[] { i }).iterator();
76+
} catch (NoSuchFieldException e) {
77+
throw new AssertionError(e);
78+
}
79+
}
80+
}

0 commit comments

Comments
 (0)