diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java index bde7b1d5b60c0..e20b27836d680 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java @@ -22,7 +22,6 @@ import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.internal.hppc.IntObjectHashMap; -import org.apache.lucene.search.AbstractKnnCollector; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; @@ -232,8 +231,6 @@ public final void search(String field, float[] target, KnnCollector knnCollector } return visitedDocs.getAndSet(docId) == false; }; - assert knnCollector instanceof AbstractKnnCollector; - AbstractKnnCollector knnCollectorImpl = (AbstractKnnCollector) knnCollector; int nProbe = DYNAMIC_NPROBE; // Search strategy may be null if this is being called from checkIndex (e.g. from a test) if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) { @@ -259,7 +256,8 @@ public final void search(String field, float[] target, KnnCollector knnCollector // Note, numCollected is doing the bare minimum here. // TODO do we need to handle nested doc counts similarly to how we handle // filtering? E.g. keep exploring until we hit an expected number of parent documents vs. child vectors? - while (centroidIterator.hasNext() && (centroidsVisited < nProbe || knnCollectorImpl.numCollected() < knnCollector.k())) { + while (centroidIterator.hasNext() + && (centroidsVisited < nProbe || knnCollector.minCompetitiveSimilarity() == Float.NEGATIVE_INFINITY)) { ++centroidsVisited; // todo do we actually need to know the score??? long offset = centroidIterator.nextPostingListOffset(); diff --git a/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java b/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java index 9c998eb920dc9..22c93032aecbe 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java +++ b/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java @@ -26,7 +26,6 @@ import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.VectorScorer; import org.apache.lucene.search.suggest.document.CompletionTerms; -import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.automaton.CompiledAutomaton; @@ -146,7 +145,7 @@ public void searchNearestVectors(String field, byte[] target, KnnCollector colle in.searchNearestVectors(field, target, collector, acceptDocs); return; } - in.searchNearestVectors(field, target, collector, createTimeOutCheckingBits(acceptDocs)); + in.searchNearestVectors(field, target, new TimeOutCheckingKnnCollector(collector), acceptDocs); } @Override @@ -164,122 +163,25 @@ public void searchNearestVectors(String field, float[] target, KnnCollector coll in.searchNearestVectors(field, target, collector, acceptDocs); return; } - in.searchNearestVectors(field, target, collector, createTimeOutCheckingBits(acceptDocs)); + in.searchNearestVectors(field, target, new TimeOutCheckingKnnCollector(collector), acceptDocs); } - private Bits createTimeOutCheckingBits(Bits acceptDocs) { - if (acceptDocs == null || acceptDocs instanceof BitSet) { - return new TimeOutCheckingBitSet((BitSet) acceptDocs); - } - return new TimeOutCheckingBits(acceptDocs); - } - - private class TimeOutCheckingBitSet extends BitSet { - private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 10; - private int calls; - private final BitSet inner; - private final int maxDoc; - - private TimeOutCheckingBitSet(BitSet inner) { - this.inner = inner; - this.maxDoc = maxDoc(); - } - - @Override - public void set(int i) { - throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet"); - } - - @Override - public boolean getAndSet(int i) { - throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet"); - } - - @Override - public void clear(int i) { - throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet"); - } - - @Override - public void clear(int startIndex, int endIndex) { - throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet"); - } - - @Override - public int cardinality() { - if (inner == null) { - return maxDoc; - } - return inner.cardinality(); - } - - @Override - public int approximateCardinality() { - if (inner == null) { - return maxDoc; - } - return inner.approximateCardinality(); - } - - @Override - public int prevSetBit(int index) { - throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet"); - } - - @Override - public int nextSetBit(int start, int end) { - throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet"); - } - - @Override - public long ramBytesUsed() { - throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet"); - } - - @Override - public boolean get(int index) { - if (calls++ % MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK == 0) { - queryCancellation.checkCancelled(); - } - if (inner == null) { - // if acceptDocs is null, we assume all docs are accepted - return index >= 0 && index < maxDoc; - } - return inner.get(index); - } - - @Override - public int length() { - if (inner == null) { - // if acceptDocs is null, we assume all docs are accepted - return maxDoc; - } - return 0; - } - } - - private class TimeOutCheckingBits implements Bits { + private class TimeOutCheckingKnnCollector extends KnnCollector.Decorator { + private final KnnCollector in; private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 10; - private final Bits updatedAcceptDocs; private int calls; - private TimeOutCheckingBits(Bits acceptDocs) { - // when acceptDocs is null due to no doc deleted, we will instantiate a new one that would - // match all docs to allow timeout checking. - this.updatedAcceptDocs = acceptDocs == null ? new Bits.MatchAllBits(maxDoc()) : acceptDocs; + private TimeOutCheckingKnnCollector(KnnCollector in) { + super(in); + this.in = in; } @Override - public boolean get(int index) { + public boolean collect(int docId, float similarity) { if (calls++ % MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK == 0) { queryCancellation.checkCancelled(); } - return updatedAcceptDocs.get(index); - } - - @Override - public int length() { - return updatedAcceptDocs.length(); + return in.collect(docId, similarity); } } }