Skip to content

Commit 0437da5

Browse files
authored
Diskbbq posting list order (#132697)
revisiting ordering the postings list now that we aren't keeping track of already scored vectors. GroupVarInt doesn't add much given delta encoding. Maybe we should just switch to optimized `for` encoding
1 parent eac438b commit 0437da5

File tree

3 files changed

+79
-87
lines changed

3 files changed

+79
-87
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,12 +326,12 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
326326
final float[] centroid;
327327
long slicePos;
328328
OptimizedScalarQuantizer.QuantizationResult queryCorrections;
329-
DocIdsWriter docIdsWriter = new DocIdsWriter();
330329

331330
final float[] scratch;
332331
final int[] quantizationScratch;
333332
final byte[] quantizedQueryScratch;
334333
final OptimizedScalarQuantizer quantizer;
334+
final DocIdsWriter idsWriter = new DocIdsWriter();
335335
final float[] correctiveValues = new float[3];
336336
final long quantizedVectorByteSize;
337337

@@ -369,7 +369,13 @@ public int resetPostingsScorer(long offset) throws IOException {
369369
vectors = indexInput.readVInt();
370370
// read the doc ids
371371
assert vectors <= docIdsScratch.length;
372-
docIdsWriter.readInts(indexInput, vectors, docIdsScratch);
372+
idsWriter.readInts(indexInput, vectors, docIdsScratch);
373+
// reconstitute from the deltas
374+
int sum = 0;
375+
for (int i = 0; i < vectors; i++) {
376+
sum += docIdsScratch[i];
377+
docIdsScratch[i] = sum;
378+
}
373379
slicePos = indexInput.getFilePointer();
374380
return vectors;
375381
}

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.apache.lucene.store.IOContext;
1818
import org.apache.lucene.store.IndexInput;
1919
import org.apache.lucene.store.IndexOutput;
20+
import org.apache.lucene.util.IntroSorter;
2021
import org.apache.lucene.util.LongValues;
2122
import org.apache.lucene.util.VectorUtil;
2223
import org.apache.lucene.util.hnsw.IntToIntFunction;
@@ -101,14 +102,17 @@ LongValues buildAndWritePostingsLists(
101102
postingsOutput.writeVInt(maxPostingListSize);
102103
// write the posting lists
103104
final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT);
104-
DocIdsWriter docIdsWriter = new DocIdsWriter();
105105
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
106106
OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors(
107107
floatVectorValues,
108108
fieldInfo.getVectorDimension(),
109109
new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction())
110110
);
111111
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
112+
final int[] docIds = new int[maxPostingListSize];
113+
final int[] docDeltas = new int[maxPostingListSize];
114+
final int[] clusterOrds = new int[maxPostingListSize];
115+
DocIdsWriter idsWriter = new DocIdsWriter();
112116
for (int c = 0; c < centroidSupplier.size(); c++) {
113117
float[] centroid = centroidSupplier.centroid(c);
114118
int[] cluster = assignmentsByCluster[c];
@@ -121,11 +125,21 @@ LongValues buildAndWritePostingsLists(
121125
int size = cluster.length;
122126
// write docIds
123127
postingsOutput.writeVInt(size);
124-
onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[ord]);
128+
for (int j = 0; j < size; j++) {
129+
docIds[j] = floatVectorValues.ordToDoc(cluster[j]);
130+
clusterOrds[j] = j;
131+
}
132+
// sort cluster.buffer by docIds values, this way cluster ordinals are sorted by docIds
133+
new IntSorter(clusterOrds, i -> docIds[i]).sort(0, size);
134+
// encode doc deltas
135+
for (int j = 0; j < size; j++) {
136+
docDeltas[j] = j == 0 ? docIds[clusterOrds[j]] : docIds[clusterOrds[j]] - docIds[clusterOrds[j - 1]];
137+
}
138+
onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[clusterOrds[ord]]);
125139
// TODO we might want to consider putting the docIds in a separate file
126140
// to aid with only having to fetch vectors from slower storage when they are required
127141
// keeping them in the same file indicates we pull the entire file into cache
128-
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
142+
idsWriter.writeDocIds(i -> docDeltas[i], size, postingsOutput);
129143
// write vectors
130144
bulkWriter.writeVectors(onHeapQuantizedVectors);
131145
}
@@ -233,12 +247,15 @@ LongValues buildAndWritePostingsLists(
233247
quantizedVectorsInput,
234248
fieldInfo.getVectorDimension()
235249
);
236-
DocIdsWriter docIdsWriter = new DocIdsWriter();
237250
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
238251
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
239252
// write the max posting list size
240253
postingsOutput.writeVInt(maxPostingListSize);
241254
// write the posting lists
255+
final int[] docIds = new int[maxPostingListSize];
256+
final int[] docDeltas = new int[maxPostingListSize];
257+
final int[] clusterOrds = new int[maxPostingListSize];
258+
DocIdsWriter idsWriter = new DocIdsWriter();
242259
for (int c = 0; c < centroidSupplier.size(); c++) {
243260
float[] centroid = centroidSupplier.centroid(c);
244261
int[] cluster = assignmentsByCluster[c];
@@ -252,11 +269,21 @@ LongValues buildAndWritePostingsLists(
252269
// write docIds
253270
int size = cluster.length;
254271
postingsOutput.writeVInt(size);
255-
offHeapQuantizedVectors.reset(size, ord -> isOverspill[ord], ord -> cluster[ord]);
272+
for (int j = 0; j < size; j++) {
273+
docIds[j] = floatVectorValues.ordToDoc(cluster[j]);
274+
clusterOrds[j] = j;
275+
}
276+
// sort cluster.buffer by docIds values, this way cluster ordinals are sorted by docIds
277+
new IntSorter(clusterOrds, i -> docIds[i]).sort(0, size);
278+
// encode doc deltas
279+
for (int j = 0; j < size; j++) {
280+
docDeltas[j] = j == 0 ? docIds[clusterOrds[j]] : docIds[clusterOrds[j]] - docIds[clusterOrds[j - 1]];
281+
}
282+
offHeapQuantizedVectors.reset(size, ord -> isOverspill[clusterOrds[ord]], ord -> cluster[clusterOrds[ord]]);
256283
// TODO we might want to consider putting the docIds in a separate file
257284
// to aid with only having to fetch vectors from slower storage when they are required
258285
// keeping them in the same file indicates we pull the entire file into cache
259-
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
286+
idsWriter.writeDocIds(i -> docDeltas[i], size, postingsOutput);
260287
// write vectors
261288
bulkWriter.writeVectors(offHeapQuantizedVectors);
262289
}
@@ -717,4 +744,37 @@ public void readQuantizedVector(int ord, boolean isOverspill) throws IOException
717744
bitSum = quantizedVectorsInput.readShort();
718745
}
719746
}
747+
748+
private static class IntSorter extends IntroSorter {
749+
int pivot = -1;
750+
private final int[] arr;
751+
private final IntToIntFunction func;
752+
753+
private IntSorter(int[] arr, IntToIntFunction func) {
754+
this.arr = arr;
755+
this.func = func;
756+
}
757+
758+
@Override
759+
protected void setPivot(int i) {
760+
pivot = func.apply(arr[i]);
761+
}
762+
763+
@Override
764+
protected int comparePivot(int j) {
765+
return Integer.compare(pivot, func.apply(arr[j]));
766+
}
767+
768+
@Override
769+
protected int compare(int a, int b) {
770+
return Integer.compare(func.apply(arr[a]), func.apply(arr[b]));
771+
}
772+
773+
@Override
774+
protected void swap(int i, int j) {
775+
final int tmp = arr[i];
776+
arr[i] = arr[j];
777+
arr[j] = tmp;
778+
}
779+
}
720780
}

server/src/main/java/org/elasticsearch/index/codec/vectors/DocIdsWriter.java

Lines changed: 5 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,13 @@
1919
package org.elasticsearch.index.codec.vectors;
2020

2121
import org.apache.lucene.index.PointValues.IntersectVisitor;
22-
import org.apache.lucene.search.DocIdSetIterator;
2322
import org.apache.lucene.store.DataOutput;
2423
import org.apache.lucene.store.IndexInput;
25-
import org.apache.lucene.util.ArrayUtil;
26-
import org.apache.lucene.util.DocBaseBitSetIterator;
27-
import org.apache.lucene.util.FixedBitSet;
2824
import org.apache.lucene.util.IntsRef;
2925
import org.apache.lucene.util.LongsRef;
3026
import org.apache.lucene.util.hnsw.IntToIntFunction;
3127

3228
import java.io.IOException;
33-
import java.util.Arrays;
3429

3530
/**
3631
* This class is used to write and read the doc ids in a compressed format. The format is optimized
@@ -42,7 +37,6 @@ final class DocIdsWriter {
4237
public static final int DEFAULT_MAX_POINTS_IN_LEAF_NODE = 512;
4338

4439
private static final byte CONTINUOUS_IDS = (byte) -2;
45-
private static final byte BITSET_IDS = (byte) -1;
4640
private static final byte DELTA_BPV_16 = (byte) 16;
4741
private static final byte BPV_21 = (byte) 21;
4842
private static final byte BPV_24 = (byte) 24;
@@ -92,21 +86,11 @@ void writeDocIds(IntToIntFunction docIds, int count, DataOutput out) throws IOEx
9286
}
9387

9488
int min2max = max - min + 1;
95-
if (strictlySorted) {
96-
if (min2max == count) {
97-
// continuous ids, typically happens when segment is sorted
98-
out.writeByte(CONTINUOUS_IDS);
99-
out.writeVInt(docIds.apply(0));
100-
return;
101-
} else if (min2max <= (count << 4)) {
102-
assert min2max > count : "min2max: " + min2max + ", count: " + count;
103-
// Only trigger bitset optimization when max - min + 1 <= 16 * count in order to avoid
104-
// expanding too much storage.
105-
// A field with lower cardinality will have higher probability to trigger this optimization.
106-
out.writeByte(BITSET_IDS);
107-
writeIdsAsBitSet(docIds, count, out);
108-
return;
109-
}
89+
if (strictlySorted && min2max == count) {
90+
// continuous ids, typically happens when segment is sorted
91+
out.writeByte(CONTINUOUS_IDS);
92+
out.writeVInt(docIds.apply(0));
93+
return;
11094
}
11195

11296
if (min2max <= 0xFFFF) {
@@ -180,38 +164,6 @@ void writeDocIds(IntToIntFunction docIds, int count, DataOutput out) throws IOEx
180164
}
181165
}
182166

183-
private static void writeIdsAsBitSet(IntToIntFunction docIds, int count, DataOutput out) throws IOException {
184-
int min = docIds.apply(0);
185-
int max = docIds.apply(count - 1);
186-
187-
final int offsetWords = min >> 6;
188-
final int offsetBits = offsetWords << 6;
189-
final int totalWordCount = FixedBitSet.bits2words(max - offsetBits + 1);
190-
long currentWord = 0;
191-
int currentWordIndex = 0;
192-
193-
out.writeVInt(offsetWords);
194-
out.writeVInt(totalWordCount);
195-
// build bit set streaming
196-
for (int i = 0; i < count; i++) {
197-
final int index = docIds.apply(i) - offsetBits;
198-
final int nextWordIndex = index >> 6;
199-
assert currentWordIndex <= nextWordIndex;
200-
if (currentWordIndex < nextWordIndex) {
201-
out.writeLong(currentWord);
202-
currentWord = 0L;
203-
currentWordIndex++;
204-
while (currentWordIndex < nextWordIndex) {
205-
currentWordIndex++;
206-
out.writeLong(0L);
207-
}
208-
}
209-
currentWord |= 1L << index;
210-
}
211-
out.writeLong(currentWord);
212-
assert currentWordIndex + 1 == totalWordCount;
213-
}
214-
215167
/** Read {@code count} integers into {@code docIDs}. */
216168
void readInts(IndexInput in, int count, int[] docIDs) throws IOException {
217169
if (count == 0) {
@@ -225,9 +177,6 @@ void readInts(IndexInput in, int count, int[] docIDs) throws IOException {
225177
case CONTINUOUS_IDS:
226178
readContinuousIds(in, count, docIDs);
227179
break;
228-
case BITSET_IDS:
229-
readBitSet(in, count, docIDs);
230-
break;
231180
case DELTA_BPV_16:
232181
readDelta16(in, count, docIDs);
233182
break;
@@ -245,36 +194,13 @@ void readInts(IndexInput in, int count, int[] docIDs) throws IOException {
245194
}
246195
}
247196

248-
private DocIdSetIterator readBitSetIterator(IndexInput in, int count) throws IOException {
249-
int offsetWords = in.readVInt();
250-
int longLen = in.readVInt();
251-
scratchLongs.longs = ArrayUtil.growNoCopy(scratchLongs.longs, longLen);
252-
in.readLongs(scratchLongs.longs, 0, longLen);
253-
// make ghost bits clear for FixedBitSet.
254-
if (longLen < scratchLongs.length) {
255-
Arrays.fill(scratchLongs.longs, longLen, scratchLongs.longs.length, 0);
256-
}
257-
scratchLongs.length = longLen;
258-
FixedBitSet bitSet = new FixedBitSet(scratchLongs.longs, longLen << 6);
259-
return new DocBaseBitSetIterator(bitSet, count, offsetWords << 6);
260-
}
261-
262197
private static void readContinuousIds(IndexInput in, int count, int[] docIDs) throws IOException {
263198
int start = in.readVInt();
264199
for (int i = 0; i < count; i++) {
265200
docIDs[i] = start + i;
266201
}
267202
}
268203

269-
private void readBitSet(IndexInput in, int count, int[] docIDs) throws IOException {
270-
DocIdSetIterator iterator = readBitSetIterator(in, count);
271-
int docId, pos = 0;
272-
while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
273-
docIDs[pos++] = docId;
274-
}
275-
assert pos == count : "pos: " + pos + ", count: " + count;
276-
}
277-
278204
private static void readDelta16(IndexInput in, int count, int[] docIds) throws IOException {
279205
final int min = in.readVInt();
280206
final int half = count >> 1;

0 commit comments

Comments
 (0)