Skip to content

Commit a9777af

Browse files
authored
Fix VectorScorerOSQBenchmarkTests (elastic#142742)
This PR fixes several VectorScorerOSQBenchmarkTests failures. The native C code (score_common.h) is either compiled with -O3 (e.g. on aarch64), which enables FMA (fused multiply-add) instructions, or uses FMA SIMD instructions directly. Java's float arithmetic rounds at every operation per IEEE 754, but FMA computes a * b + c with a single rounding, preserving extra intermediate precision. The score correction formula involves terms like ax * ay * dims + ay * lx * tcs + .... When correction values come from random bytes, they can represent extreme floats (e.g., ±1e+30). These large terms nearly cancel, and the FMA vs non-FMA rounding difference becomes catastrophic — the scalar gets score ≈ -1 (result 0.0) while the native gets score ≈ 5.5e+37 (result 2.7e+37). This is NOT an issue in production — real quantized corrections from OptimizedScalarQuantizer produce small, well-behaved floats where FMA differences are negligible (the ESNextOSQVectorsScorerTests.testScoreBulk test validates this with real data). The PR extracts the same input data generation used in ESNextOSQVectorsScorerTests.testScoreBulk, extracting it and exposing it via test fixture, so both tests share the same input data setup. Fixes elastic#142289 Fixes elastic#142413 Fixes elastic#142490 Fixes elastic#142491 Fixes elastic#142492 Fixes elastic#142587 Fixes elastic#142588 Fixes elastic#142589
1 parent 605c05b commit a9777af

File tree

10 files changed

+268
-226
lines changed

10 files changed

+268
-226
lines changed

benchmarks/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ dependencies {
5151
api(project(':x-pack:plugin:logsdb'))
5252
implementation project(path: ':libs:native')
5353
implementation project(path: ':libs:simdvec')
54+
implementation (testFixtures(project(path: ':libs:simdvec')))
5455
implementation project(path: ':libs:swisshash')
5556
implementation project(path: ':libs:exponential-histogram')
5657
implementation(project(':x-pack:plugin:searchable-snapshots')) {

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerOSQBenchmark.java

Lines changed: 36 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
import org.apache.lucene.store.IndexOutput;
1717
import org.apache.lucene.store.MMapDirectory;
1818
import org.apache.lucene.store.NIOFSDirectory;
19-
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
19+
import org.apache.lucene.util.VectorUtil;
2020
import org.elasticsearch.common.logging.LogConfigurator;
2121
import org.elasticsearch.core.IOUtils;
2222
import org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat;
2323
import org.elasticsearch.simdvec.ESNextOSQVectorsScorer;
2424
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
25+
import org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils;
2526
import org.elasticsearch.xpack.searchablesnapshots.store.SearchableSnapshotDirectoryFactory;
2627
import org.openjdk.jmh.annotations.Benchmark;
2728
import org.openjdk.jmh.annotations.BenchmarkMode;
@@ -42,6 +43,11 @@
4243
import java.util.Random;
4344
import java.util.concurrent.TimeUnit;
4445

46+
import static org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils.createOSQIndexData;
47+
import static org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils.createOSQQueryData;
48+
import static org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils.randomVector;
49+
import static org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils.writeBulkOSQVectorData;
50+
4551
@BenchmarkMode(Mode.Throughput)
4652
@OutputTimeUnit(TimeUnit.MILLISECONDS)
4753
@State(Scope.Benchmark)
@@ -72,7 +78,7 @@ public enum VectorImplementation {
7278
public int dims;
7379

7480
@Param({ "1", "2", "4" })
75-
public int bits;
81+
public byte bits;
7682

7783
int bulkSize = ESNextOSQVectorsScorer.BULK_SIZE;
7884

@@ -90,9 +96,7 @@ public enum VectorImplementation {
9096

9197
int length;
9298

93-
byte[][] binaryVectors;
94-
byte[][] binaryQueries;
95-
OptimizedScalarQuantizer.QuantizationResult result;
99+
VectorScorerTestUtils.OSQVectorData[] binaryQueries;
96100
float centroidDp;
97101

98102
byte[] scratch;
@@ -111,17 +115,12 @@ public void setup() throws IOException {
111115
}
112116

113117
void setup(Random random) throws IOException {
114-
this.length = switch (bits) {
115-
case 1 -> ESNextDiskBBQVectorsFormat.QuantEncoding.ONE_BIT_4BIT_QUERY.getDocPackedLength(dims);
116-
case 2 -> ESNextDiskBBQVectorsFormat.QuantEncoding.TWO_BIT_4BIT_QUERY.getDocPackedLength(dims);
117-
case 4 -> ESNextDiskBBQVectorsFormat.QuantEncoding.FOUR_BIT_SYMMETRIC.getDocPackedLength(dims);
118-
default -> throw new IllegalArgumentException("Unsupported bits: " + bits);
119-
};
118+
this.length = ESNextDiskBBQVectorsFormat.QuantEncoding.fromBits(bits).getDocPackedLength(dims);
120119

121-
binaryVectors = new byte[numVectors][length];
122-
for (byte[] binaryVector : binaryVectors) {
123-
random.nextBytes(binaryVector);
124-
}
120+
final float[] centroid = new float[dims];
121+
randomVector(random, centroid, similarityFunction);
122+
123+
var quantizer = new org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer(similarityFunction);
125124

126125
directory = switch (directoryType) {
127126
case MMAP -> new MMapDirectory(createTempDirectory("vectorDataMmap"));
@@ -130,35 +129,27 @@ void setup(Random random) throws IOException {
130129
};
131130

132131
try (IndexOutput output = directory.createOutput("vectors", IOContext.DEFAULT)) {
133-
byte[] correctionBytes = new byte[16 * bulkSize];
132+
VectorScorerTestUtils.OSQVectorData[] vectors = new VectorScorerTestUtils.OSQVectorData[bulkSize];
134133
for (int i = 0; i < numVectors; i += bulkSize) {
135134
for (int j = 0; j < bulkSize; j++) {
136-
output.writeBytes(binaryVectors[i + j], 0, binaryVectors[i + j].length);
135+
var vector = new float[dims];
136+
randomVector(random, vector, similarityFunction);
137+
vectors[j] = createOSQIndexData(vector, centroid, quantizer, dims, bits, length);
137138
}
138-
random.nextBytes(correctionBytes);
139-
output.writeBytes(correctionBytes, 0, correctionBytes.length);
139+
writeBulkOSQVectorData(bulkSize, output, vectors);
140140
}
141141
CodecUtil.writeFooter(output);
142142
}
143143
input = directory.openInput("vectors", IOContext.DEFAULT);
144-
int binaryQueryLength = switch (bits) {
145-
case 1 -> ESNextDiskBBQVectorsFormat.QuantEncoding.ONE_BIT_4BIT_QUERY.getQueryPackedLength(dims);
146-
case 2 -> ESNextDiskBBQVectorsFormat.QuantEncoding.TWO_BIT_4BIT_QUERY.getQueryPackedLength(dims);
147-
case 4 -> ESNextDiskBBQVectorsFormat.QuantEncoding.FOUR_BIT_SYMMETRIC.getQueryPackedLength(dims);
148-
default -> throw new IllegalArgumentException("Unsupported bits: " + bits);
149-
};
144+
int binaryQueryLength = ESNextDiskBBQVectorsFormat.QuantEncoding.fromBits(bits).getQueryPackedLength(dims);
150145

151-
binaryQueries = new byte[numVectors][binaryQueryLength];
152-
for (byte[] binaryQuery : binaryQueries) {
153-
random.nextBytes(binaryQuery);
146+
binaryQueries = new VectorScorerTestUtils.OSQVectorData[numVectors];
147+
var query = new float[dims];
148+
for (int i = 0; i < numVectors; ++i) {
149+
randomVector(random, query, similarityFunction);
150+
binaryQueries[i] = createOSQQueryData(query, centroid, quantizer, dims, (byte) 4, binaryQueryLength);
154151
}
155-
result = new OptimizedScalarQuantizer.QuantizationResult(
156-
random.nextFloat(),
157-
random.nextFloat(),
158-
random.nextFloat(),
159-
Short.toUnsignedInt((short) random.nextInt())
160-
);
161-
centroidDp = random.nextFloat();
152+
centroidDp = VectorUtil.dotProduct(centroid, centroid);
162153

163154
scratch = new byte[length];
164155
final int docBits;
@@ -202,14 +193,14 @@ public float[] score() throws IOException {
202193
for (int j = 0; j < numQueries; j++) {
203194
input.seek(0);
204195
for (int i = 0; i < numVectors; i++) {
205-
float qDist = scorer.quantizeScore(binaryQueries[j]);
196+
float qDist = scorer.quantizeScore(binaryQueries[j].quantizedVector());
206197
input.readFloats(corrections, 0, corrections.length);
207198
int addition = Short.toUnsignedInt(input.readShort());
208199
float score = scorer.score(
209-
result.lowerInterval(),
210-
result.upperInterval(),
211-
result.quantizedComponentSum(),
212-
result.additionalCorrection(),
200+
binaryQueries[j].lowerInterval(),
201+
binaryQueries[j].upperInterval(),
202+
binaryQueries[j].quantizedComponentSum(),
203+
binaryQueries[j].additionalCorrection(),
213204
similarityFunction,
214205
centroidDp,
215206
corrections[0],
@@ -231,11 +222,11 @@ public float[] bulkScore() throws IOException {
231222
input.seek(0);
232223
for (int i = 0; i < numVectors; i += scratchScores.length) {
233224
scorer.scoreBulk(
234-
binaryQueries[j],
235-
result.lowerInterval(),
236-
result.upperInterval(),
237-
result.quantizedComponentSum(),
238-
result.additionalCorrection(),
225+
binaryQueries[j].quantizedVector(),
226+
binaryQueries[j].lowerInterval(),
227+
binaryQueries[j].upperInterval(),
228+
binaryQueries[j].quantizedComponentSum(),
229+
binaryQueries[j].additionalCorrection(),
239230
similarityFunction,
240231
centroidDp,
241232
scratchScores

benchmarks/src/test/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerOSQBenchmarkTests.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.apache.lucene.util.Constants;
1616
import org.elasticsearch.core.IOUtils;
1717
import org.elasticsearch.test.ESTestCase;
18+
import org.elasticsearch.test.junit.annotations.TestLogging;
1819
import org.junit.BeforeClass;
1920
import org.openjdk.jmh.annotations.Param;
2021

@@ -24,17 +25,22 @@
2425

2526
import static org.elasticsearch.common.util.CollectionUtils.appendToCopy;
2627

28+
@TestLogging(
29+
reason = "Noisy logging",
30+
value = "org.elasticsearch.env.NodeEnvironment:WARN,org.elasticsearch.xpack.searchablesnapshots.cache.full.PersistentCache:WARN"
31+
)
2732
public class VectorScorerOSQBenchmarkTests extends ESTestCase {
2833

34+
private static final int REPETITIONS = 10;
2935
private final float deltaPercent = 0.1f;
3036
private final int dims;
31-
private final int bits;
37+
private final byte bits;
3238
private final VectorScorerOSQBenchmark.DirectoryType directoryType;
3339
private final VectorSimilarityFunction similarityFunction;
3440

3541
public VectorScorerOSQBenchmarkTests(
3642
int dims,
37-
int bits,
43+
byte bits,
3844
VectorScorerOSQBenchmark.DirectoryType directoryType,
3945
VectorSimilarityFunction similarityFunction
4046
) {
@@ -50,7 +56,7 @@ public static void skipWindows() {
5056
}
5157

5258
public void testSingleScalarVsVectorized() throws Exception {
53-
for (int i = 0; i < 100; i++) {
59+
for (int i = 0; i < REPETITIONS; i++) {
5460
var seed = randomLong();
5561

5662
var scalar = new VectorScorerOSQBenchmark();
@@ -85,7 +91,7 @@ public void testSingleScalarVsVectorized() throws Exception {
8591
}
8692

8793
public void testBulkScalarVsVectorized() throws Exception {
88-
for (int i = 0; i < 100; i++) {
94+
for (int i = 0; i < REPETITIONS; i++) {
8995
var seed = randomLong();
9096

9197
var scalar = new VectorScorerOSQBenchmark();
@@ -128,7 +134,7 @@ public static Iterable<Object[]> parametersFactory() {
128134

129135
return () -> Arrays.stream(dims)
130136
.map(Integer::parseInt)
131-
.flatMap(d -> Arrays.stream(bits).map(Integer::parseInt).map(b -> List.<Object>of(d, b)))
137+
.flatMap(d -> Arrays.stream(bits).map(Byte::parseByte).map(b -> List.<Object>of(d, b)))
132138
.flatMap(params -> Arrays.stream(VectorScorerOSQBenchmark.DirectoryType.values()).map(dir -> appendToCopy(params, dir)))
133139
.flatMap(params -> Arrays.stream(VectorSimilarityFunction.values()).map(f -> appendToCopy(params, f).toArray()))
134140
.iterator();

libs/simdvec/build.gradle

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,20 @@ import org.elasticsearch.gradle.internal.precommit.CheckForbiddenApisTask
1212
apply plugin: 'elasticsearch.publish'
1313
apply plugin: 'elasticsearch.build'
1414
apply plugin: 'elasticsearch.mrjar'
15+
apply plugin: 'java-test-fixtures'
1516

1617
dependencies {
1718
implementation project(':libs:native')
1819
implementation project(':libs:logging')
1920
implementation "org.apache.lucene:lucene-core:${versions.lucene}"
2021

22+
testImplementation(testArtifact(project(':x-pack:plugin:searchable-snapshots')))
2123
testImplementation(project(":test:framework")) {
2224
exclude group: 'org.elasticsearch', module: 'native'
2325
}
24-
testImplementation(testArtifact(project(':x-pack:plugin:searchable-snapshots')))
26+
testFixturesImplementation(project(":test:framework")) {
27+
exclude group: 'org.elasticsearch', module: 'native'
28+
}
2529
}
2630

2731
// compileMain21Java does not exist within idea (see MrJarPlugin) so we cannot reference directly by name

libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ public void testScore() throws Exception {
120120
long qDist = defaultScorer.quantizeScore(quantizeQuery);
121121
slice.readFloats(floatScratch, 0, 3);
122122
int quantizedComponentSum = slice.readShort();
123-
float defaulScore = defaultScorer.score(
123+
float defaultScore = defaultScorer.score(
124124
queryCorrections.lowerInterval(),
125125
queryCorrections.upperInterval(),
126126
queryCorrections.quantizedComponentSum(),
@@ -149,7 +149,7 @@ public void testScore() throws Exception {
149149
floatScratch[2],
150150
qDist
151151
);
152-
assertEquals(defaulScore, panamaScore, 1e-2f);
152+
assertEquals(defaultScore, panamaScore, 1e-2f);
153153
assertEquals(((long) (i + 1) * (length + 14)), slice.getFilePointer());
154154
assertEquals(padding + ((long) (i + 1) * (length + 14)), in.getFilePointer());
155155
}
@@ -234,9 +234,7 @@ public void testScoreBulk() throws Exception {
234234
scoresPanama
235235
);
236236
assertEquals(defaultMaxScore, panamaMaxScore, 1e-2f);
237-
for (int j = 0; j < bulkSize; j++) {
238-
assertEquals(scoresDefault[j], scoresPanama[j], 1e-2f);
239-
}
237+
assertArrayEqualsPercent(scoresDefault, scoresPanama, 0.05f, 1e-2f);
240238
assertEquals(((long) (bulkSize) * (length + 14)), slice.getFilePointer());
241239
assertEquals(padding + ((long) (i + bulkSize) * (length + 14)), in.getFilePointer());
242240
}

0 commit comments

Comments
 (0)