Skip to content

Commit 63b77a6

Browse files
authored
Hierarchical centroid storage for DiskBBQ (#132010)
This commit presents a hierarchical layer on top of the DiskBBQ centroids to reduce the number of centroids scored at search time.
1 parent 713d874 commit 63b77a6

File tree

11 files changed

+473
-77
lines changed

11 files changed

+473
-77
lines changed

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ private static String formatIndexPath(CmdLineArgs args) {
101101
static Codec createCodec(CmdLineArgs args) {
102102
final KnnVectorsFormat format;
103103
if (args.indexType() == IndexType.IVF) {
104-
format = new IVFVectorsFormat(args.ivfClusterSize());
104+
format = new IVFVectorsFormat(args.ivfClusterSize(), IVFVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER);
105105
} else {
106106
if (args.quantizeBits() == 1) {
107107
if (args.indexType() == IndexType.FLAT) {

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 209 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929

3030
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS;
3131
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
32-
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
33-
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
3432
import static org.elasticsearch.index.codec.vectors.BQSpaceUtils.transposeHalfByte;
3533
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
3634
import static org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer.DEFAULT_LAMBDA;
@@ -41,7 +39,9 @@
4139
* brute force and then scores the top ones using the posting list.
4240
*/
4341
public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeapStats {
44-
private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1);
42+
43+
// The percentage of centroids that are scored to keep recall
44+
public static final double CENTROID_SAMPLING_PERCENTAGE = 0.2;
4545

4646
public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
4747
super(state, rawVectorsReader);
@@ -54,8 +54,12 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde
5454
final float globalCentroidDp = fieldEntry.globalCentroidDp();
5555
final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
5656
final int[] scratch = new int[targetQuery.length];
57+
float[] targetQueryCopy = ArrayUtil.copyArray(targetQuery);
58+
if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
59+
VectorUtil.l2normalize(targetQueryCopy);
60+
}
5761
final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
58-
ArrayUtil.copyArray(targetQuery),
62+
targetQueryCopy,
5963
scratch,
6064
(byte) 4,
6165
fieldEntry.globalCentroid()
@@ -65,67 +69,227 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde
6569
quantized[i] = (byte) scratch[i];
6670
}
6771
final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
68-
NeighborQueue queue = new NeighborQueue(fieldEntry.numCentroids(), true);
6972
centroids.seek(0L);
70-
final float[] centroidCorrectiveValues = new float[3];
71-
for (int i = 0; i < numCentroids; i++) {
72-
final float qcDist = scorer.int4DotProduct(quantized);
73-
centroids.readFloats(centroidCorrectiveValues, 0, 3);
74-
final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort());
75-
float score = int4QuantizedScore(
76-
qcDist,
73+
int numParents = centroids.readVInt();
74+
if (numParents > 0) {
75+
return getCentroidIteratorWithParents(
76+
fieldInfo,
77+
centroids,
78+
numParents,
79+
numCentroids,
80+
scorer,
81+
quantized,
7782
queryParams,
78-
fieldInfo.getVectorDimension(),
79-
centroidCorrectiveValues,
80-
quantizedCentroidComponentSum,
81-
globalCentroidDp,
82-
fieldInfo.getVectorSimilarityFunction()
83+
globalCentroidDp
8384
);
84-
queue.add(i, score);
8585
}
86-
final long offset = centroids.getFilePointer();
86+
return getCentroidIteratorNoParent(fieldInfo, centroids, numCentroids, scorer, quantized, queryParams, globalCentroidDp);
87+
}
88+
89+
private static CentroidIterator getCentroidIteratorNoParent(
90+
FieldInfo fieldInfo,
91+
IndexInput centroids,
92+
int numCentroids,
93+
ES91Int4VectorsScorer scorer,
94+
byte[] quantizeQuery,
95+
OptimizedScalarQuantizer.QuantizationResult queryParams,
96+
float globalCentroidDp
97+
) throws IOException {
98+
final NeighborQueue neighborQueue = new NeighborQueue(numCentroids, true);
99+
score(
100+
neighborQueue,
101+
numCentroids,
102+
0,
103+
scorer,
104+
quantizeQuery,
105+
queryParams,
106+
globalCentroidDp,
107+
fieldInfo.getVectorSimilarityFunction(),
108+
new float[ES91Int4VectorsScorer.BULK_SIZE]
109+
);
110+
long offset = centroids.getFilePointer();
87111
return new CentroidIterator() {
88112
@Override
89113
public boolean hasNext() {
90-
return queue.size() > 0;
114+
return neighborQueue.size() > 0;
91115
}
92116

93117
@Override
94118
public long nextPostingListOffset() throws IOException {
95-
int centroidOrdinal = queue.pop();
119+
int centroidOrdinal = neighborQueue.pop();
96120
centroids.seek(offset + (long) Long.BYTES * centroidOrdinal);
97121
return centroids.readLong();
98122
}
99123
};
100124
}
101125

102-
// TODO can we do this in off-heap blocks?
103-
private float int4QuantizedScore(
104-
float qcDist,
126+
private static CentroidIterator getCentroidIteratorWithParents(
127+
FieldInfo fieldInfo,
128+
IndexInput centroids,
129+
int numParents,
130+
int numCentroids,
131+
ES91Int4VectorsScorer scorer,
132+
byte[] quantizeQuery,
133+
OptimizedScalarQuantizer.QuantizationResult queryParams,
134+
float globalCentroidDp
135+
) throws IOException {
136+
// build the three queues we are going to use
137+
final NeighborQueue parentsQueue = new NeighborQueue(numParents, true);
138+
final int maxChildrenSize = centroids.readVInt();
139+
final NeighborQueue currentParentQueue = new NeighborQueue(maxChildrenSize, true);
140+
final int bufferSize = (int) Math.max(numCentroids * CENTROID_SAMPLING_PERCENTAGE, 1);
141+
final NeighborQueue neighborQueue = new NeighborQueue(bufferSize, true);
142+
// score the parents
143+
final float[] scores = new float[ES91Int4VectorsScorer.BULK_SIZE];
144+
score(
145+
parentsQueue,
146+
numParents,
147+
0,
148+
scorer,
149+
quantizeQuery,
150+
queryParams,
151+
globalCentroidDp,
152+
fieldInfo.getVectorSimilarityFunction(),
153+
scores
154+
);
155+
final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES;
156+
final long offset = centroids.getFilePointer();
157+
final long childrenOffset = offset + (long) Long.BYTES * numParents;
158+
// populate the children's queue by reading parents one by one
159+
while (parentsQueue.size() > 0 && neighborQueue.size() < bufferSize) {
160+
final int pop = parentsQueue.pop();
161+
populateOneChildrenGroup(
162+
currentParentQueue,
163+
centroids,
164+
offset + 2L * Integer.BYTES * pop,
165+
childrenOffset,
166+
centroidQuantizeSize,
167+
fieldInfo,
168+
scorer,
169+
quantizeQuery,
170+
queryParams,
171+
globalCentroidDp,
172+
scores
173+
);
174+
while (currentParentQueue.size() > 0 && neighborQueue.size() < bufferSize) {
175+
final float score = currentParentQueue.topScore();
176+
final int children = currentParentQueue.pop();
177+
neighborQueue.add(children, score);
178+
}
179+
}
180+
final long childrenFileOffsets = childrenOffset + centroidQuantizeSize * numCentroids;
181+
return new CentroidIterator() {
182+
@Override
183+
public boolean hasNext() {
184+
return neighborQueue.size() > 0;
185+
}
186+
187+
@Override
188+
public long nextPostingListOffset() throws IOException {
189+
int centroidOrdinal = neighborQueue.pop();
190+
updateQueue(); // add one children if available so the queue remains fully populated
191+
centroids.seek(childrenFileOffsets + (long) Long.BYTES * centroidOrdinal);
192+
return centroids.readLong();
193+
}
194+
195+
private void updateQueue() throws IOException {
196+
if (currentParentQueue.size() > 0) {
197+
// add a children from the current parent queue
198+
float score = currentParentQueue.topScore();
199+
int children = currentParentQueue.pop();
200+
neighborQueue.add(children, score);
201+
} else if (parentsQueue.size() > 0) {
202+
// add a new parent from the current parent queue
203+
int pop = parentsQueue.pop();
204+
populateOneChildrenGroup(
205+
currentParentQueue,
206+
centroids,
207+
offset + 2L * Integer.BYTES * pop,
208+
childrenOffset,
209+
centroidQuantizeSize,
210+
fieldInfo,
211+
scorer,
212+
quantizeQuery,
213+
queryParams,
214+
globalCentroidDp,
215+
scores
216+
);
217+
updateQueue();
218+
}
219+
}
220+
};
221+
}
222+
223+
private static void populateOneChildrenGroup(
224+
NeighborQueue neighborQueue,
225+
IndexInput centroids,
226+
long parentOffset,
227+
long childrenOffset,
228+
long centroidQuantizeSize,
229+
FieldInfo fieldInfo,
230+
ES91Int4VectorsScorer scorer,
231+
byte[] quantizeQuery,
232+
OptimizedScalarQuantizer.QuantizationResult queryParams,
233+
float globalCentroidDp,
234+
float[] scores
235+
) throws IOException {
236+
centroids.seek(parentOffset);
237+
int childrenOrdinal = centroids.readInt();
238+
int numChildren = centroids.readInt();
239+
centroids.seek(childrenOffset + centroidQuantizeSize * childrenOrdinal);
240+
score(
241+
neighborQueue,
242+
numChildren,
243+
childrenOrdinal,
244+
scorer,
245+
quantizeQuery,
246+
queryParams,
247+
globalCentroidDp,
248+
fieldInfo.getVectorSimilarityFunction(),
249+
scores
250+
);
251+
}
252+
253+
private static void score(
254+
NeighborQueue neighborQueue,
255+
int size,
256+
int scoresOffset,
257+
ES91Int4VectorsScorer scorer,
258+
byte[] quantizeQuery,
105259
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
106-
int dims,
107-
float[] targetCorrections,
108-
int targetComponentSum,
109260
float centroidDp,
110-
VectorSimilarityFunction similarityFunction
111-
) {
112-
float ax = targetCorrections[0];
113-
float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE;
114-
float ay = queryCorrections.lowerInterval();
115-
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
116-
float y1 = queryCorrections.quantizedComponentSum();
117-
float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
118-
if (similarityFunction == EUCLIDEAN) {
119-
score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score;
120-
return Math.max(1 / (1f + score), 0);
121-
} else {
122-
// For cosine and max inner product, we need to apply the additional correction, which is
123-
// assumed to be the non-centered dot-product between the vector and the centroid
124-
score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp;
125-
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
126-
return VectorUtil.scaleMaxInnerProductScore(score);
261+
VectorSimilarityFunction similarityFunction,
262+
float[] scores
263+
) throws IOException {
264+
int limit = size - ES91Int4VectorsScorer.BULK_SIZE + 1;
265+
int i = 0;
266+
for (; i < limit; i += ES91Int4VectorsScorer.BULK_SIZE) {
267+
scorer.scoreBulk(
268+
quantizeQuery,
269+
queryCorrections.lowerInterval(),
270+
queryCorrections.upperInterval(),
271+
queryCorrections.quantizedComponentSum(),
272+
queryCorrections.additionalCorrection(),
273+
similarityFunction,
274+
centroidDp,
275+
scores
276+
);
277+
for (int j = 0; j < ES91Int4VectorsScorer.BULK_SIZE; j++) {
278+
neighborQueue.add(scoresOffset + i + j, scores[j]);
127279
}
128-
return Math.max((1f + score) / 2f, 0);
280+
}
281+
282+
for (; i < size; i++) {
283+
float score = scorer.score(
284+
quantizeQuery,
285+
queryCorrections.lowerInterval(),
286+
queryCorrections.upperInterval(),
287+
queryCorrections.quantizedComponentSum(),
288+
queryCorrections.additionalCorrection(),
289+
similarityFunction,
290+
centroidDp
291+
);
292+
neighborQueue.add(scoresOffset + i, score);
129293
}
130294
}
131295

0 commit comments

Comments
 (0)