Skip to content

Commit e087048

Browse files
authored
[DiskBBQ] Quantize centroids using 7 bits instead of 4 bits (#132261)
1 parent 6ca0466 commit e087048

File tree

3 files changed

+73
-51
lines changed

3 files changed

+73
-51
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
import org.apache.lucene.util.VectorUtil;
2020
import org.apache.lucene.util.hnsw.NeighborQueue;
2121
import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats;
22-
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
2322
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
23+
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
2424
import org.elasticsearch.simdvec.ESVectorUtil;
2525

2626
import java.io.IOException;
@@ -61,14 +61,14 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde
6161
final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
6262
targetQueryCopy,
6363
scratch,
64-
(byte) 4,
64+
(byte) 7,
6565
fieldEntry.globalCentroid()
6666
);
6767
final byte[] quantized = new byte[targetQuery.length];
6868
for (int i = 0; i < quantized.length; i++) {
6969
quantized[i] = (byte) scratch[i];
7070
}
71-
final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
71+
final ES92Int7VectorsScorer scorer = ESVectorUtil.getES92Int7VectorsScorer(centroids, fieldInfo.getVectorDimension());
7272
centroids.seek(0L);
7373
int numParents = centroids.readVInt();
7474
if (numParents > 0) {
@@ -90,7 +90,7 @@ private static CentroidIterator getCentroidIteratorNoParent(
9090
FieldInfo fieldInfo,
9191
IndexInput centroids,
9292
int numCentroids,
93-
ES91Int4VectorsScorer scorer,
93+
ES92Int7VectorsScorer scorer,
9494
byte[] quantizeQuery,
9595
OptimizedScalarQuantizer.QuantizationResult queryParams,
9696
float globalCentroidDp
@@ -105,7 +105,7 @@ private static CentroidIterator getCentroidIteratorNoParent(
105105
queryParams,
106106
globalCentroidDp,
107107
fieldInfo.getVectorSimilarityFunction(),
108-
new float[ES91Int4VectorsScorer.BULK_SIZE]
108+
new float[ES92Int7VectorsScorer.BULK_SIZE]
109109
);
110110
long offset = centroids.getFilePointer();
111111
return new CentroidIterator() {
@@ -128,7 +128,7 @@ private static CentroidIterator getCentroidIteratorWithParents(
128128
IndexInput centroids,
129129
int numParents,
130130
int numCentroids,
131-
ES91Int4VectorsScorer scorer,
131+
ES92Int7VectorsScorer scorer,
132132
byte[] quantizeQuery,
133133
OptimizedScalarQuantizer.QuantizationResult queryParams,
134134
float globalCentroidDp
@@ -140,7 +140,7 @@ private static CentroidIterator getCentroidIteratorWithParents(
140140
final int bufferSize = (int) Math.max(numCentroids * CENTROID_SAMPLING_PERCENTAGE, 1);
141141
final NeighborQueue neighborQueue = new NeighborQueue(bufferSize, true);
142142
// score the parents
143-
final float[] scores = new float[ES91Int4VectorsScorer.BULK_SIZE];
143+
final float[] scores = new float[ES92Int7VectorsScorer.BULK_SIZE];
144144
score(
145145
parentsQueue,
146146
numParents,
@@ -152,7 +152,7 @@ private static CentroidIterator getCentroidIteratorWithParents(
152152
fieldInfo.getVectorSimilarityFunction(),
153153
scores
154154
);
155-
final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES;
155+
final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Integer.BYTES;
156156
final long offset = centroids.getFilePointer();
157157
final long childrenOffset = offset + (long) Long.BYTES * numParents;
158158
// populate the children's queue by reading parents one by one
@@ -227,7 +227,7 @@ private static void populateOneChildrenGroup(
227227
long childrenOffset,
228228
long centroidQuantizeSize,
229229
FieldInfo fieldInfo,
230-
ES91Int4VectorsScorer scorer,
230+
ES92Int7VectorsScorer scorer,
231231
byte[] quantizeQuery,
232232
OptimizedScalarQuantizer.QuantizationResult queryParams,
233233
float globalCentroidDp,
@@ -254,16 +254,16 @@ private static void score(
254254
NeighborQueue neighborQueue,
255255
int size,
256256
int scoresOffset,
257-
ES91Int4VectorsScorer scorer,
257+
ES92Int7VectorsScorer scorer,
258258
byte[] quantizeQuery,
259259
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
260260
float centroidDp,
261261
VectorSimilarityFunction similarityFunction,
262262
float[] scores
263263
) throws IOException {
264-
int limit = size - ES91Int4VectorsScorer.BULK_SIZE + 1;
264+
int limit = size - ES92Int7VectorsScorer.BULK_SIZE + 1;
265265
int i = 0;
266-
for (; i < limit; i += ES91Int4VectorsScorer.BULK_SIZE) {
266+
for (; i < limit; i += ES92Int7VectorsScorer.BULK_SIZE) {
267267
scorer.scoreBulk(
268268
quantizeQuery,
269269
queryCorrections.lowerInterval(),
@@ -274,7 +274,7 @@ private static void score(
274274
centroidDp,
275275
scores
276276
);
277-
for (int j = 0; j < ES91Int4VectorsScorer.BULK_SIZE; j++) {
277+
for (int j = 0; j < ES92Int7VectorsScorer.BULK_SIZE; j++) {
278278
neighborQueue.add(scoresOffset + i + j, scores[j]);
279279
}
280280
}

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
import org.elasticsearch.index.codec.vectors.cluster.KMeansResult;
2727
import org.elasticsearch.logging.LogManager;
2828
import org.elasticsearch.logging.Logger;
29-
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
3029
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
30+
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
3131

3232
import java.io.IOException;
3333
import java.io.UncheckedIOException;
@@ -315,8 +315,8 @@ private void writeCentroidsWithParents(
315315
LongValues offsets,
316316
IndexOutput centroidOutput
317317
) throws IOException {
318-
DiskBBQBulkWriter.FourBitDiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.FourBitDiskBBQBulkWriter(
319-
ES91Int4VectorsScorer.BULK_SIZE,
318+
DiskBBQBulkWriter.SevenBitDiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.SevenBitDiskBBQBulkWriter(
319+
ES92Int7VectorsScorer.BULK_SIZE,
320320
centroidOutput
321321
);
322322
final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
@@ -365,8 +365,8 @@ private void writeCentroidsWithoutParents(
365365
IndexOutput centroidOutput
366366
) throws IOException {
367367
centroidOutput.writeVInt(0);
368-
DiskBBQBulkWriter.FourBitDiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.FourBitDiskBBQBulkWriter(
369-
ES91Int4VectorsScorer.BULK_SIZE,
368+
DiskBBQBulkWriter.SevenBitDiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.SevenBitDiskBBQBulkWriter(
369+
ES92Int7VectorsScorer.BULK_SIZE,
370370
centroidOutput
371371
);
372372
final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
@@ -571,7 +571,7 @@ public byte[] next() throws IOException {
571571
// Its possible that the vectors are on-heap and we cannot mutate them as we may quantize twice
572572
// due to overspill, so we copy the vector to a scratch array
573573
System.arraycopy(vector, 0, floatVectorScratch, 0, vector.length);
574-
corrections = quantizer.scalarQuantize(floatVectorScratch, quantizedVectorScratch, (byte) 4, centroid);
574+
corrections = quantizer.scalarQuantize(floatVectorScratch, quantizedVectorScratch, (byte) 7, centroid);
575575
for (int i = 0; i < quantizedVectorScratch.length; i++) {
576576
quantizedVector[i] = (byte) quantizedVectorScratch[i];
577577
}

server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java

Lines changed: 54 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -29,32 +29,6 @@ protected DiskBBQBulkWriter(int bulkSize, IndexOutput out) {
2929

3030
abstract void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException;
3131

32-
private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException {
33-
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
34-
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
35-
}
36-
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
37-
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
38-
}
39-
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
40-
int targetComponentSum = correction.quantizedComponentSum();
41-
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
42-
out.writeShort((short) targetComponentSum);
43-
}
44-
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
45-
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
46-
}
47-
}
48-
49-
private static void writeCorrection(OptimizedScalarQuantizer.QuantizationResult correction, IndexOutput out) throws IOException {
50-
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
51-
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
52-
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
53-
int targetComponentSum = correction.quantizedComponentSum();
54-
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
55-
out.writeShort((short) targetComponentSum);
56-
}
57-
5832
static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
5933
private final OptimizedScalarQuantizer.QuantizationResult[] corrections;
6034

@@ -73,22 +47,48 @@ void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOEx
7347
corrections[j] = qvv.getCorrections();
7448
out.writeBytes(qv, qv.length);
7549
}
76-
writeCorrections(corrections, out);
50+
writeCorrections(corrections);
7751
}
7852
// write tail
7953
for (; i < qvv.count(); ++i) {
8054
byte[] qv = qvv.next();
8155
OptimizedScalarQuantizer.QuantizationResult correction = qvv.getCorrections();
8256
out.writeBytes(qv, qv.length);
83-
writeCorrection(correction, out);
57+
writeCorrection(correction);
58+
}
59+
}
60+
61+
private void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections) throws IOException {
62+
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
63+
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
64+
}
65+
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
66+
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
67+
}
68+
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
69+
int targetComponentSum = correction.quantizedComponentSum();
70+
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
71+
out.writeShort((short) targetComponentSum);
72+
}
73+
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
74+
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
8475
}
8576
}
77+
78+
private void writeCorrection(OptimizedScalarQuantizer.QuantizationResult correction) throws IOException {
79+
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
80+
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
81+
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
82+
int targetComponentSum = correction.quantizedComponentSum();
83+
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
84+
out.writeShort((short) targetComponentSum);
85+
}
8686
}
8787

88-
static class FourBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
88+
static class SevenBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
8989
private final OptimizedScalarQuantizer.QuantizationResult[] corrections;
9090

91-
FourBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) {
91+
SevenBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) {
9292
super(bulkSize, out);
9393
this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize];
9494
}
@@ -103,15 +103,37 @@ void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOEx
103103
corrections[j] = qvv.getCorrections();
104104
out.writeBytes(qv, qv.length);
105105
}
106-
writeCorrections(corrections, out);
106+
writeCorrections(corrections);
107107
}
108108
// write tail
109109
for (; i < qvv.count(); ++i) {
110110
byte[] qv = qvv.next();
111111
OptimizedScalarQuantizer.QuantizationResult correction = qvv.getCorrections();
112112
out.writeBytes(qv, qv.length);
113-
writeCorrection(correction, out);
113+
writeCorrection(correction);
114114
}
115115
}
116+
117+
private void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections) throws IOException {
118+
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
119+
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
120+
}
121+
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
122+
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
123+
}
124+
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
125+
out.writeInt(correction.quantizedComponentSum());
126+
}
127+
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
128+
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
129+
}
130+
}
131+
132+
private void writeCorrection(OptimizedScalarQuantizer.QuantizationResult correction) throws IOException {
133+
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
134+
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
135+
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
136+
out.writeInt(correction.quantizedComponentSum());
137+
}
116138
}
117139
}

0 commit comments

Comments
 (0)