diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index 312172f251dda..2523c8dd3e2c3 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -297,8 +297,7 @@ private static void score( PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput indexInput, float[] target, IntPredicate needsScoring) throws IOException { FieldEntry entry = fields.get(fieldInfo.number); - final int maxPostingListSize = indexInput.readVInt(); - return new MemorySegmentPostingsVisitor(target, indexInput, entry, fieldInfo, maxPostingListSize, needsScoring); + return new MemorySegmentPostingsVisitor(target, indexInput, entry, fieldInfo, needsScoring); } @Override @@ -341,7 +340,6 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor { IndexInput indexInput, FieldEntry entry, FieldInfo fieldInfo, - int maxPostingListSize, IntPredicate needsScoring ) throws IOException { this.target = target; @@ -358,7 +356,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor { quantizedVectorByteSize = (discretizedDimensions / 8); quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction(), DEFAULT_LAMBDA, 1); osqVectorsScorer = ESVectorUtil.getES91OSQVectorsScorer(indexInput, fieldInfo.getVectorDimension()); - this.docIdsScratch = new int[maxPostingListSize]; + this.docIdsScratch = new int[ES91OSQVectorsScorer.BULK_SIZE]; } @Override @@ -369,25 +367,23 @@ public int resetPostingsScorer(long offset) throws IOException { centroidDp = Float.intBitsToFloat(indexInput.readInt()); vectors = indexInput.readVInt(); // read the doc ids - assert vectors <= docIdsScratch.length; - docIdsWriter.readInts(indexInput, vectors, docIdsScratch); slicePos = indexInput.getFilePointer(); return vectors; } - private float scoreIndividually(int offset) throws IOException { + private float scoreIndividually() throws IOException { float maxScore = Float.NEGATIVE_INFINITY; // score individually, first the quantized byte chunk for (int j = 0; j < BULK_SIZE; j++) { - int doc = docIdsScratch[j + offset]; + int doc = docIdsScratch[j]; if (doc != -1) { - indexInput.seek(slicePos + (offset * quantizedByteLength) + (j * quantizedVectorByteSize)); float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch); scores[j] = qcDist; + } else { + indexInput.skipBytes(quantizedVectorByteSize); } } // read in all corrections - indexInput.seek(slicePos + (offset * quantizedByteLength) + (BULK_SIZE * quantizedVectorByteSize)); indexInput.readFloats(correctionsLower, 0, BULK_SIZE); indexInput.readFloats(correctionsUpper, 0, BULK_SIZE); for (int j = 0; j < BULK_SIZE; j++) { @@ -396,7 +392,7 @@ private float scoreIndividually(int offset) throws IOException { indexInput.readFloats(correctionsAdd, 0, BULK_SIZE); // Now apply corrections for (int j = 0; j < BULK_SIZE; j++) { - int doc = docIdsScratch[offset + j]; + int doc = docIdsScratch[j]; if (doc != -1) { scores[j] = osqVectorsScorer.score( queryCorrections.lowerInterval(), @@ -419,21 +415,20 @@ private float scoreIndividually(int offset) throws IOException { return maxScore; } - private static int docToBulkScore(int[] docIds, int offset, IntPredicate needsScoring) { + private static int docToBulkScore(int[] docIds, IntPredicate needsScoring) { int docToScore = ES91OSQVectorsScorer.BULK_SIZE; for (int i = 0; i < ES91OSQVectorsScorer.BULK_SIZE; i++) { - final int idx = offset + i; - if (needsScoring.test(docIds[idx]) == false) { - docIds[idx] = -1; + if (needsScoring.test(docIds[i]) == false) { + docIds[i] = -1; docToScore--; } } return docToScore; } - private static void collectBulk(int[] docIds, int offset, KnnCollector knnCollector, float[] scores) { + private static void collectBulk(int[] docIds, KnnCollector knnCollector, float[] scores) { for (int i = 0; i < ES91OSQVectorsScorer.BULK_SIZE; i++) { - final int doc = docIds[offset + i]; + final int doc = docIds[i]; if (doc != -1) { knnCollector.collect(doc, scores[i]); } @@ -442,20 +437,23 @@ private static void collectBulk(int[] docIds, int offset, KnnCollector knnCollec @Override public int visit(KnnCollector knnCollector) throws IOException { + indexInput.seek(slicePos); // block processing int scoredDocs = 0; int limit = vectors - BULK_SIZE + 1; int i = 0; + for (; i < limit; i += BULK_SIZE) { - final int docsToBulkScore = docToBulkScore(docIdsScratch, i, needsScoring); + docIdsWriter.readInts(indexInput, BULK_SIZE, docIdsScratch); + final int docsToBulkScore = docToBulkScore(docIdsScratch, needsScoring); if (docsToBulkScore == 0) { + indexInput.skipBytes(BULK_SIZE * quantizedByteLength); continue; } quantizeQueryIfNecessary(); - indexInput.seek(slicePos + i * quantizedByteLength); final float maxScore; if (docsToBulkScore < BULK_SIZE / 2) { - maxScore = scoreIndividually(i); + maxScore = scoreIndividually(); } else { maxScore = osqVectorsScorer.scoreBulk( quantizedQueryScratch, @@ -469,16 +467,17 @@ public int visit(KnnCollector knnCollector) throws IOException { ); } if (knnCollector.minCompetitiveSimilarity() < maxScore) { - collectBulk(docIdsScratch, i, knnCollector, scores); + collectBulk(docIdsScratch, knnCollector, scores); } scoredDocs += docsToBulkScore; } // process tail - for (; i < vectors; i++) { - int doc = docIdsScratch[i]; + int tailLength = vectors - i; + docIdsWriter.readInts(indexInput, tailLength, docIdsScratch); + for (int j = 0; j < tailLength; j++) { + int doc = docIdsScratch[j]; if (needsScoring.test(doc)) { quantizeQueryIfNecessary(); - indexInput.seek(slicePos + i * quantizedByteLength); float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch); indexInput.readFloats(correctiveValues, 0, 3); final int quantizedComponentSum = Short.toUnsignedInt(indexInput.readShort()); @@ -497,6 +496,8 @@ public int visit(KnnCollector knnCollector) throws IOException { ); scoredDocs++; knnCollector.collect(doc, score); + } else { + indexInput.skipBytes(quantizedByteLength); } } if (scoredDocs > 0) { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index d16163d6934e8..88532ebd5fd19 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -77,11 +77,9 @@ LongValues buildAndWritePostingsLists( } } - int maxPostingListSize = 0; int[][] assignmentsByCluster = new int[centroidSupplier.size()][]; for (int c = 0; c < centroidSupplier.size(); c++) { int size = centroidVectorCount[c]; - maxPostingListSize = Math.max(maxPostingListSize, size); assignmentsByCluster[c] = new int[size]; } Arrays.fill(centroidVectorCount, 0); @@ -97,11 +95,8 @@ LongValues buildAndWritePostingsLists( } } } - // write the max posting list size - postingsOutput.writeVInt(maxPostingListSize); // write the posting lists final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT); - DocIdsWriter docIdsWriter = new DocIdsWriter(); DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput); OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors( floatVectorValues, @@ -125,9 +120,8 @@ LongValues buildAndWritePostingsLists( // TODO we might want to consider putting the docIds in a separate file // to aid with only having to fetch vectors from slower storage when they are required // keeping them in the same file indicates we pull the entire file into cache - docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput); // write vectors - bulkWriter.writeVectors(onHeapQuantizedVectors); + bulkWriter.writeVectors(onHeapQuantizedVectors, j -> floatVectorValues.ordToDoc(cluster[j])); } if (logger.isDebugEnabled()) { @@ -203,12 +197,10 @@ LongValues buildAndWritePostingsLists( } } - int maxPostingListSize = 0; int[][] assignmentsByCluster = new int[centroidSupplier.size()][]; boolean[][] isOverspillByCluster = new boolean[centroidSupplier.size()][]; for (int c = 0; c < centroidSupplier.size(); c++) { int size = centroidVectorCount[c]; - maxPostingListSize = Math.max(maxPostingListSize, size); assignmentsByCluster[c] = new int[size]; isOverspillByCluster[c] = new boolean[size]; } @@ -233,11 +225,8 @@ LongValues buildAndWritePostingsLists( quantizedVectorsInput, fieldInfo.getVectorDimension() ); - DocIdsWriter docIdsWriter = new DocIdsWriter(); DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput); final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - // write the max posting list size - postingsOutput.writeVInt(maxPostingListSize); // write the posting lists for (int c = 0; c < centroidSupplier.size(); c++) { float[] centroid = centroidSupplier.centroid(c); @@ -256,9 +245,8 @@ LongValues buildAndWritePostingsLists( // TODO we might want to consider putting the docIds in a separate file // to aid with only having to fetch vectors from slower storage when they are required // keeping them in the same file indicates we pull the entire file into cache - docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput); // write vectors - bulkWriter.writeVectors(offHeapQuantizedVectors); + bulkWriter.writeVectors(offHeapQuantizedVectors, j -> floatVectorValues.ordToDoc(cluster[j])); } if (logger.isDebugEnabled()) { @@ -342,7 +330,7 @@ private void writeCentroidsWithParents( osq, globalCentroid ); - bulkWriter.writeVectors(parentQuantizeCentroid); + bulkWriter.writeVectors(parentQuantizeCentroid, null); int offset = 0; for (int i = 0; i < centroidGroups.centroids().length; i++) { centroidOutput.writeInt(offset); @@ -359,7 +347,7 @@ private void writeCentroidsWithParents( for (int i = 0; i < centroidGroups.centroids().length; i++) { final int[] centroidAssignments = centroidGroups.vectors()[i]; childrenQuantizeCentroid.reset(idx -> centroidAssignments[idx], centroidAssignments.length); - bulkWriter.writeVectors(childrenQuantizeCentroid); + bulkWriter.writeVectors(childrenQuantizeCentroid, null); } // write the centroid offsets at the end of the file for (int i = 0; i < centroidGroups.centroids().length; i++) { @@ -389,7 +377,7 @@ private void writeCentroidsWithoutParents( osq, globalCentroid ); - bulkWriter.writeVectors(quantizedCentroids); + bulkWriter.writeVectors(quantizedCentroids, null); // write the centroid offsets at the end of the file for (int i = 0; i < centroidSupplier.size(); i++) { centroidOutput.writeLong(offsets.get(i)); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java index 1d7e7f74f6c14..095ddbe295d81 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java @@ -10,6 +10,7 @@ package org.elasticsearch.index.codec.vectors; import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.hnsw.IntToIntFunction; import java.io.IOException; @@ -27,10 +28,11 @@ protected DiskBBQBulkWriter(int bulkSize, IndexOutput out) { this.out = out; } - abstract void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException; + abstract void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv, IntToIntFunction docIds) throws IOException; static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter { private final OptimizedScalarQuantizer.QuantizationResult[] corrections; + protected DocIdsWriter docIdsWriter = new DocIdsWriter(); OneBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) { super(bulkSize, out); @@ -38,10 +40,12 @@ static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter { } @Override - void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException { + void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv, IntToIntFunction docIds) throws IOException { int limit = qvv.count() - bulkSize + 1; int i = 0; for (; i < limit; i += bulkSize) { + int offset = i; + docIdsWriter.writeDocIds(idx -> docIds.apply(offset + idx), bulkSize, out); for (int j = 0; j < bulkSize; j++) { byte[] qv = qvv.next(); corrections[j] = qvv.getCorrections(); @@ -50,6 +54,8 @@ void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOEx writeCorrections(corrections); } // write tail + int offset = i; + docIdsWriter.writeDocIds(idx -> docIds.apply(offset + idx), qvv.count() - i, out); for (; i < qvv.count(); ++i) { byte[] qv = qvv.next(); OptimizedScalarQuantizer.QuantizationResult correction = qvv.getCorrections(); @@ -94,7 +100,8 @@ static class SevenBitDiskBBQBulkWriter extends DiskBBQBulkWriter { } @Override - void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException { + void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv, IntToIntFunction docIds) throws IOException { + assert docIds == null; int limit = qvv.count() - bulkSize + 1; int i = 0; for (; i < limit; i += bulkSize) { diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java index 2c0d2f3fc7449..4f38a6f64e435 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java @@ -22,8 +22,10 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.util.TestUtil; @@ -145,6 +147,30 @@ public void testSimpleOffHeapSize() throws IOException { } } + public void testSameVectorManyTimes() throws IOException { + float[] vector = randomVector(random().nextInt(12, 500)); + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + for (int i = 0; i < 10_000; i++) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, VectorSimilarityFunction.EUCLIDEAN)); + w.addDocument(doc); + } + w.commit(); + if (rarely()) { + w.forceMerge(1); + } + try (IndexReader reader = DirectoryReader.open(w)) { + List subReaders = reader.leaves(); + for (LeafReaderContext r : subReaders) { + LeafReader leafReader = r.reader(); + TopDocs topDocs = leafReader.searchNearestVectors("f", vector, 10, leafReader.getLiveDocs(), Integer.MAX_VALUE); + assertEquals(Math.min(leafReader.maxDoc(), 10), topDocs.scoreDocs.length); + } + + } + } + } + // this is a modified version of lucene's TestSearchWithThreads test case public void testWithThreads() throws Exception { final int numThreads = random().nextInt(2, 5);