Skip to content

Commit 4e926ae

Browse files
authored
Minor ivf cleanups and fixing quantization performance (#129566)
We are accidentally utilizing the non-vectorized quantizer when building ivf indices. This provides a 3-5x speed improvement on quantizing on my mac This fixes that and addresses some minor fixes (removing unused code, etc.) Here is a small benchmark result. time spent quantizing goes down significantly. <img width="652" alt="image" src="https://github.com/user-attachments/assets/9f46398c-c587-4e74-bc91-f2e07a63b406" /> vs. <img width="673" alt="image" src="https://github.com/user-attachments/assets/c4f4679f-d7a7-4486-841f-7dd3e75a11cb" />
1 parent 90c24d0 commit 4e926ae

File tree

11 files changed

+139
-75
lines changed

11 files changed

+139
-75
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,10 @@ public void scoreFromArray(Blackhole bh) throws IOException {
126126
in.readFloats(corrections, 0, corrections.length);
127127
int addition = Short.toUnsignedInt(in.readShort());
128128
float score = scorer.score(
129-
result,
129+
result.lowerInterval(),
130+
result.upperInterval(),
131+
result.quantizedComponentSum(),
132+
result.additionalCorrection(),
130133
VectorSimilarityFunction.EUCLIDEAN,
131134
centroidDp,
132135
corrections[0],
@@ -150,7 +153,10 @@ public void scoreFromMemorySegmentOnlyVector(Blackhole bh) throws IOException {
150153
in.readFloats(corrections, 0, corrections.length);
151154
int addition = Short.toUnsignedInt(in.readShort());
152155
float score = scorer.score(
153-
result,
156+
result.lowerInterval(),
157+
result.upperInterval(),
158+
result.quantizedComponentSum(),
159+
result.additionalCorrection(),
154160
VectorSimilarityFunction.EUCLIDEAN,
155161
centroidDp,
156162
corrections[0],
@@ -175,7 +181,10 @@ public void scoreFromMemorySegmentOnlyVectorBulk(Blackhole bh) throws IOExceptio
175181
in.readFloats(corrections, 0, corrections.length);
176182
int addition = Short.toUnsignedInt(in.readShort());
177183
float score = scorer.score(
178-
result,
184+
result.lowerInterval(),
185+
result.upperInterval(),
186+
result.quantizedComponentSum(),
187+
result.additionalCorrection(),
179188
VectorSimilarityFunction.EUCLIDEAN,
180189
centroidDp,
181190
corrections[0],
@@ -196,7 +205,16 @@ public void scoreFromMemorySegmentAllBulk(Blackhole bh) throws IOException {
196205
for (int j = 0; j < numQueries; j++) {
197206
in.seek(0);
198207
for (int i = 0; i < numVectors; i += 16) {
199-
scorer.scoreBulk(binaryQueries[j], result, VectorSimilarityFunction.EUCLIDEAN, centroidDp, scratchScores);
208+
scorer.scoreBulk(
209+
binaryQueries[j],
210+
result.lowerInterval(),
211+
result.upperInterval(),
212+
result.quantizedComponentSum(),
213+
result.additionalCorrection(),
214+
VectorSimilarityFunction.EUCLIDEAN,
215+
centroidDp,
216+
scratchScores
217+
);
200218
bh.consume(scratchScores);
201219
}
202220
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,10 @@ public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOExce
9595
* Computes the score by applying the necessary corrections to the provided quantized distance.
9696
*/
9797
public float score(
98-
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
98+
float queryLowerInterval,
99+
float queryUpperInterval,
100+
int queryComponentSum,
101+
float queryAdditionalCorrection,
99102
VectorSimilarityFunction similarityFunction,
100103
float centroidDp,
101104
float lowerInterval,
@@ -107,19 +110,19 @@ public float score(
107110
float ax = lowerInterval;
108111
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
109112
float lx = upperInterval - ax;
110-
float ay = queryCorrections.lowerInterval();
111-
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
112-
float y1 = queryCorrections.quantizedComponentSum();
113+
float ay = queryLowerInterval;
114+
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
115+
float y1 = queryComponentSum;
113116
float score = ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
114117
// For euclidean, we need to invert the score and apply the additional correction, which is
115118
// assumed to be the squared l2norm of the centroid centered vectors.
116119
if (similarityFunction == EUCLIDEAN) {
117-
score = queryCorrections.additionalCorrection() + additionalCorrection - 2 * score;
120+
score = queryAdditionalCorrection + additionalCorrection - 2 * score;
118121
return Math.max(1 / (1f + score), 0);
119122
} else {
120123
// For cosine and max inner product, we need to apply the additional correction, which is
121124
// assumed to be the non-centered dot-product between the vector and the centroid
122-
score += queryCorrections.additionalCorrection() + additionalCorrection - centroidDp;
125+
score += queryAdditionalCorrection + additionalCorrection - centroidDp;
123126
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
124127
return VectorUtil.scaleMaxInnerProductScore(score);
125128
}
@@ -140,7 +143,10 @@ public float score(
140143
*/
141144
public void scoreBulk(
142145
byte[] q,
143-
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
146+
float queryLowerInterval,
147+
float queryUpperInterval,
148+
int queryComponentSum,
149+
float queryAdditionalCorrection,
144150
VectorSimilarityFunction similarityFunction,
145151
float centroidDp,
146152
float[] scores
@@ -154,7 +160,10 @@ public void scoreBulk(
154160
in.readFloats(additionalCorrections, 0, BULK_SIZE);
155161
for (int i = 0; i < BULK_SIZE; i++) {
156162
scores[i] = score(
157-
queryCorrections,
163+
queryLowerInterval,
164+
queryUpperInterval,
165+
queryComponentSum,
166+
queryAdditionalCorrection,
158167
similarityFunction,
159168
centroidDp,
160169
lowerIntervals[i],

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import org.apache.lucene.index.VectorSimilarityFunction;
2020
import org.apache.lucene.store.IndexInput;
2121
import org.apache.lucene.util.VectorUtil;
22-
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
2322
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
2423

2524
import java.io.IOException;
@@ -298,7 +297,10 @@ private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IO
298297
@Override
299298
public void scoreBulk(
300299
byte[] q,
301-
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
300+
float queryLowerInterval,
301+
float queryUpperInterval,
302+
int queryComponentSum,
303+
float queryAdditionalCorrection,
302304
VectorSimilarityFunction similarityFunction,
303305
float centroidDp,
304306
float[] scores
@@ -307,19 +309,49 @@ public void scoreBulk(
307309
// 128 / 8 == 16
308310
if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
309311
if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) {
310-
score256Bulk(q, queryCorrections, similarityFunction, centroidDp, scores);
312+
score256Bulk(
313+
q,
314+
queryLowerInterval,
315+
queryUpperInterval,
316+
queryComponentSum,
317+
queryAdditionalCorrection,
318+
similarityFunction,
319+
centroidDp,
320+
scores
321+
);
311322
return;
312323
} else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) {
313-
score128Bulk(q, queryCorrections, similarityFunction, centroidDp, scores);
324+
score128Bulk(
325+
q,
326+
queryLowerInterval,
327+
queryUpperInterval,
328+
queryComponentSum,
329+
queryAdditionalCorrection,
330+
similarityFunction,
331+
centroidDp,
332+
scores
333+
);
314334
return;
315335
}
316336
}
317-
super.scoreBulk(q, queryCorrections, similarityFunction, centroidDp, scores);
337+
super.scoreBulk(
338+
q,
339+
queryLowerInterval,
340+
queryUpperInterval,
341+
queryComponentSum,
342+
queryAdditionalCorrection,
343+
similarityFunction,
344+
centroidDp,
345+
scores
346+
);
318347
}
319348

320349
private void score128Bulk(
321350
byte[] q,
322-
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
351+
float queryLowerInterval,
352+
float queryUpperInterval,
353+
int queryComponentSum,
354+
float queryAdditionalCorrection,
323355
VectorSimilarityFunction similarityFunction,
324356
float centroidDp,
325357
float[] scores
@@ -328,9 +360,9 @@ private void score128Bulk(
328360
int limit = FLOAT_SPECIES_128.loopBound(BULK_SIZE);
329361
int i = 0;
330362
long offset = in.getFilePointer();
331-
float ay = queryCorrections.lowerInterval();
332-
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
333-
float y1 = queryCorrections.quantizedComponentSum();
363+
float ay = queryLowerInterval;
364+
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
365+
float y1 = queryComponentSum;
334366
for (; i < limit; i += FLOAT_SPECIES_128.length()) {
335367
var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
336368
var lx = FloatVector.fromMemorySegment(
@@ -362,13 +394,13 @@ private void score128Bulk(
362394
// For euclidean, we need to invert the score and apply the additional correction, which is
363395
// assumed to be the squared l2norm of the centroid centered vectors.
364396
if (similarityFunction == EUCLIDEAN) {
365-
res = res.mul(-2).add(additionalCorrections).add(queryCorrections.additionalCorrection()).add(1f);
397+
res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
366398
res = FloatVector.broadcast(FLOAT_SPECIES_128, 1).div(res).max(0);
367399
res.intoArray(scores, i);
368400
} else {
369401
// For cosine and max inner product, we need to apply the additional correction, which is
370402
// assumed to be the non-centered dot-product between the vector and the centroid
371-
res = res.add(queryCorrections.additionalCorrection()).add(additionalCorrections).sub(centroidDp);
403+
res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp);
372404
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
373405
res.intoArray(scores, i);
374406
// not sure how to do it better
@@ -386,7 +418,10 @@ private void score128Bulk(
386418

387419
private void score256Bulk(
388420
byte[] q,
389-
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
421+
float queryLowerInterval,
422+
float queryUpperInterval,
423+
int queryComponentSum,
424+
float queryAdditionalCorrection,
390425
VectorSimilarityFunction similarityFunction,
391426
float centroidDp,
392427
float[] scores
@@ -395,9 +430,9 @@ private void score256Bulk(
395430
int limit = FLOAT_SPECIES_256.loopBound(BULK_SIZE);
396431
int i = 0;
397432
long offset = in.getFilePointer();
398-
float ay = queryCorrections.lowerInterval();
399-
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
400-
float y1 = queryCorrections.quantizedComponentSum();
433+
float ay = queryLowerInterval;
434+
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
435+
float y1 = queryComponentSum;
401436
for (; i < limit; i += FLOAT_SPECIES_256.length()) {
402437
var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
403438
var lx = FloatVector.fromMemorySegment(
@@ -429,13 +464,13 @@ private void score256Bulk(
429464
// For euclidean, we need to invert the score and apply the additional correction, which is
430465
// assumed to be the squared l2norm of the centroid centered vectors.
431466
if (similarityFunction == EUCLIDEAN) {
432-
res = res.mul(-2).add(additionalCorrections).add(queryCorrections.additionalCorrection()).add(1f);
467+
res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
433468
res = FloatVector.broadcast(FLOAT_SPECIES_256, 1).div(res).max(0);
434469
res.intoArray(scores, i);
435470
} else {
436471
// For cosine and max inner product, we need to apply the additional correction, which is
437472
// assumed to be the non-centered dot-product between the vector and the centroid
438-
res = res.add(queryCorrections.additionalCorrection()).add(additionalCorrections).sub(centroidDp);
473+
res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp);
439474
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
440475
res.intoArray(scores, i);
441476
// not sure how to do it better

libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,26 @@ public void testScore() throws Exception {
104104
);
105105
final ES91OSQVectorsScorer defaultScorer = defaultProvider().newES91OSQVectorsScorer(slice, dimensions);
106106
final ES91OSQVectorsScorer panamaScorer = maybePanamaProvider().newES91OSQVectorsScorer(in, dimensions);
107-
defaultScorer.scoreBulk(query, result, similarityFunction, centroidDp, scores1);
108-
panamaScorer.scoreBulk(query, result, similarityFunction, centroidDp, scores2);
107+
defaultScorer.scoreBulk(
108+
query,
109+
result.lowerInterval(),
110+
result.upperInterval(),
111+
result.quantizedComponentSum(),
112+
result.additionalCorrection(),
113+
similarityFunction,
114+
centroidDp,
115+
scores1
116+
);
117+
panamaScorer.scoreBulk(
118+
query,
119+
result.lowerInterval(),
120+
result.upperInterval(),
121+
result.quantizedComponentSum(),
122+
result.additionalCorrection(),
123+
similarityFunction,
124+
centroidDp,
125+
scores2
126+
);
109127
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
110128
if (scores1[j] > (maxDims * Short.MAX_VALUE)) {
111129
int diff = (int) (scores1[j] - scores2[j]);

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import org.apache.lucene.util.ArrayUtil;
1919
import org.apache.lucene.util.VectorUtil;
2020
import org.apache.lucene.util.hnsw.NeighborQueue;
21-
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
2221
import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats;
2322
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
2423
import org.elasticsearch.simdvec.ESVectorUtil;
@@ -31,8 +30,8 @@
3130
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
3231
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
3332
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
34-
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize;
35-
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.transposeHalfByte;
33+
import static org.elasticsearch.index.codec.vectors.BQSpaceUtils.transposeHalfByte;
34+
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
3635
import static org.elasticsearch.simdvec.ES91OSQVectorsScorer.BULK_SIZE;
3736

3837
/**
@@ -47,13 +46,8 @@ public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVect
4746
}
4847

4948
@Override
50-
CentroidQueryScorer getCentroidScorer(
51-
FieldInfo fieldInfo,
52-
int numCentroids,
53-
IndexInput centroids,
54-
float[] targetQuery,
55-
IndexInput clusters
56-
) throws IOException {
49+
CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery)
50+
throws IOException {
5751
FieldEntry fieldEntry = fields.get(fieldInfo.number);
5852
float[] globalCentroid = fieldEntry.globalCentroid();
5953
float globalCentroidDp = fieldEntry.globalCentroidDp();
@@ -259,7 +253,10 @@ void scoreIndividually(int offset) throws IOException {
259253
int doc = docIdsScratch[offset + j];
260254
if (doc != -1) {
261255
scores[j] = osqVectorsScorer.score(
262-
queryCorrections,
256+
queryCorrections.lowerInterval(),
257+
queryCorrections.upperInterval(),
258+
queryCorrections.quantizedComponentSum(),
259+
queryCorrections.additionalCorrection(),
263260
fieldInfo.getVectorSimilarityFunction(),
264261
centroidDp,
265262
correctionsLower[j],
@@ -297,7 +294,10 @@ public int visit(KnnCollector knnCollector) throws IOException {
297294
} else {
298295
osqVectorsScorer.scoreBulk(
299296
quantizedQueryScratch,
300-
queryCorrections,
297+
queryCorrections.lowerInterval(),
298+
queryCorrections.upperInterval(),
299+
queryCorrections.quantizedComponentSum(),
300+
queryCorrections.additionalCorrection(),
301301
fieldInfo.getVectorSimilarityFunction(),
302302
centroidDp,
303303
scores
@@ -321,7 +321,10 @@ public int visit(KnnCollector knnCollector) throws IOException {
321321
indexInput.readFloats(correctiveValues, 0, 3);
322322
final int quantizedComponentSum = Short.toUnsignedInt(indexInput.readShort());
323323
float score = osqVectorsScorer.score(
324-
queryCorrections,
324+
queryCorrections.lowerInterval(),
325+
queryCorrections.upperInterval(),
326+
queryCorrections.quantizedComponentSum(),
327+
queryCorrections.additionalCorrection(),
325328
fieldInfo.getVectorSimilarityFunction(),
326329
centroidDp,
327330
correctiveValues[0],

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import org.apache.lucene.store.IndexInput;
1919
import org.apache.lucene.store.IndexOutput;
2020
import org.apache.lucene.util.VectorUtil;
21-
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
2221
import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans;
2322
import org.elasticsearch.index.codec.vectors.cluster.KMeansResult;
2423
import org.elasticsearch.logging.LogManager;
@@ -30,8 +29,8 @@
3029
import java.nio.ByteOrder;
3130

3231
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS;
33-
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize;
34-
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.packAsBinary;
32+
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
33+
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.packAsBinary;
3534

3635
/**
3736
* Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to

0 commit comments

Comments
 (0)