Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.AbstractKnnCollector;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
Expand Down Expand Up @@ -232,8 +231,6 @@ public final void search(String field, float[] target, KnnCollector knnCollector
}
return visitedDocs.getAndSet(docId) == false;
};
assert knnCollector instanceof AbstractKnnCollector;
AbstractKnnCollector knnCollectorImpl = (AbstractKnnCollector) knnCollector;
int nProbe = DYNAMIC_NPROBE;
// Search strategy may be null if this is being called from checkIndex (e.g. from a test)
if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) {
Expand All @@ -259,7 +256,8 @@ public final void search(String field, float[] target, KnnCollector knnCollector
// Note, numCollected is doing the bare minimum here.
// TODO do we need to handle nested doc counts similarly to how we handle
// filtering? E.g. keep exploring until we hit an expected number of parent documents vs. child vectors?
while (centroidIterator.hasNext() && (centroidsVisited < nProbe || knnCollectorImpl.numCollected() < knnCollector.k())) {
while (centroidIterator.hasNext()
&& (centroidsVisited < nProbe || knnCollector.minCompetitiveSimilarity() == Float.NEGATIVE_INFINITY)) {
Comment on lines +259 to +260
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were making a bad assumption before, assuming things were an AbstractKnnCollector. Instead, we should use the appropriate measure, which is Has the collector set a minCompetitiveSimilarity indicating it gathered enough vectors.

++centroidsVisited;
// todo do we actually need to know the score???
long offset = centroidIterator.nextPostingListOffset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.search.suggest.document.CompletionTerms;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.automaton.CompiledAutomaton;
Expand Down Expand Up @@ -146,7 +145,7 @@ public void searchNearestVectors(String field, byte[] target, KnnCollector colle
in.searchNearestVectors(field, target, collector, acceptDocs);
return;
}
in.searchNearestVectors(field, target, collector, createTimeOutCheckingBits(acceptDocs));
in.searchNearestVectors(field, target, new TimeOutCheckingKnnCollector(collector), acceptDocs);
}

@Override
Expand All @@ -164,122 +163,25 @@ public void searchNearestVectors(String field, float[] target, KnnCollector coll
in.searchNearestVectors(field, target, collector, acceptDocs);
return;
}
in.searchNearestVectors(field, target, collector, createTimeOutCheckingBits(acceptDocs));
in.searchNearestVectors(field, target, new TimeOutCheckingKnnCollector(collector), acceptDocs);
}

private Bits createTimeOutCheckingBits(Bits acceptDocs) {
if (acceptDocs == null || acceptDocs instanceof BitSet) {
return new TimeOutCheckingBitSet((BitSet) acceptDocs);
}
return new TimeOutCheckingBits(acceptDocs);
}

private class TimeOutCheckingBitSet extends BitSet {
private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 10;
private int calls;
private final BitSet inner;
private final int maxDoc;

private TimeOutCheckingBitSet(BitSet inner) {
this.inner = inner;
this.maxDoc = maxDoc();
}

@Override
public void set(int i) {
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
}

@Override
public boolean getAndSet(int i) {
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
}

@Override
public void clear(int i) {
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
}

@Override
public void clear(int startIndex, int endIndex) {
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
}

@Override
public int cardinality() {
if (inner == null) {
return maxDoc;
}
return inner.cardinality();
}

@Override
public int approximateCardinality() {
if (inner == null) {
return maxDoc;
}
return inner.approximateCardinality();
}

@Override
public int prevSetBit(int index) {
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
}

@Override
public int nextSetBit(int start, int end) {
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
}

@Override
public long ramBytesUsed() {
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
}

@Override
public boolean get(int index) {
if (calls++ % MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK == 0) {
queryCancellation.checkCancelled();
}
if (inner == null) {
// if acceptDocs is null, we assume all docs are accepted
return index >= 0 && index < maxDoc;
}
return inner.get(index);
}

@Override
public int length() {
if (inner == null) {
// if acceptDocs is null, we assume all docs are accepted
return maxDoc;
}
return 0;
}
}

private class TimeOutCheckingBits implements Bits {
private class TimeOutCheckingKnnCollector extends KnnCollector.Decorator {
private final KnnCollector in;
private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 10;
private final Bits updatedAcceptDocs;
private int calls;

private TimeOutCheckingBits(Bits acceptDocs) {
// when acceptDocs is null due to no doc deleted, we will instantiate a new one that would
// match all docs to allow timeout checking.
this.updatedAcceptDocs = acceptDocs == null ? new Bits.MatchAllBits(maxDoc()) : acceptDocs;
private TimeOutCheckingKnnCollector(KnnCollector in) {
super(in);
this.in = in;
}

@Override
public boolean get(int index) {
public boolean collect(int docId, float similarity) {
if (calls++ % MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK == 0) {
queryCancellation.checkCancelled();
}
return updatedAcceptDocs.get(index);
}

@Override
public int length() {
return updatedAcceptDocs.length();
return in.collect(docId, similarity);
}
}
}
Expand Down