diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index 312172f251dda..3fceace3bbf51 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -16,6 +16,7 @@ import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.Bits; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.NeighborQueue; import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats; @@ -25,7 +26,6 @@ import java.io.IOException; import java.util.Map; -import java.util.function.IntPredicate; import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS; import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; @@ -294,11 +294,10 @@ private static void score( } @Override - PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput indexInput, float[] target, IntPredicate needsScoring) - throws IOException { + PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput indexInput, float[] target, Bits acceptDocs) throws IOException { FieldEntry entry = fields.get(fieldInfo.number); final int maxPostingListSize = indexInput.readVInt(); - return new MemorySegmentPostingsVisitor(target, indexInput, entry, fieldInfo, maxPostingListSize, needsScoring); + return new MemorySegmentPostingsVisitor(target, indexInput, entry, fieldInfo, maxPostingListSize, acceptDocs); } @Override @@ -312,7 +311,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor { final float[] target; final FieldEntry entry; final FieldInfo fieldInfo; - final IntPredicate needsScoring; + final Bits acceptDocs; private final ES91OSQVectorsScorer osqVectorsScorer; final float[] scores = new float[BULK_SIZE]; final float[] correctionsLower = new float[BULK_SIZE]; @@ -342,13 +341,13 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor { FieldEntry entry, FieldInfo fieldInfo, int maxPostingListSize, - IntPredicate needsScoring + Bits acceptDocs ) throws IOException { this.target = target; this.indexInput = indexInput; this.entry = entry; this.fieldInfo = fieldInfo; - this.needsScoring = needsScoring; + this.acceptDocs = acceptDocs; centroid = new float[fieldInfo.getVectorDimension()]; scratch = new float[target.length]; quantizationScratch = new int[target.length]; @@ -419,11 +418,12 @@ private float scoreIndividually(int offset) throws IOException { return maxScore; } - private static int docToBulkScore(int[] docIds, int offset, IntPredicate needsScoring) { + private static int docToBulkScore(int[] docIds, int offset, Bits acceptDocs) { + assert acceptDocs != null : "acceptDocs must not be null"; int docToScore = ES91OSQVectorsScorer.BULK_SIZE; for (int i = 0; i < ES91OSQVectorsScorer.BULK_SIZE; i++) { final int idx = offset + i; - if (needsScoring.test(docIds[idx]) == false) { + if (acceptDocs.get(docIds[idx]) == false) { docIds[idx] = -1; docToScore--; } @@ -447,7 +447,7 @@ public int visit(KnnCollector knnCollector) throws IOException { int limit = vectors - BULK_SIZE + 1; int i = 0; for (; i < limit; i += BULK_SIZE) { - final int docsToBulkScore = docToBulkScore(docIdsScratch, i, needsScoring); + final int docsToBulkScore = acceptDocs == null ? BULK_SIZE : docToBulkScore(docIdsScratch, i, acceptDocs); if (docsToBulkScore == 0) { continue; } @@ -476,7 +476,7 @@ public int visit(KnnCollector knnCollector) throws IOException { // process tail for (; i < vectors; i++) { int doc = docIdsScratch[i]; - if (needsScoring.test(doc)) { + if (acceptDocs == null || acceptDocs.get(doc)) { quantizeQueryIfNecessary(); indexInput.seek(slicePos + i * quantizedByteLength); float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java index e20b27836d680..0043f78590ac1 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java @@ -29,12 +29,10 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; -import org.apache.lucene.util.FixedBitSet; import org.elasticsearch.core.IOUtils; import org.elasticsearch.search.vectors.IVFKnnSearchStrategy; import java.io.IOException; -import java.util.function.IntPredicate; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS; import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.DYNAMIC_NPROBE; @@ -224,13 +222,6 @@ public final void search(String field, float[] target, KnnCollector knnCollector percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length())); } int numVectors = rawVectorsReader.getFloatVectorValues(field).size(); - BitSet visitedDocs = new FixedBitSet(state.segmentInfo.maxDoc() + 1); - IntPredicate needsScoring = docId -> { - if (acceptDocs != null && acceptDocs.get(docId) == false) { - return false; - } - return visitedDocs.getAndSet(docId) == false; - }; int nProbe = DYNAMIC_NPROBE; // Search strategy may be null if this is being called from checkIndex (e.g. from a test) if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) { @@ -248,7 +239,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector nProbe = Math.max(Math.min(nProbe, entry.numCentroids), 1); } CentroidIterator centroidIterator = getCentroidIterator(fieldInfo, entry.numCentroids, entry.centroidSlice(ivfCentroids), target); - PostingVisitor scorer = getPostingVisitor(fieldInfo, entry.postingListSlice(ivfClusters), target, needsScoring); + PostingVisitor scorer = getPostingVisitor(fieldInfo, entry.postingListSlice(ivfClusters), target, acceptDocs); int centroidsVisited = 0; long expectedDocs = 0; long actualDocs = 0; @@ -316,7 +307,7 @@ IndexInput postingListSlice(IndexInput postingListFile) throws IOException { } } - abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postingsLists, float[] target, IntPredicate needsScoring) + abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postingsLists, float[] target, Bits needsScoring) throws IOException; interface CentroidIterator { diff --git a/server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java index 18d11ce667d24..16b32c46972bc 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java @@ -9,6 +9,8 @@ package org.elasticsearch.search.vectors; +import com.carrotsearch.hppc.IntHashSet; + import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; @@ -115,7 +117,10 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { filterWeight = null; } // we request numCands as we are using it as an approximation measure - KnnCollectorManager knnCollectorManager = getKnnCollectorManager(numCands, indexSearcher); + // we need to ensure we are getting at least 2*k results to ensure we cover overspill duplicates + // TODO move the logic for automatically adjusting percentages/nprobe to the query, so we can only pass + // 2k to the collector. + KnnCollectorManager knnCollectorManager = getKnnCollectorManager(Math.max(Math.round(2f * k), numCands), indexSearcher); TaskExecutor taskExecutor = indexSearcher.getTaskExecutor(); List leafReaderContexts = reader.leaves(); List> tasks = new ArrayList<>(leafReaderContexts.size()); @@ -135,12 +140,23 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException { TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager); - if (ctx.docBase > 0) { - for (ScoreDoc scoreDoc : results.scoreDocs) { + IntHashSet dedup = new IntHashSet(results.scoreDocs.length * 4 / 3); + int deduplicateCount = 0; + for (ScoreDoc scoreDoc : results.scoreDocs) { + if (dedup.add(scoreDoc.doc)) { + deduplicateCount++; + } + } + ScoreDoc[] deduplicatedScoreDocs = new ScoreDoc[deduplicateCount]; + dedup.clear(); + int index = 0; + for (ScoreDoc scoreDoc : results.scoreDocs) { + if (dedup.add(scoreDoc.doc)) { scoreDoc.doc += ctx.docBase; + deduplicatedScoreDocs[index++] = scoreDoc; } } - return results; + return new TopDocs(results.totalHits, deduplicatedScoreDocs); } TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException {