diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index 3fceace3bbf51..d80fc216e556c 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -326,12 +326,12 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor { final float[] centroid; long slicePos; OptimizedScalarQuantizer.QuantizationResult queryCorrections; - DocIdsWriter docIdsWriter = new DocIdsWriter(); final float[] scratch; final int[] quantizationScratch; final byte[] quantizedQueryScratch; final OptimizedScalarQuantizer quantizer; + final DocIdsWriter idsWriter = new DocIdsWriter(); final float[] correctiveValues = new float[3]; final long quantizedVectorByteSize; @@ -369,7 +369,13 @@ public int resetPostingsScorer(long offset) throws IOException { vectors = indexInput.readVInt(); // read the doc ids assert vectors <= docIdsScratch.length; - docIdsWriter.readInts(indexInput, vectors, docIdsScratch); + idsWriter.readInts(indexInput, vectors, docIdsScratch); + // reconstitute from the deltas + int sum = 0; + for (int i = 0; i < vectors; i++) { + sum += docIdsScratch[i]; + docIdsScratch[i] = sum; + } slicePos = indexInput.getFilePointer(); return vectors; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index d16163d6934e8..5e696b74530a8 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -17,6 +17,7 @@ import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.IntroSorter; import org.apache.lucene.util.LongValues; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.IntToIntFunction; @@ -101,7 +102,6 @@ LongValues buildAndWritePostingsLists( postingsOutput.writeVInt(maxPostingListSize); // write the posting lists final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT); - DocIdsWriter docIdsWriter = new DocIdsWriter(); DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput); OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors( floatVectorValues, @@ -109,6 +109,10 @@ LongValues buildAndWritePostingsLists( new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()) ); final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + final int[] docIds = new int[maxPostingListSize]; + final int[] docDeltas = new int[maxPostingListSize]; + final int[] clusterOrds = new int[maxPostingListSize]; + DocIdsWriter idsWriter = new DocIdsWriter(); for (int c = 0; c < centroidSupplier.size(); c++) { float[] centroid = centroidSupplier.centroid(c); int[] cluster = assignmentsByCluster[c]; @@ -121,11 +125,21 @@ LongValues buildAndWritePostingsLists( int size = cluster.length; // write docIds postingsOutput.writeVInt(size); - onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[ord]); + for (int j = 0; j < size; j++) { + docIds[j] = floatVectorValues.ordToDoc(cluster[j]); + clusterOrds[j] = j; + } + // sort cluster.buffer by docIds values, this way cluster ordinals are sorted by docIds + new IntSorter(clusterOrds, i -> docIds[i]).sort(0, size); + // encode doc deltas + for (int j = 0; j < size; j++) { + docDeltas[j] = j == 0 ? docIds[clusterOrds[j]] : docIds[clusterOrds[j]] - docIds[clusterOrds[j - 1]]; + } + onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[clusterOrds[ord]]); // TODO we might want to consider putting the docIds in a separate file // to aid with only having to fetch vectors from slower storage when they are required // keeping them in the same file indicates we pull the entire file into cache - docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput); + idsWriter.writeDocIds(i -> docDeltas[i], size, postingsOutput); // write vectors bulkWriter.writeVectors(onHeapQuantizedVectors); } @@ -233,12 +247,15 @@ LongValues buildAndWritePostingsLists( quantizedVectorsInput, fieldInfo.getVectorDimension() ); - DocIdsWriter docIdsWriter = new DocIdsWriter(); DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput); final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); // write the max posting list size postingsOutput.writeVInt(maxPostingListSize); // write the posting lists + final int[] docIds = new int[maxPostingListSize]; + final int[] docDeltas = new int[maxPostingListSize]; + final int[] clusterOrds = new int[maxPostingListSize]; + DocIdsWriter idsWriter = new DocIdsWriter(); for (int c = 0; c < centroidSupplier.size(); c++) { float[] centroid = centroidSupplier.centroid(c); int[] cluster = assignmentsByCluster[c]; @@ -252,11 +269,21 @@ LongValues buildAndWritePostingsLists( // write docIds int size = cluster.length; postingsOutput.writeVInt(size); - offHeapQuantizedVectors.reset(size, ord -> isOverspill[ord], ord -> cluster[ord]); + for (int j = 0; j < size; j++) { + docIds[j] = floatVectorValues.ordToDoc(cluster[j]); + clusterOrds[j] = j; + } + // sort cluster.buffer by docIds values, this way cluster ordinals are sorted by docIds + new IntSorter(clusterOrds, i -> docIds[i]).sort(0, size); + // encode doc deltas + for (int j = 0; j < size; j++) { + docDeltas[j] = j == 0 ? docIds[clusterOrds[j]] : docIds[clusterOrds[j]] - docIds[clusterOrds[j - 1]]; + } + offHeapQuantizedVectors.reset(size, ord -> isOverspill[clusterOrds[ord]], ord -> cluster[clusterOrds[ord]]); // TODO we might want to consider putting the docIds in a separate file // to aid with only having to fetch vectors from slower storage when they are required // keeping them in the same file indicates we pull the entire file into cache - docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput); + idsWriter.writeDocIds(i -> docDeltas[i], size, postingsOutput); // write vectors bulkWriter.writeVectors(offHeapQuantizedVectors); } @@ -717,4 +744,37 @@ public void readQuantizedVector(int ord, boolean isOverspill) throws IOException bitSum = quantizedVectorsInput.readShort(); } } + + private static class IntSorter extends IntroSorter { + int pivot = -1; + private final int[] arr; + private final IntToIntFunction func; + + private IntSorter(int[] arr, IntToIntFunction func) { + this.arr = arr; + this.func = func; + } + + @Override + protected void setPivot(int i) { + pivot = func.apply(arr[i]); + } + + @Override + protected int comparePivot(int j) { + return Integer.compare(pivot, func.apply(arr[j])); + } + + @Override + protected int compare(int a, int b) { + return Integer.compare(func.apply(arr[a]), func.apply(arr[b])); + } + + @Override + protected void swap(int i, int j) { + final int tmp = arr[i]; + arr[i] = arr[j]; + arr[j] = tmp; + } + } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DocIdsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DocIdsWriter.java index d46a6301b60a2..257a1340eeff1 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DocIdsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DocIdsWriter.java @@ -19,18 +19,13 @@ package org.elasticsearch.index.codec.vectors; import org.apache.lucene.index.PointValues.IntersectVisitor; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.store.DataOutput; import org.apache.lucene.store.IndexInput; -import org.apache.lucene.util.ArrayUtil; -import org.apache.lucene.util.DocBaseBitSetIterator; -import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.IntsRef; import org.apache.lucene.util.LongsRef; import org.apache.lucene.util.hnsw.IntToIntFunction; import java.io.IOException; -import java.util.Arrays; /** * This class is used to write and read the doc ids in a compressed format. The format is optimized @@ -42,7 +37,6 @@ final class DocIdsWriter { public static final int DEFAULT_MAX_POINTS_IN_LEAF_NODE = 512; private static final byte CONTINUOUS_IDS = (byte) -2; - private static final byte BITSET_IDS = (byte) -1; private static final byte DELTA_BPV_16 = (byte) 16; private static final byte BPV_21 = (byte) 21; private static final byte BPV_24 = (byte) 24; @@ -92,21 +86,11 @@ void writeDocIds(IntToIntFunction docIds, int count, DataOutput out) throws IOEx } int min2max = max - min + 1; - if (strictlySorted) { - if (min2max == count) { - // continuous ids, typically happens when segment is sorted - out.writeByte(CONTINUOUS_IDS); - out.writeVInt(docIds.apply(0)); - return; - } else if (min2max <= (count << 4)) { - assert min2max > count : "min2max: " + min2max + ", count: " + count; - // Only trigger bitset optimization when max - min + 1 <= 16 * count in order to avoid - // expanding too much storage. - // A field with lower cardinality will have higher probability to trigger this optimization. - out.writeByte(BITSET_IDS); - writeIdsAsBitSet(docIds, count, out); - return; - } + if (strictlySorted && min2max == count) { + // continuous ids, typically happens when segment is sorted + out.writeByte(CONTINUOUS_IDS); + out.writeVInt(docIds.apply(0)); + return; } if (min2max <= 0xFFFF) { @@ -180,38 +164,6 @@ void writeDocIds(IntToIntFunction docIds, int count, DataOutput out) throws IOEx } } - private static void writeIdsAsBitSet(IntToIntFunction docIds, int count, DataOutput out) throws IOException { - int min = docIds.apply(0); - int max = docIds.apply(count - 1); - - final int offsetWords = min >> 6; - final int offsetBits = offsetWords << 6; - final int totalWordCount = FixedBitSet.bits2words(max - offsetBits + 1); - long currentWord = 0; - int currentWordIndex = 0; - - out.writeVInt(offsetWords); - out.writeVInt(totalWordCount); - // build bit set streaming - for (int i = 0; i < count; i++) { - final int index = docIds.apply(i) - offsetBits; - final int nextWordIndex = index >> 6; - assert currentWordIndex <= nextWordIndex; - if (currentWordIndex < nextWordIndex) { - out.writeLong(currentWord); - currentWord = 0L; - currentWordIndex++; - while (currentWordIndex < nextWordIndex) { - currentWordIndex++; - out.writeLong(0L); - } - } - currentWord |= 1L << index; - } - out.writeLong(currentWord); - assert currentWordIndex + 1 == totalWordCount; - } - /** Read {@code count} integers into {@code docIDs}. */ void readInts(IndexInput in, int count, int[] docIDs) throws IOException { if (count == 0) { @@ -225,9 +177,6 @@ void readInts(IndexInput in, int count, int[] docIDs) throws IOException { case CONTINUOUS_IDS: readContinuousIds(in, count, docIDs); break; - case BITSET_IDS: - readBitSet(in, count, docIDs); - break; case DELTA_BPV_16: readDelta16(in, count, docIDs); break; @@ -245,20 +194,6 @@ void readInts(IndexInput in, int count, int[] docIDs) throws IOException { } } - private DocIdSetIterator readBitSetIterator(IndexInput in, int count) throws IOException { - int offsetWords = in.readVInt(); - int longLen = in.readVInt(); - scratchLongs.longs = ArrayUtil.growNoCopy(scratchLongs.longs, longLen); - in.readLongs(scratchLongs.longs, 0, longLen); - // make ghost bits clear for FixedBitSet. - if (longLen < scratchLongs.length) { - Arrays.fill(scratchLongs.longs, longLen, scratchLongs.longs.length, 0); - } - scratchLongs.length = longLen; - FixedBitSet bitSet = new FixedBitSet(scratchLongs.longs, longLen << 6); - return new DocBaseBitSetIterator(bitSet, count, offsetWords << 6); - } - private static void readContinuousIds(IndexInput in, int count, int[] docIDs) throws IOException { int start = in.readVInt(); for (int i = 0; i < count; i++) { @@ -266,15 +201,6 @@ private static void readContinuousIds(IndexInput in, int count, int[] docIDs) th } } - private void readBitSet(IndexInput in, int count, int[] docIDs) throws IOException { - DocIdSetIterator iterator = readBitSetIterator(in, count); - int docId, pos = 0; - while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { - docIDs[pos++] = docId; - } - assert pos == count : "pos: " + pos + ", count: " + count; - } - private static void readDelta16(IndexInput in, int count, int[] docIds) throws IOException { final int min = in.readVInt(); final int half = count >> 1;