Skip to content

Commit a66087f

Browse files
committed
[DiskBBQ] Break big posting lists into blocks
1 parent f91cc68 commit a66087f

File tree

3 files changed

+239
-31
lines changed

3 files changed

+239
-31
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -321,11 +321,10 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
321321
final float[] correctionsAdd = new float[BULK_SIZE];
322322
final int[] docIdsScratch;
323323

324-
int vectors;
324+
int totalVectors;
325325
boolean quantized = false;
326326
float centroidDp;
327327
final float[] centroid;
328-
long slicePos;
329328
OptimizedScalarQuantizer.QuantizationResult queryCorrections;
330329
DocIdsWriter docIdsWriter = new DocIdsWriter();
331330

@@ -367,12 +366,9 @@ public int resetPostingsScorer(long offset) throws IOException {
367366
indexInput.seek(offset);
368367
indexInput.readFloats(centroid, 0, centroid.length);
369368
centroidDp = Float.intBitsToFloat(indexInput.readInt());
370-
vectors = indexInput.readVInt();
371-
// read the doc ids
372-
assert vectors <= docIdsScratch.length;
373-
docIdsWriter.readInts(indexInput, vectors, docIdsScratch);
374-
slicePos = indexInput.getFilePointer();
375-
return vectors;
369+
totalVectors = indexInput.readVInt();
370+
371+
return totalVectors;
376372
}
377373

378374
float scoreIndividually(int offset) throws IOException {
@@ -381,13 +377,13 @@ float scoreIndividually(int offset) throws IOException {
381377
for (int j = 0; j < BULK_SIZE; j++) {
382378
int doc = docIdsScratch[j + offset];
383379
if (doc != -1) {
384-
indexInput.seek(slicePos + (offset * quantizedByteLength) + (j * quantizedVectorByteSize));
385380
float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch);
386381
scores[j] = qcDist;
382+
} else {
383+
indexInput.skipBytes(quantizedVectorByteSize);
387384
}
388385
}
389386
// read in all corrections
390-
indexInput.seek(slicePos + (offset * quantizedByteLength) + (BULK_SIZE * quantizedVectorByteSize));
391387
indexInput.readFloats(correctionsLower, 0, BULK_SIZE);
392388
indexInput.readFloats(correctionsUpper, 0, BULK_SIZE);
393389
for (int j = 0; j < BULK_SIZE; j++) {
@@ -444,18 +440,36 @@ private static int collect(int[] docIds, int offset, KnnCollector knnCollector,
444440

445441
@Override
446442
public int visit(KnnCollector knnCollector) throws IOException {
443+
byte postingListType = indexInput.readByte();
444+
if (postingListType == DefaultIVFVectorsWriter.SINGLE_BLOCK_POSTING_LIST) {
445+
return singleBlockVisit(knnCollector, totalVectors);
446+
} else {
447+
assert postingListType == DefaultIVFVectorsWriter.MULTI_BLOCK_POSTING_LIST;
448+
final int numBlocks = indexInput.readVInt();
449+
int scoredDocs = 0;
450+
for (int i = 0; i < numBlocks; i++) {
451+
final int numVectors = indexInput.readVInt();
452+
scoredDocs += singleBlockVisit(knnCollector, numVectors);
453+
}
454+
return scoredDocs;
455+
}
456+
}
457+
458+
private int singleBlockVisit(KnnCollector knnCollector, int numVectors) throws IOException {
459+
assert numVectors <= docIdsScratch.length : "numVectors: " + numVectors + ", docIdsScratch.length: " + docIdsScratch.length;
460+
docIdsWriter.readInts(indexInput, numVectors, docIdsScratch);
447461
// block processing
448462
int scoredDocs = 0;
449-
int limit = vectors - BULK_SIZE + 1;
463+
int limit = numVectors - BULK_SIZE + 1;
450464
int i = 0;
451465
for (; i < limit; i += BULK_SIZE) {
452466
int docsToScore = BULK_SIZE - filterDocs(docIdsScratch, i, needsScoring);
453467
if (docsToScore == 0) {
468+
indexInput.skipBytes(BULK_SIZE * quantizedByteLength);
454469
continue;
455470
}
456471
quantizeQueryIfNecessary();
457-
indexInput.seek(slicePos + i * quantizedByteLength);
458-
float maxScore = Float.NEGATIVE_INFINITY;
472+
float maxScore;
459473
if (docsToScore < BULK_SIZE / 2) {
460474
maxScore = scoreIndividually(i);
461475
} else {
@@ -475,11 +489,10 @@ public int visit(KnnCollector knnCollector) throws IOException {
475489
}
476490
}
477491
// process tail
478-
for (; i < vectors; i++) {
492+
for (; i < numVectors; i++) {
479493
int doc = docIdsScratch[i];
480494
if (needsScoring.test(doc)) {
481495
quantizeQueryIfNecessary();
482-
indexInput.seek(slicePos + i * quantizedByteLength);
483496
float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch);
484497
indexInput.readFloats(correctiveValues, 0, 3);
485498
final int quantizedComponentSum = Short.toUnsignedInt(indexInput.readShort());
@@ -498,6 +511,8 @@ public int visit(KnnCollector knnCollector) throws IOException {
498511
);
499512
scoredDocs++;
500513
knnCollector.collect(doc, score);
514+
} else {
515+
indexInput.skipBytes(quantizedByteLength);
501516
}
502517
}
503518
if (scoredDocs > 0) {

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 183 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import java.nio.ByteOrder;
3636
import java.util.AbstractList;
3737
import java.util.Arrays;
38+
import java.util.function.IntPredicate;
3839

3940
/**
4041
* Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to
@@ -43,6 +44,10 @@
4344
*/
4445
public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
4546
private static final Logger logger = LogManager.getLogger(DefaultIVFVectorsWriter.class);
47+
// posting lists bigger than that will be split in two or more blocks
48+
private static final int MAX_POSTING_LIST_BLOCK_SIZE = 16 * 100;
49+
public static final byte SINGLE_BLOCK_POSTING_LIST = 0;
50+
public static final byte MULTI_BLOCK_POSTING_LIST = 1;
4651

4752
private final int vectorPerCluster;
4853
private final int centroidsPerParentCluster;
@@ -98,7 +103,7 @@ LongValues buildAndWritePostingsLists(
98103
}
99104
}
100105
// write the max posting list size
101-
postingsOutput.writeVInt(maxPostingListSize);
106+
postingsOutput.writeVInt(Math.min(MAX_POSTING_LIST_BLOCK_SIZE, maxPostingListSize));
102107
// write the posting lists
103108
final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT);
104109
DocIdsWriter docIdsWriter = new DocIdsWriter();
@@ -121,13 +126,31 @@ LongValues buildAndWritePostingsLists(
121126
int size = cluster.length;
122127
// write docIds
123128
postingsOutput.writeVInt(size);
124-
onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[ord]);
125-
// TODO we might want to consider putting the docIds in a separate file
126-
// to aid with only having to fetch vectors from slower storage when they are required
127-
// keeping them in the same file indicates we pull the entire file into cache
128-
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
129-
// write vectors
130-
bulkWriter.writeVectors(onHeapQuantizedVectors);
129+
if (size > MAX_POSTING_LIST_BLOCK_SIZE) {
130+
postingsOutput.writeByte(MULTI_BLOCK_POSTING_LIST);
131+
writeOnHeapMultiBlockPostingList(
132+
postingsOutput,
133+
floatVectorValues,
134+
onHeapQuantizedVectors,
135+
centroid,
136+
cluster,
137+
size,
138+
docIdsWriter,
139+
bulkWriter
140+
);
141+
} else {
142+
postingsOutput.writeByte(SINGLE_BLOCK_POSTING_LIST);
143+
writeOnHeapSingleBlockPostingList(
144+
postingsOutput,
145+
floatVectorValues,
146+
onHeapQuantizedVectors,
147+
centroid,
148+
k -> cluster[k],
149+
size,
150+
docIdsWriter,
151+
bulkWriter
152+
);
153+
}
131154
}
132155

133156
if (logger.isDebugEnabled()) {
@@ -137,6 +160,69 @@ LongValues buildAndWritePostingsLists(
137160
return offsets.build();
138161
}
139162

163+
private void writeOnHeapMultiBlockPostingList(
164+
IndexOutput postingsOutput,
165+
FloatVectorValues floatVectorValues,
166+
OnHeapQuantizedVectors onHeapQuantizedVectors,
167+
float[] centroid,
168+
int[] cluster,
169+
int size,
170+
DocIdsWriter docIdsWriter,
171+
DiskBBQBulkWriter bulkWriter
172+
) throws IOException {
173+
int numBlocks = (int) Math.ceil((double) size / MAX_POSTING_LIST_BLOCK_SIZE);
174+
postingsOutput.writeVInt(numBlocks);
175+
for (int i = 0; i < numBlocks - 1; i++) {
176+
int offset = MAX_POSTING_LIST_BLOCK_SIZE * i;
177+
postingsOutput.writeVInt(MAX_POSTING_LIST_BLOCK_SIZE);
178+
writeOnHeapSingleBlockPostingList(
179+
postingsOutput,
180+
floatVectorValues,
181+
onHeapQuantizedVectors,
182+
centroid,
183+
k -> cluster[offset + k],
184+
MAX_POSTING_LIST_BLOCK_SIZE,
185+
docIdsWriter,
186+
bulkWriter
187+
);
188+
}
189+
int lastBlock = size - (numBlocks - 1) * MAX_POSTING_LIST_BLOCK_SIZE;
190+
assert lastBlock >= 0;
191+
if (lastBlock > 0) {
192+
postingsOutput.writeVInt(lastBlock);
193+
writeOnHeapSingleBlockPostingList(
194+
postingsOutput,
195+
floatVectorValues,
196+
onHeapQuantizedVectors,
197+
centroid,
198+
k -> cluster[(numBlocks - 1) * MAX_POSTING_LIST_BLOCK_SIZE + k],
199+
lastBlock,
200+
docIdsWriter,
201+
bulkWriter
202+
);
203+
}
204+
}
205+
206+
private void writeOnHeapSingleBlockPostingList(
207+
IndexOutput postingsOutput,
208+
FloatVectorValues floatVectorValues,
209+
OnHeapQuantizedVectors onHeapQuantizedVectors,
210+
float[] centroid,
211+
IntToIntFunction cluster,
212+
int size,
213+
DocIdsWriter docIdsWriter,
214+
DiskBBQBulkWriter bulkWriter
215+
) throws IOException {
216+
217+
onHeapQuantizedVectors.reset(centroid, size, cluster);
218+
// TODO we might want to consider putting the docIds in a separate file
219+
// to aid with only having to fetch vectors from slower storage when they are required
220+
// keeping them in the same file indicates we pull the entire file into cache
221+
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.apply(j)), size, postingsOutput);
222+
// write vectors
223+
bulkWriter.writeVectors(onHeapQuantizedVectors);
224+
}
225+
140226
@Override
141227
LongValues buildAndWritePostingsLists(
142228
FieldInfo fieldInfo,
@@ -237,7 +323,7 @@ LongValues buildAndWritePostingsLists(
237323
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
238324
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
239325
// write the max posting list size
240-
postingsOutput.writeVInt(maxPostingListSize);
326+
postingsOutput.writeVInt(Math.min(MAX_POSTING_LIST_BLOCK_SIZE, maxPostingListSize));
241327
// write the posting lists
242328
for (int c = 0; c < centroidSupplier.size(); c++) {
243329
float[] centroid = centroidSupplier.centroid(c);
@@ -252,13 +338,31 @@ LongValues buildAndWritePostingsLists(
252338
// write docIds
253339
int size = cluster.length;
254340
postingsOutput.writeVInt(size);
255-
offHeapQuantizedVectors.reset(size, ord -> isOverspill[ord], ord -> cluster[ord]);
256-
// TODO we might want to consider putting the docIds in a separate file
257-
// to aid with only having to fetch vectors from slower storage when they are required
258-
// keeping them in the same file indicates we pull the entire file into cache
259-
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
260-
// write vectors
261-
bulkWriter.writeVectors(offHeapQuantizedVectors);
341+
if (size > MAX_POSTING_LIST_BLOCK_SIZE) {
342+
postingsOutput.writeByte(MULTI_BLOCK_POSTING_LIST);
343+
writeOffHeapMultiBlockPostingList(
344+
postingsOutput,
345+
floatVectorValues,
346+
offHeapQuantizedVectors,
347+
cluster,
348+
size,
349+
isOverspill,
350+
docIdsWriter,
351+
bulkWriter
352+
);
353+
} else {
354+
postingsOutput.writeByte(SINGLE_BLOCK_POSTING_LIST);
355+
writeOffHeapBlockPostingList(
356+
postingsOutput,
357+
floatVectorValues,
358+
offHeapQuantizedVectors,
359+
k -> cluster[k],
360+
size,
361+
b -> isOverspill[b],
362+
docIdsWriter,
363+
bulkWriter
364+
);
365+
}
262366
}
263367

264368
if (logger.isDebugEnabled()) {
@@ -268,6 +372,69 @@ LongValues buildAndWritePostingsLists(
268372
}
269373
}
270374

375+
private void writeOffHeapMultiBlockPostingList(
376+
IndexOutput postingsOutput,
377+
FloatVectorValues floatVectorValues,
378+
OffHeapQuantizedVectors offHeapQuantizedVectors,
379+
int[] cluster,
380+
int size,
381+
boolean[] isOverspill,
382+
DocIdsWriter docIdsWriter,
383+
DiskBBQBulkWriter bulkWriter
384+
) throws IOException {
385+
int numBlocks = (int) Math.ceil((double) size / MAX_POSTING_LIST_BLOCK_SIZE);
386+
postingsOutput.writeVInt(numBlocks);
387+
for (int i = 0; i < numBlocks - 1; i++) {
388+
int offset = MAX_POSTING_LIST_BLOCK_SIZE * i;
389+
postingsOutput.writeVInt(MAX_POSTING_LIST_BLOCK_SIZE);
390+
writeOffHeapBlockPostingList(
391+
postingsOutput,
392+
floatVectorValues,
393+
offHeapQuantizedVectors,
394+
k -> cluster[offset + k],
395+
MAX_POSTING_LIST_BLOCK_SIZE,
396+
b -> isOverspill[offset + b],
397+
docIdsWriter,
398+
bulkWriter
399+
);
400+
}
401+
int lastBlock = size - (numBlocks - 1) * MAX_POSTING_LIST_BLOCK_SIZE;
402+
assert lastBlock >= 0;
403+
if (lastBlock > 0) {
404+
postingsOutput.writeVInt(lastBlock);
405+
writeOffHeapBlockPostingList(
406+
postingsOutput,
407+
floatVectorValues,
408+
offHeapQuantizedVectors,
409+
k -> cluster[(numBlocks - 1) * MAX_POSTING_LIST_BLOCK_SIZE + k],
410+
lastBlock,
411+
b -> isOverspill[(numBlocks - 1) * MAX_POSTING_LIST_BLOCK_SIZE + b],
412+
docIdsWriter,
413+
bulkWriter
414+
);
415+
}
416+
}
417+
418+
private void writeOffHeapBlockPostingList(
419+
IndexOutput postingsOutput,
420+
FloatVectorValues floatVectorValues,
421+
OffHeapQuantizedVectors offHeapQuantizedVectors,
422+
IntToIntFunction cluster,
423+
int size,
424+
IntPredicate isOverspill,
425+
DocIdsWriter docIdsWriter,
426+
DiskBBQBulkWriter bulkWriter
427+
) throws IOException {
428+
429+
offHeapQuantizedVectors.reset(size, isOverspill::test, cluster);
430+
// TODO we might want to consider putting the docIds in a separate file
431+
// to aid with only having to fetch vectors from slower storage when they are required
432+
// keeping them in the same file indicates we pull the entire file into cache
433+
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.apply(j)), size, postingsOutput);
434+
// write vectors
435+
bulkWriter.writeVectors(offHeapQuantizedVectors);
436+
}
437+
271438
private static void printClusterQualityStatistics(int[][] clusters) {
272439
float min = Float.MAX_VALUE;
273440
float max = Float.MIN_VALUE;

0 commit comments

Comments
 (0)