Skip to content

Commit 18c73e0

Browse files
authored
Add bulk scoring option to VectorScorer interface (#15171)
Now that we have bulk scoring for our RandomScorer's, let's enable it for our VectorScorer interface. I only implemented it for float32 right now, as that is the only place where we are actually doing bulk scoring. But once the interface is in place, I imagine it can be easily implemented elsewhere. This shows a marginal performance improvement: baseline ``` recall latency(ms) netCPU avgCpuCount nDoc topK fanout maxConn beamWidth quantized index(s) index_docs/s force_merge(s) num_segments index_size(MB) vec_disk(MB) vec_RAM(MB) indexType 1.000 0.880 0.760 0.864 1000000 500 550 16 100 no 0.00 Infinity 0.06 0 0.00 0.000 0.000 HNSW ``` candidate: ``` recall latency(ms) netCPU avgCpuCount nDoc topK fanout maxConn beamWidth quantized index(s) index_docs/s force_merge(s) num_segments index_size(MB) vec_disk(MB) vec_RAM(MB) indexType 1.000 0.780 0.640 0.821 1000000 500 550 16 100 no 0.00 Infinity 0.06 0 0.00 0.000 0.000 HNSW ```
1 parent 0421a8d commit 18c73e0

File tree

5 files changed

+150
-12
lines changed

5 files changed

+150
-12
lines changed

lucene/CHANGES.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ Optimizations
141141
* GITHUB#15160: Increased the size used for blocks of postings from 128 to 256.
142142
This gives a noticeable speedup to many queries. (Adrien Grand)
143143

144+
* GITHUB#15171: Add `VectorScorer.Bulk` for bulk iteration and scoring of vectors. This new interface is now
145+
used in AbstractKnnVectorQuery to make exact matches faster. (Ben Trent)
146+
144147
Bug Fixes
145148
---------------------
146149
* GITHUB#14161: PointInSetQuery's constructor now throws IllegalArgumentException

lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,19 @@
1818
package org.apache.lucene.codecs.lucene95;
1919

2020
import java.io.IOException;
21+
import java.util.List;
2122
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
2223
import org.apache.lucene.codecs.lucene90.IndexedDISI;
2324
import org.apache.lucene.index.FloatVectorValues;
2425
import org.apache.lucene.index.VectorEncoding;
2526
import org.apache.lucene.index.VectorSimilarityFunction;
27+
import org.apache.lucene.search.ConjunctionUtils;
28+
import org.apache.lucene.search.DocAndFloatFeatureBuffer;
2629
import org.apache.lucene.search.DocIdSetIterator;
2730
import org.apache.lucene.search.VectorScorer;
2831
import org.apache.lucene.store.IndexInput;
2932
import org.apache.lucene.store.RandomAccessInput;
33+
import org.apache.lucene.util.ArrayUtil;
3034
import org.apache.lucene.util.Bits;
3135
import org.apache.lucene.util.hnsw.RandomVectorScorer;
3236
import org.apache.lucene.util.packed.DirectMonotonicReader;
@@ -173,6 +177,30 @@ public float score() throws IOException {
173177
public DocIdSetIterator iterator() {
174178
return iterator;
175179
}
180+
181+
@Override
182+
public VectorScorer.Bulk bulk(DocIdSetIterator matchingDocs) {
183+
final DocIdSetIterator matches =
184+
matchingDocs == null
185+
? iterator
186+
: ConjunctionUtils.createConjunction(List.of(matchingDocs, iterator), List.of());
187+
return (nextCount, liveDocs, buffer) -> {
188+
if (matches.docID() == -1) {
189+
matches.nextDoc();
190+
}
191+
buffer.growNoCopy(nextCount);
192+
int size = 0;
193+
for (int doc = matches.docID();
194+
doc != DocIdSetIterator.NO_MORE_DOCS && size < nextCount;
195+
doc = matches.nextDoc()) {
196+
if (liveDocs == null || liveDocs.get(doc)) {
197+
buffer.docs[size++] = doc;
198+
}
199+
}
200+
buffer.size = size;
201+
return randomVectorScorer.bulkScore(buffer.docs, buffer.features, size);
202+
};
203+
}
176204
};
177205
}
178206
}
@@ -266,6 +294,43 @@ public float score() throws IOException {
266294
public DocIdSetIterator iterator() {
267295
return iterator;
268296
}
297+
298+
@Override
299+
public VectorScorer.Bulk bulk(DocIdSetIterator matchingDocs) {
300+
return new Bulk() {
301+
final DocIdSetIterator matches =
302+
matchingDocs == null
303+
? iterator
304+
: ConjunctionUtils.createConjunction(
305+
List.of(matchingDocs, iterator), List.of());
306+
int[] docIds = new int[0];
307+
308+
@Override
309+
public float nextDocsAndScores(
310+
int nextCount, Bits liveDocs, DocAndFloatFeatureBuffer buffer) throws IOException {
311+
if (matches.docID() == -1) {
312+
matches.nextDoc();
313+
}
314+
buffer.growNoCopy(nextCount);
315+
docIds = ArrayUtil.growNoCopy(docIds, nextCount);
316+
int size = 0;
317+
for (int doc = matches.docID();
318+
doc != DocIdSetIterator.NO_MORE_DOCS && size < nextCount;
319+
doc = matches.nextDoc()) {
320+
if (liveDocs == null || liveDocs.get(doc)) {
321+
buffer.docs[size] = iterator.index();
322+
docIds[size] = doc;
323+
++size;
324+
}
325+
}
326+
buffer.size = size;
327+
float maxScore = randomVectorScorer.bulkScore(buffer.docs, buffer.features, size);
328+
// copy back the real doc IDs
329+
System.arraycopy(docIds, 0, buffer.docs, 0, size);
330+
return maxScore;
331+
}
332+
};
333+
}
269334
};
270335
}
271336
}

lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,8 @@ private void search(
339339
int filteredDocCount = Math.min(acceptDocs.cost(), graph.size());
340340
Bits accepted = acceptDocs.bits();
341341
final Bits acceptedOrds = scorer.getAcceptOrds(accepted);
342-
boolean doHnsw = knnCollector.k() < scorer.maxOrd();
342+
int numVectors = scorer.maxOrd();
343+
boolean doHnsw = knnCollector.k() < numVectors;
343344
// The approximate number of vectors that would be visited if we did not filter
344345
int unfilteredVisit = HnswGraphSearcher.expectedVisitedNodes(knnCollector.k(), graph.size());
345346
if (unfilteredVisit >= filteredDocCount) {
@@ -354,7 +355,7 @@ private void search(
354355
int[] ords = new int[EXHAUSTIVE_BULK_SCORE_ORDS];
355356
float[] scores = new float[EXHAUSTIVE_BULK_SCORE_ORDS];
356357
int numOrds = 0;
357-
for (int i = 0; i < scorer.maxOrd(); i++) {
358+
for (int i = 0; i < numVectors; i++) {
358359
if (acceptedOrds == null || acceptedOrds.get(i)) {
359360
if (knnCollector.earlyTerminated()) {
360361
break;

lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -309,21 +309,28 @@ protected TopDocs exactSearch(
309309
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
310310
ScoreDoc topDoc = queue.top();
311311
DocIdSetIterator vectorIterator = vectorScorer.iterator();
312-
DocIdSetIterator conjunction =
313-
ConjunctionDISI.createConjunction(List.of(vectorIterator, acceptIterator), List.of());
314-
int doc;
315-
while ((doc = conjunction.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
312+
DocAndFloatFeatureBuffer buffer = new DocAndFloatFeatureBuffer();
313+
VectorScorer.Bulk bulkScorer = vectorScorer.bulk(acceptIterator);
314+
while (vectorIterator.docID() != DocIdSetIterator.NO_MORE_DOCS) {
316315
// Mark results as partial if timeout is met
317316
if (queryTimeout != null && queryTimeout.shouldExit()) {
318317
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
319318
break;
320319
}
321-
assert vectorIterator.docID() == doc;
322-
float score = vectorScorer.score();
323-
if (score > topDoc.score) {
324-
topDoc.score = score;
325-
topDoc.doc = doc;
326-
topDoc = queue.updateTop();
320+
// iterator already takes live docs into account
321+
float maxScore = bulkScorer.nextDocsAndScores(64, null, buffer);
322+
if (maxScore < topDoc.score) {
323+
// all the scores in this batch are too low, skip
324+
continue;
325+
}
326+
for (int i = 0; i < buffer.size; i++) {
327+
float score = buffer.features[i];
328+
int doc = buffer.docs[i];
329+
if (score > topDoc.score) {
330+
topDoc.score = score;
331+
topDoc.doc = doc;
332+
topDoc = queue.updateTop();
333+
}
327334
}
328335
}
329336

lucene/core/src/java/org/apache/lucene/search/VectorScorer.java

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
package org.apache.lucene.search;
1818

1919
import java.io.IOException;
20+
import java.util.List;
21+
import org.apache.lucene.util.Bits;
2022

2123
/**
2224
* Computes the similarity score between a given query vector and different document vectors. This
@@ -38,4 +40,64 @@ public interface VectorScorer {
3840
* @return a {@link DocIdSetIterator} over the documents.
3941
*/
4042
DocIdSetIterator iterator();
43+
44+
/**
45+
* An optional bulk scorer implementation that allows bulk scoring over the provided matching
46+
* docs. The iterator of this instance of VectorScorer should be used and iterated in conjunction
47+
* with the provided matchingDocs iterator to score only the documents that are present in both
48+
* iterators. If the provided matchingDocs iterator is null, then all documents should be scored.
49+
* Additionally, if the iterators are unpositioned (docID() == -1), this method should position
50+
* them to the first document.
51+
*
52+
* @param matchingDocs the documents to score
53+
* @return a {@link Bulk} scorer
54+
* @throws IOException if an exception occurs during bulk scorer creation
55+
* @lucene.experimental
56+
*/
57+
default Bulk bulk(DocIdSetIterator matchingDocs) throws IOException {
58+
final DocIdSetIterator iterator =
59+
matchingDocs == null
60+
? iterator()
61+
: ConjunctionUtils.createConjunction(List.of(matchingDocs, iterator()), List.of());
62+
if (iterator.docID() == -1) {
63+
iterator.nextDoc();
64+
}
65+
return (nextCount, liveDocs, buffer) -> {
66+
buffer.growNoCopy(nextCount);
67+
int size = 0;
68+
float maxScore = Float.NEGATIVE_INFINITY;
69+
for (int doc = iterator.docID();
70+
doc != DocIdSetIterator.NO_MORE_DOCS && size < nextCount;
71+
doc = iterator.nextDoc()) {
72+
if (liveDocs == null || liveDocs.get(doc)) {
73+
buffer.docs[size] = doc;
74+
buffer.features[size] = score();
75+
maxScore = Math.max(maxScore, buffer.features[size]);
76+
++size;
77+
}
78+
}
79+
buffer.size = size;
80+
return maxScore;
81+
};
82+
}
83+
84+
/**
85+
* Bulk scorer interface to score multiple vectors at once
86+
*
87+
* @lucene.experimental
88+
*/
89+
interface Bulk {
90+
/**
91+
* Score up to nextCount documents, store the results in the provided buffer. Behaves similarly
92+
* to {@link Scorer#nextDocsAndScores(int, Bits, DocAndFloatFeatureBuffer)}
93+
*
94+
* @param nextCount the maximum number of documents to score
95+
* @param liveDocs the live docs, or null if all docs are live
96+
* @param buffer the buffer to store the results
97+
* @return the max score of the scored documents
98+
* @throws IOException if an exception occurs during scoring
99+
*/
100+
float nextDocsAndScores(int nextCount, Bits liveDocs, DocAndFloatFeatureBuffer buffer)
101+
throws IOException;
102+
}
41103
}

0 commit comments

Comments
 (0)