Skip to content

Remove soar duplicate checking #132617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats;
Expand All @@ -25,7 +26,6 @@

import java.io.IOException;
import java.util.Map;
import java.util.function.IntPredicate;

import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS;
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
Expand Down Expand Up @@ -294,11 +294,10 @@ private static void score(
}

@Override
PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput indexInput, float[] target, IntPredicate needsScoring)
throws IOException {
PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput indexInput, float[] target, Bits acceptDocs) throws IOException {
FieldEntry entry = fields.get(fieldInfo.number);
final int maxPostingListSize = indexInput.readVInt();
return new MemorySegmentPostingsVisitor(target, indexInput, entry, fieldInfo, maxPostingListSize, needsScoring);
return new MemorySegmentPostingsVisitor(target, indexInput, entry, fieldInfo, maxPostingListSize, acceptDocs);
}

@Override
Expand All @@ -312,7 +311,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
final float[] target;
final FieldEntry entry;
final FieldInfo fieldInfo;
final IntPredicate needsScoring;
final Bits acceptDocs;
private final ES91OSQVectorsScorer osqVectorsScorer;
final float[] scores = new float[BULK_SIZE];
final float[] correctionsLower = new float[BULK_SIZE];
Expand Down Expand Up @@ -342,13 +341,13 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
FieldEntry entry,
FieldInfo fieldInfo,
int maxPostingListSize,
IntPredicate needsScoring
Bits acceptDocs
) throws IOException {
this.target = target;
this.indexInput = indexInput;
this.entry = entry;
this.fieldInfo = fieldInfo;
this.needsScoring = needsScoring;
this.acceptDocs = acceptDocs;
centroid = new float[fieldInfo.getVectorDimension()];
scratch = new float[target.length];
quantizationScratch = new int[target.length];
Expand Down Expand Up @@ -419,11 +418,12 @@ private float scoreIndividually(int offset) throws IOException {
return maxScore;
}

private static int docToBulkScore(int[] docIds, int offset, IntPredicate needsScoring) {
private static int docToBulkScore(int[] docIds, int offset, Bits acceptDocs) {
assert acceptDocs != null : "acceptDocs must not be null";
int docToScore = ES91OSQVectorsScorer.BULK_SIZE;
for (int i = 0; i < ES91OSQVectorsScorer.BULK_SIZE; i++) {
final int idx = offset + i;
if (needsScoring.test(docIds[idx]) == false) {
if (acceptDocs.get(docIds[idx]) == false) {
docIds[idx] = -1;
docToScore--;
}
Expand All @@ -447,7 +447,7 @@ public int visit(KnnCollector knnCollector) throws IOException {
int limit = vectors - BULK_SIZE + 1;
int i = 0;
for (; i < limit; i += BULK_SIZE) {
final int docsToBulkScore = docToBulkScore(docIdsScratch, i, needsScoring);
final int docsToBulkScore = acceptDocs == null ? BULK_SIZE : docToBulkScore(docIdsScratch, i, acceptDocs);
if (docsToBulkScore == 0) {
continue;
}
Expand Down Expand Up @@ -476,7 +476,7 @@ public int visit(KnnCollector knnCollector) throws IOException {
// process tail
for (; i < vectors; i++) {
int doc = docIdsScratch[i];
if (needsScoring.test(doc)) {
if (acceptDocs == null || acceptDocs.get(doc)) {
quantizeQueryIfNecessary();
indexInput.seek(slicePos + i * quantizedByteLength);
float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,10 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;

import java.io.IOException;
import java.util.function.IntPredicate;

import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.DYNAMIC_NPROBE;
Expand Down Expand Up @@ -224,13 +222,6 @@ public final void search(String field, float[] target, KnnCollector knnCollector
percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length()));
}
int numVectors = rawVectorsReader.getFloatVectorValues(field).size();
BitSet visitedDocs = new FixedBitSet(state.segmentInfo.maxDoc() + 1);
IntPredicate needsScoring = docId -> {
if (acceptDocs != null && acceptDocs.get(docId) == false) {
return false;
}
return visitedDocs.getAndSet(docId) == false;
};
int nProbe = DYNAMIC_NPROBE;
// Search strategy may be null if this is being called from checkIndex (e.g. from a test)
if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) {
Expand All @@ -248,7 +239,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
nProbe = Math.max(Math.min(nProbe, entry.numCentroids), 1);
}
CentroidIterator centroidIterator = getCentroidIterator(fieldInfo, entry.numCentroids, entry.centroidSlice(ivfCentroids), target);
PostingVisitor scorer = getPostingVisitor(fieldInfo, entry.postingListSlice(ivfClusters), target, needsScoring);
PostingVisitor scorer = getPostingVisitor(fieldInfo, entry.postingListSlice(ivfClusters), target, acceptDocs);
int centroidsVisited = 0;
long expectedDocs = 0;
long actualDocs = 0;
Expand Down Expand Up @@ -316,7 +307,7 @@ IndexInput postingListSlice(IndexInput postingListFile) throws IOException {
}
}

abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postingsLists, float[] target, IntPredicate needsScoring)
abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postingsLists, float[] target, Bits needsScoring)
throws IOException;

interface CentroidIterator {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

package org.elasticsearch.search.vectors;

import com.carrotsearch.hppc.IntHashSet;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
Expand Down Expand Up @@ -115,7 +117,10 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
filterWeight = null;
}
// we request numCands as we are using it as an approximation measure
KnnCollectorManager knnCollectorManager = getKnnCollectorManager(numCands, indexSearcher);
// we need to ensure we are getting at least 2*k results to ensure we cover overspill duplicates
// TODO move the logic for automatically adjusting percentages/nprobe to the query, so we can only pass
// 2k to the collector.
KnnCollectorManager knnCollectorManager = getKnnCollectorManager(Math.max(Math.round(2f * k), numCands), indexSearcher);
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
List<LeafReaderContext> leafReaderContexts = reader.leaves();
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());
Expand All @@ -135,12 +140,23 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {

private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException {
TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
IntHashSet dedup = new IntHashSet(results.scoreDocs.length * 4 / 3);
int deduplicateCount = 0;
for (ScoreDoc scoreDoc : results.scoreDocs) {
if (dedup.add(scoreDoc.doc)) {
deduplicateCount++;
}
}
ScoreDoc[] deduplicatedScoreDocs = new ScoreDoc[deduplicateCount];
dedup.clear();
int index = 0;
for (ScoreDoc scoreDoc : results.scoreDocs) {
if (dedup.add(scoreDoc.doc)) {
scoreDoc.doc += ctx.docBase;
deduplicatedScoreDocs[index++] = scoreDoc;
}
}
return results;
return new TopDocs(results.totalHits, deduplicatedScoreDocs);
}

TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException {
Expand Down