Skip to content

Commit bfefe03

Browse files
authored
Remove soar duplicate checking (#132617)
Through our various benchmarking runs, I have noticed we do a silly amount of work just handling duplicate vectors for overspill. When it comes to block scoring, it is likely much better to just score the duplicates, and deduplicate later. This indeed is the case, and the performance increases as the number of vector ops increases. ## Multi-segment Cohere-wiki-768 8M I ran every nprobe 5 times and picked the fastest. ### CANDIDATE ``` index_name index_type n_probe latency(ms) net_cpu_time(ms) avg_cpu_count QPS recall visited filter_selectivity ------------------------------ ---------- ------- ----------- ---------------- ------------- ------ ------ ---------- ------------------ cohere-wikipedia-docs-768d.vec ivf 10 7.12 0.00 0.00 140.45 0.80 83108.96 1.00 cohere-wikipedia-docs-768d.vec ivf 20 10.47 0.00 0.00 95.51 0.86 169324.80 1.00 cohere-wikipedia-docs-768d.vec ivf 50 19.86 0.00 0.00 50.35 0.91 461667.04 1.00 cohere-wikipedia-docs-768d.vec ivf 100 33.65 0.00 0.00 29.72 0.94 950007.20 1.00 cohere-wikipedia-docs-768d.vec ivf 200 57.04 0.00 0.00 17.53 0.95 1797631.04 1.00 cohere-wikipedia-docs-768d.vec ivf 500 124.30 0.00 0.00 8.05 0.96 4334902.24 1.00 cohere-wikipedia-docs-768d.vec ivf 1000 236.78 0.00 0.00 4.22 0.96 8521820.48 1.00 ``` ### BASELINE ``` index_name index_type n_probe latency(ms) net_cpu_time(ms) avg_cpu_count QPS recall visited filter_selectivity ------------------------------ ---------- ------- ----------- ---------------- ------------- ------ ------ ---------- ------------------ cohere-wikipedia-docs-768d.vec ivf 10 7.21 0.00 0.00 138.70 0.81 74077.53 1.00 cohere-wikipedia-docs-768d.vec ivf 20 10.83 0.00 0.00 92.34 0.86 144966.33 1.00 cohere-wikipedia-docs-768d.vec ivf 50 21.75 0.00 0.00 45.98 0.91 365150.68 1.00 cohere-wikipedia-docs-768d.vec ivf 100 38.25 0.00 0.00 26.14 0.93 698105.96 1.00 cohere-wikipedia-docs-768d.vec ivf 200 65.61 0.00 0.00 15.24 0.95 1278157.01 1.00 cohere-wikipedia-docs-768d.vec ivf 500 148.98 0.00 0.00 6.71 0.95 2890457.27 1.00 cohere-wikipedia-docs-768d.vec ivf 1000 281.02 0.00 0.00 3.56 0.95 4939370.44 1.00 ``` ## Single segment Cohere-wiki-1024 1M My thought being that maybe larger vectors will make block scoring more expensive, so picking individual vectors would be better. Same methodology as above ### Candidate ``` index_name index_type n_probe latency(ms) net_cpu_time(ms) avg_cpu_count QPS recall visited filter_selectivity ---------------- ---------- ------- ----------- ---------------- ------------- ------- ------ --------- ------------------ wiki1024en.train ivf 10 0.63 0.00 0.00 1587.30 0.81 6389.60 1.00 wiki1024en.train ivf 20 0.86 0.00 0.00 1162.79 0.88 12528.48 1.00 wiki1024en.train ivf 50 1.43 0.00 0.00 699.30 0.93 30627.04 1.00 wiki1024en.train ivf 100 2.30 0.00 0.00 434.78 0.95 61259.84 1.00 wiki1024en.train ivf 200 4.12 0.00 0.00 242.72 0.97 122569.44 1.00 wiki1024en.train ivf 500 9.64 0.00 0.00 103.73 0.98 307816.80 1.00 wiki1024en.train ivf 1000 18.79 0.00 0.00 53.22 0.98 618772.32 1.00 ``` ### Baseline ``` index_name index_type n_probe latency(ms) net_cpu_time(ms) avg_cpu_count QPS recall visited filter_selectivity ---------------- ---------- ------- ----------- ---------------- ------------- ------- ------ --------- ------------------ wiki1024en.train ivf 10 0.65 0.00 0.00 1538.46 0.82 5680.72 1.00 wiki1024en.train ivf 20 0.84 0.00 0.00 1190.48 0.88 10677.40 1.00 wiki1024en.train ivf 50 1.49 0.00 0.00 671.14 0.94 24431.26 1.00 wiki1024en.train ivf 100 2.41 0.00 0.00 414.94 0.96 47000.85 1.00 wiki1024en.train ivf 200 4.56 0.00 0.00 219.30 0.97 91284.42 1.00 wiki1024en.train ivf 500 10.56 0.00 0.00 94.70 0.98 218185.33 1.00 wiki1024en.train ivf 1000 20.81 0.00 0.00 48.05 0.98 412137.05 1.00 ```
1 parent f8b2ed9 commit bfefe03

File tree

3 files changed

+33
-26
lines changed

3 files changed

+33
-26
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.apache.lucene.search.KnnCollector;
1717
import org.apache.lucene.store.IndexInput;
1818
import org.apache.lucene.util.ArrayUtil;
19+
import org.apache.lucene.util.Bits;
1920
import org.apache.lucene.util.VectorUtil;
2021
import org.apache.lucene.util.hnsw.NeighborQueue;
2122
import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats;
@@ -25,7 +26,6 @@
2526

2627
import java.io.IOException;
2728
import java.util.Map;
28-
import java.util.function.IntPredicate;
2929

3030
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS;
3131
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
@@ -294,11 +294,10 @@ private static void score(
294294
}
295295

296296
@Override
297-
PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput indexInput, float[] target, IntPredicate needsScoring)
298-
throws IOException {
297+
PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput indexInput, float[] target, Bits acceptDocs) throws IOException {
299298
FieldEntry entry = fields.get(fieldInfo.number);
300299
final int maxPostingListSize = indexInput.readVInt();
301-
return new MemorySegmentPostingsVisitor(target, indexInput, entry, fieldInfo, maxPostingListSize, needsScoring);
300+
return new MemorySegmentPostingsVisitor(target, indexInput, entry, fieldInfo, maxPostingListSize, acceptDocs);
302301
}
303302

304303
@Override
@@ -312,7 +311,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
312311
final float[] target;
313312
final FieldEntry entry;
314313
final FieldInfo fieldInfo;
315-
final IntPredicate needsScoring;
314+
final Bits acceptDocs;
316315
private final ES91OSQVectorsScorer osqVectorsScorer;
317316
final float[] scores = new float[BULK_SIZE];
318317
final float[] correctionsLower = new float[BULK_SIZE];
@@ -342,13 +341,13 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
342341
FieldEntry entry,
343342
FieldInfo fieldInfo,
344343
int maxPostingListSize,
345-
IntPredicate needsScoring
344+
Bits acceptDocs
346345
) throws IOException {
347346
this.target = target;
348347
this.indexInput = indexInput;
349348
this.entry = entry;
350349
this.fieldInfo = fieldInfo;
351-
this.needsScoring = needsScoring;
350+
this.acceptDocs = acceptDocs;
352351
centroid = new float[fieldInfo.getVectorDimension()];
353352
scratch = new float[target.length];
354353
quantizationScratch = new int[target.length];
@@ -419,11 +418,12 @@ private float scoreIndividually(int offset) throws IOException {
419418
return maxScore;
420419
}
421420

422-
private static int docToBulkScore(int[] docIds, int offset, IntPredicate needsScoring) {
421+
private static int docToBulkScore(int[] docIds, int offset, Bits acceptDocs) {
422+
assert acceptDocs != null : "acceptDocs must not be null";
423423
int docToScore = ES91OSQVectorsScorer.BULK_SIZE;
424424
for (int i = 0; i < ES91OSQVectorsScorer.BULK_SIZE; i++) {
425425
final int idx = offset + i;
426-
if (needsScoring.test(docIds[idx]) == false) {
426+
if (acceptDocs.get(docIds[idx]) == false) {
427427
docIds[idx] = -1;
428428
docToScore--;
429429
}
@@ -447,7 +447,7 @@ public int visit(KnnCollector knnCollector) throws IOException {
447447
int limit = vectors - BULK_SIZE + 1;
448448
int i = 0;
449449
for (; i < limit; i += BULK_SIZE) {
450-
final int docsToBulkScore = docToBulkScore(docIdsScratch, i, needsScoring);
450+
final int docsToBulkScore = acceptDocs == null ? BULK_SIZE : docToBulkScore(docIdsScratch, i, acceptDocs);
451451
if (docsToBulkScore == 0) {
452452
continue;
453453
}
@@ -476,7 +476,7 @@ public int visit(KnnCollector knnCollector) throws IOException {
476476
// process tail
477477
for (; i < vectors; i++) {
478478
int doc = docIdsScratch[i];
479-
if (needsScoring.test(doc)) {
479+
if (acceptDocs == null || acceptDocs.get(doc)) {
480480
quantizeQueryIfNecessary();
481481
indexInput.seek(slicePos + i * quantizedByteLength);
482482
float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch);

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,10 @@
2929
import org.apache.lucene.store.IndexInput;
3030
import org.apache.lucene.util.BitSet;
3131
import org.apache.lucene.util.Bits;
32-
import org.apache.lucene.util.FixedBitSet;
3332
import org.elasticsearch.core.IOUtils;
3433
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;
3534

3635
import java.io.IOException;
37-
import java.util.function.IntPredicate;
3836

3937
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
4038
import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.DYNAMIC_NPROBE;
@@ -224,13 +222,6 @@ public final void search(String field, float[] target, KnnCollector knnCollector
224222
percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length()));
225223
}
226224
int numVectors = rawVectorsReader.getFloatVectorValues(field).size();
227-
BitSet visitedDocs = new FixedBitSet(state.segmentInfo.maxDoc() + 1);
228-
IntPredicate needsScoring = docId -> {
229-
if (acceptDocs != null && acceptDocs.get(docId) == false) {
230-
return false;
231-
}
232-
return visitedDocs.getAndSet(docId) == false;
233-
};
234225
int nProbe = DYNAMIC_NPROBE;
235226
// Search strategy may be null if this is being called from checkIndex (e.g. from a test)
236227
if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) {
@@ -248,7 +239,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
248239
nProbe = Math.max(Math.min(nProbe, entry.numCentroids), 1);
249240
}
250241
CentroidIterator centroidIterator = getCentroidIterator(fieldInfo, entry.numCentroids, entry.centroidSlice(ivfCentroids), target);
251-
PostingVisitor scorer = getPostingVisitor(fieldInfo, entry.postingListSlice(ivfClusters), target, needsScoring);
242+
PostingVisitor scorer = getPostingVisitor(fieldInfo, entry.postingListSlice(ivfClusters), target, acceptDocs);
252243
int centroidsVisited = 0;
253244
long expectedDocs = 0;
254245
long actualDocs = 0;
@@ -316,7 +307,7 @@ IndexInput postingListSlice(IndexInput postingListFile) throws IOException {
316307
}
317308
}
318309

319-
abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postingsLists, float[] target, IntPredicate needsScoring)
310+
abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postingsLists, float[] target, Bits needsScoring)
320311
throws IOException;
321312

322313
interface CentroidIterator {

server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
package org.elasticsearch.search.vectors;
1111

12+
import com.carrotsearch.hppc.IntHashSet;
13+
1214
import org.apache.lucene.index.IndexReader;
1315
import org.apache.lucene.index.LeafReader;
1416
import org.apache.lucene.index.LeafReaderContext;
@@ -115,7 +117,10 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
115117
filterWeight = null;
116118
}
117119
// we request numCands as we are using it as an approximation measure
118-
KnnCollectorManager knnCollectorManager = getKnnCollectorManager(numCands, indexSearcher);
120+
// we need to ensure we are getting at least 2*k results to ensure we cover overspill duplicates
121+
// TODO move the logic for automatically adjusting percentages/nprobe to the query, so we can only pass
122+
// 2k to the collector.
123+
KnnCollectorManager knnCollectorManager = getKnnCollectorManager(Math.max(Math.round(2f * k), numCands), indexSearcher);
119124
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
120125
List<LeafReaderContext> leafReaderContexts = reader.leaves();
121126
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());
@@ -135,12 +140,23 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
135140

136141
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException {
137142
TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager);
138-
if (ctx.docBase > 0) {
139-
for (ScoreDoc scoreDoc : results.scoreDocs) {
143+
IntHashSet dedup = new IntHashSet(results.scoreDocs.length * 4 / 3);
144+
int deduplicateCount = 0;
145+
for (ScoreDoc scoreDoc : results.scoreDocs) {
146+
if (dedup.add(scoreDoc.doc)) {
147+
deduplicateCount++;
148+
}
149+
}
150+
ScoreDoc[] deduplicatedScoreDocs = new ScoreDoc[deduplicateCount];
151+
dedup.clear();
152+
int index = 0;
153+
for (ScoreDoc scoreDoc : results.scoreDocs) {
154+
if (dedup.add(scoreDoc.doc)) {
140155
scoreDoc.doc += ctx.docBase;
156+
deduplicatedScoreDocs[index++] = scoreDoc;
141157
}
142158
}
143-
return results;
159+
return new TopDocs(results.totalHits, deduplicatedScoreDocs);
144160
}
145161

146162
TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException {

0 commit comments

Comments
 (0)