Skip to content

Commit 63e2a3b

Browse files
authored
Fix max score calculation in MemorySegmentES91OSQVectorsScorer (#132433)
There is an error in how we compute max score in our panamized version of ES91OSQVectorsScorer after #132293. This commit fixes it and increases test coverage.
1 parent 364c70e commit 63e2a3b

File tree

2 files changed

+199
-89
lines changed

2 files changed

+199
-89
lines changed

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ private float score128Bulk(
452452
if (similarityFunction == EUCLIDEAN) {
453453
res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
454454
res = FloatVector.broadcast(FLOAT_SPECIES_128, 1).div(res).max(0);
455-
maxScore = res.reduceLanes(VectorOperators.MAX);
455+
maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX));
456456
res.intoArray(scores, i);
457457
} else {
458458
// For cosine and max inner product, we need to apply the additional correction, which is
@@ -468,7 +468,7 @@ private float score128Bulk(
468468
} else {
469469
res = res.add(1f).mul(0.5f).max(0);
470470
res.intoArray(scores, i);
471-
maxScore = res.reduceLanes(VectorOperators.MAX);
471+
maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX));
472472
}
473473
}
474474
}
@@ -527,7 +527,7 @@ private float score256Bulk(
527527
if (similarityFunction == EUCLIDEAN) {
528528
res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
529529
res = FloatVector.broadcast(FLOAT_SPECIES_256, 1).div(res).max(0);
530-
maxScore = res.reduceLanes(VectorOperators.MAX);
530+
maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX));
531531
res.intoArray(scores, i);
532532
} else {
533533
// For cosine and max inner product, we need to apply the additional correction, which is
@@ -542,7 +542,7 @@ private float score256Bulk(
542542
}
543543
} else {
544544
res = res.add(1f).mul(0.5f).max(0);
545-
maxScore = res.reduceLanes(VectorOperators.MAX);
545+
maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX));
546546
res.intoArray(scores, i);
547547
}
548548
}

libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java

Lines changed: 195 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,20 @@
1515
import org.apache.lucene.store.IndexInput;
1616
import org.apache.lucene.store.IndexOutput;
1717
import org.apache.lucene.store.MMapDirectory;
18-
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
18+
import org.apache.lucene.util.VectorUtil;
19+
import org.elasticsearch.index.codec.vectors.BQSpaceUtils;
20+
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
21+
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
22+
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
1923
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
2024

21-
import static org.hamcrest.Matchers.lessThan;
25+
import java.io.IOException;
2226

2327
public class ES91OSQVectorScorerTests extends BaseVectorizationTests {
2428

2529
public void testQuantizeScore() throws Exception {
2630
final int dimensions = random().nextInt(1, 2000);
27-
final int length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8;
31+
final int length = BQVectorUtils.discretize(dimensions, 64) / 8;
2832
final int numVectors = random().nextInt(1, 100);
2933
final byte[] vector = new byte[length];
3034
try (Directory dir = new MMapDirectory(createTempDir())) {
@@ -53,102 +57,208 @@ public void testQuantizeScore() throws Exception {
5357
}
5458

5559
public void testScore() throws Exception {
56-
final int maxDims = 512;
60+
final int maxDims = random().nextInt(1, 1000) * 2;
5761
final int dimensions = random().nextInt(1, maxDims);
58-
final int length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8;
59-
final int numVectors = ES91OSQVectorsScorer.BULK_SIZE * random().nextInt(1, 10);
60-
final byte[] vector = new byte[length];
62+
final int length = BQVectorUtils.discretize(dimensions, 64) / 8;
63+
final int numVectors = random().nextInt(10, 50);
64+
float[][] vectors = new float[numVectors][dimensions];
65+
final int[] scratch = new int[dimensions];
66+
final byte[] qVector = new byte[length];
67+
final float[] centroid = new float[dimensions];
68+
VectorSimilarityFunction similarityFunction = randomFrom(VectorSimilarityFunction.values());
69+
randomVector(centroid, similarityFunction);
70+
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction);
6171
int padding = random().nextInt(100);
6272
byte[] paddingBytes = new byte[padding];
6373
try (Directory dir = new MMapDirectory(createTempDir())) {
6474
try (IndexOutput out = dir.createOutput("testScore.bin", IOContext.DEFAULT)) {
6575
random().nextBytes(paddingBytes);
6676
out.writeBytes(paddingBytes, 0, padding);
77+
for (float[] vector : vectors) {
78+
randomVector(vector, similarityFunction);
79+
OptimizedScalarQuantizer.QuantizationResult result = quantizer.scalarQuantize(
80+
vector.clone(),
81+
scratch,
82+
(byte) 1,
83+
centroid
84+
);
85+
BQVectorUtils.packAsBinary(scratch, qVector);
86+
out.writeBytes(qVector, 0, qVector.length);
87+
out.writeInt(Float.floatToIntBits(result.lowerInterval()));
88+
out.writeInt(Float.floatToIntBits(result.upperInterval()));
89+
out.writeInt(Float.floatToIntBits(result.additionalCorrection()));
90+
out.writeShort((short) result.quantizedComponentSum());
91+
}
92+
}
93+
final float[] query = new float[dimensions];
94+
randomVector(query, similarityFunction);
95+
OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(
96+
query.clone(),
97+
scratch,
98+
(byte) 4,
99+
centroid
100+
);
101+
final byte[] quantizeQuery = new byte[4 * length];
102+
BQSpaceUtils.transposeHalfByte(scratch, quantizeQuery);
103+
final float centroidDp = VectorUtil.dotProduct(centroid, centroid);
104+
final float[] floatScratch = new float[3];
105+
try (IndexInput in = dir.openInput("testScore.bin", IOContext.DEFAULT)) {
106+
in.seek(padding);
107+
assertEquals(in.length(), padding + (long) numVectors * (length + 14));
108+
final IndexInput slice = in.slice("test", in.getFilePointer(), (long) (length + 14) * numVectors);
109+
// Work on a slice that has just the right number of bytes to make the test fail with an
110+
// index-out-of-bounds in case the implementation reads more than the allowed number of
111+
// padding bytes.
67112
for (int i = 0; i < numVectors; i++) {
68-
random().nextBytes(vector);
69-
out.writeBytes(vector, 0, length);
70-
float lower = random().nextFloat();
71-
float upper = random().nextFloat() + lower / 2;
72-
float additionalCorrection = random().nextFloat();
73-
int targetComponentSum = randomIntBetween(0, dimensions / 2);
74-
out.writeInt(Float.floatToIntBits(lower));
75-
out.writeInt(Float.floatToIntBits(upper));
76-
out.writeShort((short) targetComponentSum);
77-
out.writeInt(Float.floatToIntBits(additionalCorrection));
113+
final ES91OSQVectorsScorer defaultScorer = defaultProvider().newES91OSQVectorsScorer(slice, dimensions);
114+
final ES91OSQVectorsScorer panamaScorer = maybePanamaProvider().newES91OSQVectorsScorer(in, dimensions);
115+
long qDist = defaultScorer.quantizeScore(quantizeQuery);
116+
slice.readFloats(floatScratch, 0, 3);
117+
int quantizedComponentSum = slice.readShort();
118+
float defaulScore = defaultScorer.score(
119+
queryCorrections.lowerInterval(),
120+
queryCorrections.upperInterval(),
121+
queryCorrections.quantizedComponentSum(),
122+
queryCorrections.additionalCorrection(),
123+
similarityFunction,
124+
centroidDp,
125+
floatScratch[0],
126+
floatScratch[1],
127+
quantizedComponentSum,
128+
floatScratch[2],
129+
qDist
130+
);
131+
qDist = panamaScorer.quantizeScore(quantizeQuery);
132+
in.readFloats(floatScratch, 0, 3);
133+
quantizedComponentSum = in.readShort();
134+
float panamaScore = panamaScorer.score(
135+
queryCorrections.lowerInterval(),
136+
queryCorrections.upperInterval(),
137+
queryCorrections.quantizedComponentSum(),
138+
queryCorrections.additionalCorrection(),
139+
similarityFunction,
140+
centroidDp,
141+
floatScratch[0],
142+
floatScratch[1],
143+
quantizedComponentSum,
144+
floatScratch[2],
145+
qDist
146+
);
147+
assertEquals(defaulScore, panamaScore, 1e-2f);
148+
assertEquals(((long) (i + 1) * (length + 14)), slice.getFilePointer());
149+
assertEquals(padding + ((long) (i + 1) * (length + 14)), in.getFilePointer());
78150
}
79151
}
80-
final byte[] query = new byte[4 * length];
81-
random().nextBytes(query);
82-
float lower = random().nextFloat();
83-
OptimizedScalarQuantizer.QuantizationResult result = new OptimizedScalarQuantizer.QuantizationResult(
84-
lower,
85-
random().nextFloat() + lower / 2,
86-
random().nextFloat(),
87-
randomIntBetween(0, dimensions * 2)
152+
}
153+
}
154+
155+
public void testScoreBulk() throws Exception {
156+
final int maxDims = random().nextInt(1, 1000) * 2;
157+
final int dimensions = random().nextInt(1, maxDims);
158+
final int length = BQVectorUtils.discretize(dimensions, 64) / 8;
159+
final int numVectors = ES91OSQVectorsScorer.BULK_SIZE * random().nextInt(1, 10);
160+
float[][] vectors = new float[numVectors][dimensions];
161+
final int[] scratch = new int[dimensions];
162+
final byte[] qVector = new byte[length];
163+
final float[] centroid = new float[dimensions];
164+
VectorSimilarityFunction similarityFunction = randomFrom(VectorSimilarityFunction.values());
165+
randomVector(centroid, similarityFunction);
166+
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction);
167+
int padding = random().nextInt(100);
168+
byte[] paddingBytes = new byte[padding];
169+
try (Directory dir = new MMapDirectory(createTempDir())) {
170+
try (IndexOutput out = dir.createOutput("testScore.bin", IOContext.DEFAULT)) {
171+
random().nextBytes(paddingBytes);
172+
out.writeBytes(paddingBytes, 0, padding);
173+
int limit = numVectors - ES91OSQVectorsScorer.BULK_SIZE + 1;
174+
OptimizedScalarQuantizer.QuantizationResult[] results =
175+
new OptimizedScalarQuantizer.QuantizationResult[ES91Int4VectorsScorer.BULK_SIZE];
176+
for (int i = 0; i < limit; i += ES91OSQVectorsScorer.BULK_SIZE) {
177+
for (int j = 0; j < ES91Int4VectorsScorer.BULK_SIZE; j++) {
178+
randomVector(vectors[i + j], similarityFunction);
179+
results[j] = quantizer.scalarQuantize(vectors[i + j].clone(), scratch, (byte) 1, centroid);
180+
BQVectorUtils.packAsBinary(scratch, qVector);
181+
out.writeBytes(qVector, 0, qVector.length);
182+
}
183+
writeCorrections(results, out);
184+
}
185+
}
186+
final float[] query = new float[dimensions];
187+
randomVector(query, similarityFunction);
188+
OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(
189+
query.clone(),
190+
scratch,
191+
(byte) 4,
192+
centroid
88193
);
89-
final float centroidDp = random().nextFloat();
90-
final float[] scores1 = new float[ES91OSQVectorsScorer.BULK_SIZE];
91-
final float[] scores2 = new float[ES91OSQVectorsScorer.BULK_SIZE];
92-
for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) {
93-
try (IndexInput in = dir.openInput("testScore.bin", IOContext.DEFAULT)) {
94-
in.seek(padding);
95-
assertEquals(in.length(), padding + (long) numVectors * (length + 14));
96-
// Work on a slice that has just the right number of bytes to make the test fail with an
97-
// index-out-of-bounds in case the implementation reads more than the allowed number of
98-
// padding bytes.
99-
for (int i = 0; i < numVectors; i += ES91OSQVectorsScorer.BULK_SIZE) {
100-
final IndexInput slice = in.slice(
101-
"test",
102-
in.getFilePointer(),
103-
(long) (length + 14) * ES91OSQVectorsScorer.BULK_SIZE
104-
);
105-
final ES91OSQVectorsScorer defaultScorer = defaultProvider().newES91OSQVectorsScorer(slice, dimensions);
106-
final ES91OSQVectorsScorer panamaScorer = maybePanamaProvider().newES91OSQVectorsScorer(in, dimensions);
107-
defaultScorer.scoreBulk(
108-
query,
109-
result.lowerInterval(),
110-
result.upperInterval(),
111-
result.quantizedComponentSum(),
112-
result.additionalCorrection(),
113-
similarityFunction,
114-
centroidDp,
115-
scores1
116-
);
117-
panamaScorer.scoreBulk(
118-
query,
119-
result.lowerInterval(),
120-
result.upperInterval(),
121-
result.quantizedComponentSum(),
122-
result.additionalCorrection(),
123-
similarityFunction,
124-
centroidDp,
125-
scores2
126-
);
127-
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
128-
if (scores1[j] == scores2[j]) {
129-
continue;
130-
}
131-
if (scores1[j] > (maxDims * Byte.MAX_VALUE)) {
132-
float diff = Math.abs(scores1[j] - scores2[j]);
133-
assertThat(
134-
"defaultScores: " + scores1[j] + " bulkScores: " + scores2[j],
135-
diff / scores1[j],
136-
lessThan(1e-5f)
137-
);
138-
assertThat(
139-
"defaultScores: " + scores1[j] + " bulkScores: " + scores2[j],
140-
diff / scores2[j],
141-
lessThan(1e-5f)
142-
);
143-
} else {
144-
assertEquals(scores1[j], scores2[j], 1e-2f);
145-
}
146-
}
147-
assertEquals(((long) (ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), slice.getFilePointer());
148-
assertEquals(padding + ((long) (i + ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), in.getFilePointer());
194+
final byte[] quantizeQuery = new byte[4 * length];
195+
BQSpaceUtils.transposeHalfByte(scratch, quantizeQuery);
196+
final float centroidDp = VectorUtil.dotProduct(centroid, centroid);
197+
final float[] scoresDefault = new float[ES91OSQVectorsScorer.BULK_SIZE];
198+
final float[] scoresPanama = new float[ES91OSQVectorsScorer.BULK_SIZE];
199+
try (IndexInput in = dir.openInput("testScore.bin", IOContext.DEFAULT)) {
200+
in.seek(padding);
201+
assertEquals(in.length(), padding + (long) numVectors * (length + 14));
202+
// Work on a slice that has just the right number of bytes to make the test fail with an
203+
// index-out-of-bounds in case the implementation reads more than the allowed number of
204+
// padding bytes.
205+
for (int i = 0; i < numVectors; i += ES91OSQVectorsScorer.BULK_SIZE) {
206+
final IndexInput slice = in.slice("test", in.getFilePointer(), (long) (length + 14) * ES91OSQVectorsScorer.BULK_SIZE);
207+
final ES91OSQVectorsScorer defaultScorer = defaultProvider().newES91OSQVectorsScorer(slice, dimensions);
208+
final ES91OSQVectorsScorer panamaScorer = maybePanamaProvider().newES91OSQVectorsScorer(in, dimensions);
209+
float defaultMaxScore = defaultScorer.scoreBulk(
210+
quantizeQuery,
211+
queryCorrections.lowerInterval(),
212+
queryCorrections.upperInterval(),
213+
queryCorrections.quantizedComponentSum(),
214+
queryCorrections.additionalCorrection(),
215+
similarityFunction,
216+
centroidDp,
217+
scoresDefault
218+
);
219+
float panamaMaxScore = panamaScorer.scoreBulk(
220+
quantizeQuery,
221+
queryCorrections.lowerInterval(),
222+
queryCorrections.upperInterval(),
223+
queryCorrections.quantizedComponentSum(),
224+
queryCorrections.additionalCorrection(),
225+
similarityFunction,
226+
centroidDp,
227+
scoresPanama
228+
);
229+
assertEquals(defaultMaxScore, panamaMaxScore, 1e-2f);
230+
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
231+
assertEquals(scoresDefault[j], scoresPanama[j], 1e-2f);
149232
}
233+
assertEquals(((long) (ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), slice.getFilePointer());
234+
assertEquals(padding + ((long) (i + ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), in.getFilePointer());
150235
}
151236
}
152237
}
153238
}
239+
240+
private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException {
241+
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
242+
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
243+
}
244+
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
245+
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
246+
}
247+
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
248+
int targetComponentSum = correction.quantizedComponentSum();
249+
out.writeShort((short) targetComponentSum);
250+
}
251+
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
252+
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
253+
}
254+
}
255+
256+
private void randomVector(float[] vector, VectorSimilarityFunction vectorSimilarityFunction) {
257+
for (int i = 0; i < vector.length; i++) {
258+
vector[i] = random().nextFloat();
259+
}
260+
if (vectorSimilarityFunction != VectorSimilarityFunction.EUCLIDEAN) {
261+
VectorUtil.l2normalize(vector);
262+
}
263+
}
154264
}

0 commit comments

Comments
 (0)