Skip to content

Commit 8fc6cfa

Browse files
committed
Hierarchical centroid storage for DiskBBQ
1 parent 6932440 commit 8fc6cfa

File tree

11 files changed

+507
-74
lines changed

11 files changed

+507
-74
lines changed

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ private static String formatIndexPath(CmdLineArgs args) {
8989
static Codec createCodec(CmdLineArgs args) {
9090
final KnnVectorsFormat format;
9191
if (args.indexType() == IndexType.IVF) {
92-
format = new IVFVectorsFormat(args.ivfClusterSize());
92+
format = new IVFVectorsFormat(args.ivfClusterSize(), IVFVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER);
9393
} else {
9494
if (args.quantizeBits() == 1) {
9595
if (args.indexType() == IndexType.FLAT) {

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 243 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929

3030
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS;
3131
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
32-
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
33-
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
3432
import static org.elasticsearch.index.codec.vectors.BQSpaceUtils.transposeHalfByte;
3533
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
3634
import static org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer.DEFAULT_LAMBDA;
@@ -41,7 +39,9 @@
4139
* brute force and then scores the top ones using the posting list.
4240
*/
4341
public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeapStats {
44-
private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1);
42+
43+
// The percentage of centroids that are scored to keep recall
44+
public static final double CENTROID_SAMPLING_PERCENTAGE = 0.1;
4545

4646
public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
4747
super(state, rawVectorsReader);
@@ -54,8 +54,12 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde
5454
final float globalCentroidDp = fieldEntry.globalCentroidDp();
5555
final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
5656
final int[] scratch = new int[targetQuery.length];
57+
float[] targetQueryCpoy = ArrayUtil.copyArray(targetQuery);
58+
if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
59+
VectorUtil.l2normalize(targetQueryCpoy);
60+
}
5761
final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
58-
ArrayUtil.copyArray(targetQuery),
62+
targetQueryCpoy,
5963
scratch,
6064
(byte) 4,
6165
fieldEntry.globalCentroid()
@@ -65,68 +69,265 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde
6569
quantized[i] = (byte) scratch[i];
6670
}
6771
final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
68-
NeighborQueue queue = new NeighborQueue(fieldEntry.numCentroids(), true);
6972
centroids.seek(0L);
73+
int numParents = centroids.readVInt();
74+
if (numParents > 0) {
75+
return getCentroidIteratorWithParents(
76+
fieldInfo,
77+
centroids,
78+
numParents,
79+
numCentroids,
80+
scorer,
81+
quantized,
82+
queryParams,
83+
globalCentroidDp
84+
);
85+
}
86+
return getCentroidIteratorNoParent(fieldInfo, centroids, numCentroids, scorer, quantized, queryParams, globalCentroidDp);
87+
}
88+
89+
private CentroidIterator getCentroidIteratorNoParent(
90+
FieldInfo fieldInfo,
91+
IndexInput centroids,
92+
int numCentroids,
93+
ES91Int4VectorsScorer scorer,
94+
byte[] quantizeQuery,
95+
OptimizedScalarQuantizer.QuantizationResult queryParams,
96+
float globalCentroidDp
97+
) throws IOException {
98+
final NeighborQueue neighborQueue = new NeighborQueue(numCentroids, true);
99+
int4QuantizedScoreBulk(
100+
neighborQueue,
101+
centroids,
102+
numCentroids,
103+
0,
104+
scorer,
105+
quantizeQuery,
106+
queryParams,
107+
new float[3], // targetCorrections
108+
globalCentroidDp,
109+
fieldInfo.getVectorSimilarityFunction(),
110+
new float[ES91Int4VectorsScorer.BULK_SIZE]
111+
);
112+
long offset = centroids.getFilePointer();
113+
return new CentroidIterator() {
114+
@Override
115+
public boolean hasNext() {
116+
return neighborQueue.size() > 0;
117+
}
118+
119+
@Override
120+
public long nextPostingListOffset() throws IOException {
121+
int centroidOrdinal = neighborQueue.pop();
122+
centroids.seek(offset + (long) Long.BYTES * centroidOrdinal);
123+
return centroids.readLong();
124+
}
125+
};
126+
}
127+
128+
private CentroidIterator getCentroidIteratorWithParents(
129+
FieldInfo fieldInfo,
130+
IndexInput centroids,
131+
int numParents,
132+
int numCentroids,
133+
ES91Int4VectorsScorer scorer,
134+
byte[] quantizeQuery,
135+
OptimizedScalarQuantizer.QuantizationResult queryParams,
136+
float globalCentroidDp
137+
) throws IOException {
138+
final int maxChildrenSize = centroids.readVInt();
139+
final NeighborQueue parentsQueue = new NeighborQueue(numParents, true);
140+
final float[] scores = new float[ES91Int4VectorsScorer.BULK_SIZE];
70141
final float[] centroidCorrectiveValues = new float[3];
71-
for (int i = 0; i < numCentroids; i++) {
72-
final float qcDist = scorer.int4DotProduct(quantized);
73-
centroids.readFloats(centroidCorrectiveValues, 0, 3);
74-
final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort());
75-
float score = int4QuantizedScore(
76-
qcDist,
142+
int4QuantizedScoreBulk(
143+
parentsQueue,
144+
centroids,
145+
numParents,
146+
0,
147+
scorer,
148+
quantizeQuery,
149+
queryParams,
150+
centroidCorrectiveValues, // targetCorrections
151+
globalCentroidDp,
152+
fieldInfo.getVectorSimilarityFunction(),
153+
scores
154+
);
155+
final int bufferSize = (int) Math.max(numCentroids * CENTROID_SAMPLING_PERCENTAGE, 1);
156+
long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES;
157+
long offset = centroids.getFilePointer();
158+
long childrenOffset = offset + (long) Long.BYTES * numParents;
159+
NeighborQueue currentParentQueue = new NeighborQueue(maxChildrenSize, true);
160+
NeighborQueue neighborQueue = new NeighborQueue(bufferSize, true);
161+
while (parentsQueue.size() > 0 && neighborQueue.size() < bufferSize) {
162+
int pop = parentsQueue.pop();
163+
populateOneChildrenGroup(
164+
currentParentQueue,
165+
centroids,
166+
offset + 2L * Integer.BYTES * pop,
167+
childrenOffset,
168+
centroidQuantizeSize,
169+
fieldInfo,
170+
scorer,
171+
quantizeQuery,
77172
queryParams,
78-
fieldInfo.getVectorDimension(),
79173
centroidCorrectiveValues,
80-
quantizedCentroidComponentSum,
81174
globalCentroidDp,
82-
fieldInfo.getVectorSimilarityFunction()
175+
scores
83176
);
84-
queue.add(i, score);
177+
while (currentParentQueue.size() > 0 && neighborQueue.size() < bufferSize) {
178+
float score = currentParentQueue.topScore();
179+
int children = currentParentQueue.pop();
180+
neighborQueue.add(children, score);
181+
}
85182
}
86-
final long offset = centroids.getFilePointer();
183+
long childrenFileOffsets = childrenOffset + centroidQuantizeSize * numCentroids;
184+
87185
return new CentroidIterator() {
88186
@Override
89187
public boolean hasNext() {
90-
return queue.size() > 0;
188+
return neighborQueue.size() > 0;
91189
}
92190

93191
@Override
94192
public long nextPostingListOffset() throws IOException {
95-
int centroidOrdinal = queue.pop();
96-
centroids.seek(offset + (long) Long.BYTES * centroidOrdinal);
193+
int centroidOrdinal = neighborQueue.pop();
194+
updateQueue();
195+
centroids.seek(childrenFileOffsets + (long) Long.BYTES * centroidOrdinal);
97196
return centroids.readLong();
98197
}
198+
199+
private void updateQueue() throws IOException {
200+
if (currentParentQueue.size() > 0) {
201+
float score = currentParentQueue.topScore();
202+
int children = currentParentQueue.pop();
203+
neighborQueue.add(children, score);
204+
} else {
205+
if (parentsQueue.size() > 0) {
206+
int pop = parentsQueue.pop();
207+
populateOneChildrenGroup(
208+
currentParentQueue,
209+
centroids,
210+
offset + 2L * Integer.BYTES * pop,
211+
childrenOffset,
212+
centroidQuantizeSize,
213+
fieldInfo,
214+
scorer,
215+
quantizeQuery,
216+
queryParams,
217+
centroidCorrectiveValues,
218+
globalCentroidDp,
219+
scores
220+
);
221+
updateQueue();
222+
}
223+
}
224+
}
99225
};
100226
}
101227

102-
// TODO can we do this in off-heap blocks?
103-
private float int4QuantizedScore(
104-
float qcDist,
228+
private void populateOneChildrenGroup(
229+
NeighborQueue neighborQueue,
230+
IndexInput centroids,
231+
long parentOffset,
232+
long childrenOffset,
233+
long centroidQuantizeSize,
234+
FieldInfo fieldInfo,
235+
ES91Int4VectorsScorer scorer,
236+
byte[] quantizeQuery,
237+
OptimizedScalarQuantizer.QuantizationResult queryParams,
238+
float[] targetCorrections,
239+
float globalCentroidDp,
240+
float[] scores
241+
) throws IOException {
242+
centroids.seek(parentOffset);
243+
int childrenOrdinal = centroids.readInt();
244+
int numChildren = centroids.readInt();
245+
centroids.seek(childrenOffset + centroidQuantizeSize * childrenOrdinal);
246+
int4QuantizedScoreBulk(
247+
neighborQueue,
248+
centroids,
249+
numChildren,
250+
childrenOrdinal,
251+
scorer,
252+
quantizeQuery,
253+
queryParams,
254+
targetCorrections,
255+
globalCentroidDp,
256+
fieldInfo.getVectorSimilarityFunction(),
257+
scores
258+
);
259+
}
260+
261+
private void int4QuantizedScoreBulk(
262+
NeighborQueue neighborQueue,
263+
IndexInput centroids,
264+
int size,
265+
int scoresOffset,
266+
ES91Int4VectorsScorer scorer,
267+
byte[] quantizeQuery,
105268
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
106-
int dims,
107269
float[] targetCorrections,
108-
int targetComponentSum,
109270
float centroidDp,
110-
VectorSimilarityFunction similarityFunction
111-
) {
112-
float ax = targetCorrections[0];
113-
float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE;
114-
float ay = queryCorrections.lowerInterval();
115-
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
116-
float y1 = queryCorrections.quantizedComponentSum();
117-
float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
118-
if (similarityFunction == EUCLIDEAN) {
119-
score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score;
120-
return Math.max(1 / (1f + score), 0);
121-
} else {
122-
// For cosine and max inner product, we need to apply the additional correction, which is
123-
// assumed to be the non-centered dot-product between the vector and the centroid
124-
score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp;
125-
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
126-
return VectorUtil.scaleMaxInnerProductScore(score);
271+
VectorSimilarityFunction similarityFunction,
272+
float[] scores
273+
) throws IOException {
274+
int limit = size - ES91Int4VectorsScorer.BULK_SIZE + 1;
275+
int i = 0;
276+
for (; i < limit; i += ES91Int4VectorsScorer.BULK_SIZE) {
277+
scorer.scoreBulk(
278+
quantizeQuery,
279+
queryCorrections.lowerInterval(),
280+
queryCorrections.upperInterval(),
281+
queryCorrections.quantizedComponentSum(),
282+
queryCorrections.additionalCorrection(),
283+
similarityFunction,
284+
centroidDp,
285+
scores
286+
);
287+
for (int j = 0; j < ES91Int4VectorsScorer.BULK_SIZE; j++) {
288+
neighborQueue.add(scoresOffset + i + j, scores[j]);
127289
}
128-
return Math.max((1f + score) / 2f, 0);
129290
}
291+
292+
for (; i < size; i++) {
293+
float score = int4QuantizedScore(
294+
centroids,
295+
scorer,
296+
quantizeQuery,
297+
queryCorrections,
298+
targetCorrections,
299+
centroidDp,
300+
similarityFunction
301+
);
302+
neighborQueue.add(scoresOffset + i, score);
303+
}
304+
}
305+
306+
private float int4QuantizedScore(
307+
IndexInput centroids,
308+
ES91Int4VectorsScorer scorer,
309+
byte[] quantizeQuery,
310+
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
311+
float[] targetCorrections,
312+
float centroidDp,
313+
VectorSimilarityFunction similarityFunction
314+
) throws IOException {
315+
float qcDist = scorer.int4DotProduct(quantizeQuery);
316+
centroids.readFloats(targetCorrections, 0, 3);
317+
final int targetComponentSum = Short.toUnsignedInt(centroids.readShort());
318+
return scorer.applyCorrections(
319+
queryCorrections.lowerInterval(),
320+
queryCorrections.upperInterval(),
321+
queryCorrections.quantizedComponentSum(),
322+
queryCorrections.additionalCorrection(),
323+
similarityFunction,
324+
centroidDp,
325+
targetCorrections[0],
326+
targetCorrections[1],
327+
targetComponentSum,
328+
targetCorrections[2],
329+
qcDist
330+
);
130331
}
131332

132333
@Override

0 commit comments

Comments
 (0)