Skip to content

Commit a7b434b

Browse files
authored
Adjust knn query cancellation checks to use collector patter (#132515)
1 parent cd0e4ae commit a7b434b

File tree

2 files changed

+11
-111
lines changed

2 files changed

+11
-111
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import org.apache.lucene.index.VectorEncoding;
2323
import org.apache.lucene.index.VectorSimilarityFunction;
2424
import org.apache.lucene.internal.hppc.IntObjectHashMap;
25-
import org.apache.lucene.search.AbstractKnnCollector;
2625
import org.apache.lucene.search.KnnCollector;
2726
import org.apache.lucene.store.ChecksumIndexInput;
2827
import org.apache.lucene.store.DataInput;
@@ -232,8 +231,6 @@ public final void search(String field, float[] target, KnnCollector knnCollector
232231
}
233232
return visitedDocs.getAndSet(docId) == false;
234233
};
235-
assert knnCollector instanceof AbstractKnnCollector;
236-
AbstractKnnCollector knnCollectorImpl = (AbstractKnnCollector) knnCollector;
237234
int nProbe = DYNAMIC_NPROBE;
238235
// Search strategy may be null if this is being called from checkIndex (e.g. from a test)
239236
if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) {
@@ -259,7 +256,8 @@ public final void search(String field, float[] target, KnnCollector knnCollector
259256
// Note, numCollected is doing the bare minimum here.
260257
// TODO do we need to handle nested doc counts similarly to how we handle
261258
// filtering? E.g. keep exploring until we hit an expected number of parent documents vs. child vectors?
262-
while (centroidIterator.hasNext() && (centroidsVisited < nProbe || knnCollectorImpl.numCollected() < knnCollector.k())) {
259+
while (centroidIterator.hasNext()
260+
&& (centroidsVisited < nProbe || knnCollector.minCompetitiveSimilarity() == Float.NEGATIVE_INFINITY)) {
263261
++centroidsVisited;
264262
// todo do we actually need to know the score???
265263
long offset = centroidIterator.nextPostingListOffset();

server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java

Lines changed: 9 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import org.apache.lucene.search.KnnCollector;
2727
import org.apache.lucene.search.VectorScorer;
2828
import org.apache.lucene.search.suggest.document.CompletionTerms;
29-
import org.apache.lucene.util.BitSet;
3029
import org.apache.lucene.util.Bits;
3130
import org.apache.lucene.util.BytesRef;
3231
import org.apache.lucene.util.automaton.CompiledAutomaton;
@@ -146,7 +145,7 @@ public void searchNearestVectors(String field, byte[] target, KnnCollector colle
146145
in.searchNearestVectors(field, target, collector, acceptDocs);
147146
return;
148147
}
149-
in.searchNearestVectors(field, target, collector, createTimeOutCheckingBits(acceptDocs));
148+
in.searchNearestVectors(field, target, new TimeOutCheckingKnnCollector(collector), acceptDocs);
150149
}
151150

152151
@Override
@@ -164,122 +163,25 @@ public void searchNearestVectors(String field, float[] target, KnnCollector coll
164163
in.searchNearestVectors(field, target, collector, acceptDocs);
165164
return;
166165
}
167-
in.searchNearestVectors(field, target, collector, createTimeOutCheckingBits(acceptDocs));
166+
in.searchNearestVectors(field, target, new TimeOutCheckingKnnCollector(collector), acceptDocs);
168167
}
169168

170-
private Bits createTimeOutCheckingBits(Bits acceptDocs) {
171-
if (acceptDocs == null || acceptDocs instanceof BitSet) {
172-
return new TimeOutCheckingBitSet((BitSet) acceptDocs);
173-
}
174-
return new TimeOutCheckingBits(acceptDocs);
175-
}
176-
177-
private class TimeOutCheckingBitSet extends BitSet {
178-
private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 10;
179-
private int calls;
180-
private final BitSet inner;
181-
private final int maxDoc;
182-
183-
private TimeOutCheckingBitSet(BitSet inner) {
184-
this.inner = inner;
185-
this.maxDoc = maxDoc();
186-
}
187-
188-
@Override
189-
public void set(int i) {
190-
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
191-
}
192-
193-
@Override
194-
public boolean getAndSet(int i) {
195-
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
196-
}
197-
198-
@Override
199-
public void clear(int i) {
200-
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
201-
}
202-
203-
@Override
204-
public void clear(int startIndex, int endIndex) {
205-
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
206-
}
207-
208-
@Override
209-
public int cardinality() {
210-
if (inner == null) {
211-
return maxDoc;
212-
}
213-
return inner.cardinality();
214-
}
215-
216-
@Override
217-
public int approximateCardinality() {
218-
if (inner == null) {
219-
return maxDoc;
220-
}
221-
return inner.approximateCardinality();
222-
}
223-
224-
@Override
225-
public int prevSetBit(int index) {
226-
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
227-
}
228-
229-
@Override
230-
public int nextSetBit(int start, int end) {
231-
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
232-
}
233-
234-
@Override
235-
public long ramBytesUsed() {
236-
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
237-
}
238-
239-
@Override
240-
public boolean get(int index) {
241-
if (calls++ % MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK == 0) {
242-
queryCancellation.checkCancelled();
243-
}
244-
if (inner == null) {
245-
// if acceptDocs is null, we assume all docs are accepted
246-
return index >= 0 && index < maxDoc;
247-
}
248-
return inner.get(index);
249-
}
250-
251-
@Override
252-
public int length() {
253-
if (inner == null) {
254-
// if acceptDocs is null, we assume all docs are accepted
255-
return maxDoc;
256-
}
257-
return 0;
258-
}
259-
}
260-
261-
private class TimeOutCheckingBits implements Bits {
169+
private class TimeOutCheckingKnnCollector extends KnnCollector.Decorator {
170+
private final KnnCollector in;
262171
private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 10;
263-
private final Bits updatedAcceptDocs;
264172
private int calls;
265173

266-
private TimeOutCheckingBits(Bits acceptDocs) {
267-
// when acceptDocs is null due to no doc deleted, we will instantiate a new one that would
268-
// match all docs to allow timeout checking.
269-
this.updatedAcceptDocs = acceptDocs == null ? new Bits.MatchAllBits(maxDoc()) : acceptDocs;
174+
private TimeOutCheckingKnnCollector(KnnCollector in) {
175+
super(in);
176+
this.in = in;
270177
}
271178

272179
@Override
273-
public boolean get(int index) {
180+
public boolean collect(int docId, float similarity) {
274181
if (calls++ % MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK == 0) {
275182
queryCancellation.checkCancelled();
276183
}
277-
return updatedAcceptDocs.get(index);
278-
}
279-
280-
@Override
281-
public int length() {
282-
return updatedAcceptDocs.length();
184+
return in.collect(docId, similarity);
283185
}
284186
}
285187
}

0 commit comments

Comments
 (0)