Skip to content

Commit 0421a8d

Browse files
authored
integrate into new bulk scorers (#15021)
Return a little extra information from the bulkScore() as this will allow callers to skip results that are below the minimum similarity needed to appear in the result set. Utilize this in hnsw and exhaustive search to skip any topN heap manipulation for any bulk scored batches that do not meet the minimum score. See #14013
1 parent 1a448f1 commit 0421a8d

File tree

7 files changed

+59
-22
lines changed

7 files changed

+59
-22
lines changed

lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -361,21 +361,23 @@ private void search(
361361
}
362362
ords[numOrds++] = i;
363363
if (numOrds == ords.length) {
364-
scorer.bulkScore(ords, scores, numOrds);
365-
for (int j = 0; j < numOrds; j++) {
366-
knnCollector.incVisitedCount(1);
367-
knnCollector.collect(scorer.ordToDoc(ords[j]), scores[j]);
364+
knnCollector.incVisitedCount(numOrds);
365+
if (scorer.bulkScore(ords, scores, numOrds) > knnCollector.minCompetitiveSimilarity()) {
366+
for (int j = 0; j < numOrds; j++) {
367+
knnCollector.collect(scorer.ordToDoc(ords[j]), scores[j]);
368+
}
368369
}
369370
numOrds = 0;
370371
}
371372
}
372373
}
373374

374375
if (numOrds > 0) {
375-
scorer.bulkScore(ords, scores, numOrds);
376-
for (int j = 0; j < numOrds; j++) {
377-
knnCollector.incVisitedCount(1);
378-
knnCollector.collect(scorer.ordToDoc(ords[j]), scores[j]);
376+
knnCollector.incVisitedCount(numOrds);
377+
if (scorer.bulkScore(ords, scores, numOrds) > knnCollector.minCompetitiveSimilarity()) {
378+
for (int j = 0; j < numOrds; j++) {
379+
knnCollector.collect(scorer.ordToDoc(ords[j]), scores[j]);
380+
}
379381
}
380382
}
381383
}

lucene/core/src/java/org/apache/lucene/search/AbstractKnnCollector.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public final boolean earlyTerminated() {
4747

4848
@Override
4949
public final void incVisitedCount(int count) {
50-
assert count > 0;
50+
assert count >= 0;
5151
this.visitedCount += count;
5252
}
5353

lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -335,10 +335,11 @@ void searchLevel(
335335
bulkNodes[numNodes++] = friendOrd;
336336
}
337337

338-
if (numNodes > 0) {
339-
numNodes = (int) Math.min((long) numNodes, results.visitLimit() - results.visitedCount());
340-
scorer.bulkScore(bulkNodes, bulkScores, numNodes);
341-
results.incVisitedCount(numNodes);
338+
numNodes = (int) Math.min((long) numNodes, results.visitLimit() - results.visitedCount());
339+
results.incVisitedCount(numNodes);
340+
if (numNodes > 0
341+
&& scorer.bulkScore(bulkNodes, bulkScores, numNodes)
342+
> results.minCompetitiveSimilarity()) {
342343
for (int i = 0; i < numNodes; i++) {
343344
int node = bulkNodes[i];
344345
float score = bulkScores[i];

lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,15 @@ public interface RandomVectorScorer {
4242
* @param nodes array of nodes to score.
4343
* @param scores output array of scores corresponding to each node.
4444
* @param numNodes number of nodes to score. Must not exceed length of nodes or scores arrays.
45+
* @return the maximum scored value of any node, or Float.NEGATIVE_INFINITY if numNodes == 0.
4546
*/
46-
default void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException {
47+
default float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException {
48+
float max = Float.NEGATIVE_INFINITY;
4749
for (int i = 0; i < numNodes; i++) {
4850
scores[i] = score(nodes[i]);
51+
max = Math.max(max, scores[i]);
4952
}
53+
return max;
5054
}
5155

5256
/**

lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFloatVectorScorer.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,19 +80,24 @@ final void checkOrdinal(int ord) {
8080
}
8181

8282
@Override
83-
public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException {
83+
public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException {
8484
int i = 0;
8585
final int limit = numNodes & ~3;
86+
float maxScore = Float.NEGATIVE_INFINITY;
8687
for (; i < limit; i += 4) {
8788
long offset1 = (long) nodes[i] * vectorByteSize;
8889
long offset2 = (long) nodes[i + 1] * vectorByteSize;
8990
long offset3 = (long) nodes[i + 2] * vectorByteSize;
9091
long offset4 = (long) nodes[i + 3] * vectorByteSize;
9192
vectorOp(seg, scratchScores, offset1, offset2, offset3, offset4, query.length);
9293
scores[i + 0] = normalizeRawScore(scratchScores[0]);
94+
maxScore = Math.max(maxScore, scores[i + 0]);
9395
scores[i + 1] = normalizeRawScore(scratchScores[1]);
96+
maxScore = Math.max(maxScore, scores[i + 1]);
9497
scores[i + 2] = normalizeRawScore(scratchScores[2]);
98+
maxScore = Math.max(maxScore, scores[i + 2]);
9599
scores[i + 3] = normalizeRawScore(scratchScores[3]);
100+
maxScore = Math.max(maxScore, scores[i + 3]);
96101
}
97102
// Handle remaining 1–3 nodes in bulk (if any)
98103
int remaining = numNodes - i;
@@ -102,9 +107,17 @@ public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOExcept
102107
long addr3 = (remaining > 2) ? (long) nodes[i + 2] * vectorByteSize : addr1;
103108
vectorOp(seg, scratchScores, addr1, addr2, addr3, addr3, query.length);
104109
scores[i] = normalizeRawScore(scratchScores[0]);
105-
if (remaining > 1) scores[i + 1] = normalizeRawScore(scratchScores[1]);
106-
if (remaining > 2) scores[i + 2] = normalizeRawScore(scratchScores[2]);
110+
maxScore = Math.max(maxScore, scores[i]);
111+
if (remaining > 1) {
112+
scores[i + 1] = normalizeRawScore(scratchScores[1]);
113+
maxScore = Math.max(maxScore, scores[i + 1]);
114+
}
115+
if (remaining > 2) {
116+
scores[i + 2] = normalizeRawScore(scratchScores[2]);
117+
maxScore = Math.max(maxScore, scores[i + 1]);
118+
}
107119
}
120+
return maxScore;
108121
}
109122

110123
abstract void vectorOp(

lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFloatVectorScorerSupplier.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,9 +289,10 @@ public float score(int node) {
289289
}
290290

291291
@Override
292-
public void bulkScore(int[] nodes, float[] scores, int numNodes) {
292+
public float bulkScore(int[] nodes, float[] scores, int numNodes) {
293293
int i = 0;
294294
long queryAddr = (long) queryOrd * vectorByteSize;
295+
float maxScore = Float.NEGATIVE_INFINITY;
295296
final int limit = numNodes & ~3;
296297
for (; i < limit; i += 4) {
297298
long offset1 = (long) nodes[i] * vectorByteSize;
@@ -300,9 +301,13 @@ public void bulkScore(int[] nodes, float[] scores, int numNodes) {
300301
long offset4 = (long) nodes[i + 3] * vectorByteSize;
301302
vectorOp(seg, scratchScores, queryAddr, offset1, offset2, offset3, offset4, dims);
302303
scores[i + 0] = normalizeRawScore(scratchScores[0]);
304+
maxScore = Math.max(maxScore, scores[i + 0]);
303305
scores[i + 1] = normalizeRawScore(scratchScores[1]);
306+
maxScore = Math.max(maxScore, scores[i + 1]);
304307
scores[i + 2] = normalizeRawScore(scratchScores[2]);
308+
maxScore = Math.max(maxScore, scores[i + 2]);
305309
scores[i + 3] = normalizeRawScore(scratchScores[3]);
310+
maxScore = Math.max(maxScore, scores[i + 3]);
306311
}
307312
// Handle remaining 1–3 nodes in bulk (if any)
308313
int remaining = numNodes - i;
@@ -312,9 +317,17 @@ public void bulkScore(int[] nodes, float[] scores, int numNodes) {
312317
long addr3 = (remaining > 2) ? (long) nodes[i + 2] * vectorByteSize : addr1;
313318
vectorOp(seg, scratchScores, queryAddr, addr1, addr2, addr3, addr1, dims);
314319
scores[i] = normalizeRawScore(scratchScores[0]);
315-
if (remaining > 1) scores[i + 1] = normalizeRawScore(scratchScores[1]);
316-
if (remaining > 2) scores[i + 2] = normalizeRawScore(scratchScores[2]);
320+
maxScore = Math.max(maxScore, scores[i]);
321+
if (remaining > 1) {
322+
scores[i + 1] = normalizeRawScore(scratchScores[1]);
323+
maxScore = Math.max(maxScore, scores[i + 1]);
324+
}
325+
if (remaining > 2) {
326+
scores[i + 2] = normalizeRawScore(scratchScores[2]);
327+
maxScore = Math.max(maxScore, scores[i + 2]);
328+
}
317329
}
330+
return maxScore;
318331
}
319332

320333
@Override

lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,13 @@ void assertBulkEqualsNonBulk(KnnVectorValues values, VectorSimilarityFunction si
260260
: flatVectorsScorer.getRandomVectorScorer(sim, values, randomFloatVector(dims));
261261
int[] indices = randomIndices(size);
262262
float[] expectedScores = new float[size];
263+
float expectedMaxScore = Float.NEGATIVE_INFINITY;
263264
for (int i = 0; i < size; i++) {
264265
expectedScores[i] = scorer.score(indices[i]);
266+
expectedMaxScore = Math.max(expectedMaxScore, expectedScores[i]);
265267
}
266268
float[] bulkScores = new float[size];
267-
scorer.bulkScore(indices, bulkScores, size);
269+
assertEquals(expectedMaxScore, scorer.bulkScore(indices, bulkScores, size), 0.001);
268270
assertArrayEquals(expectedScores, bulkScores, delta);
269271
assertNoScoreBeyondNumNodes(scorer, size);
270272
}
@@ -302,16 +304,18 @@ void assertScoresAgainstDefaultFlatScorer(KnnVectorValues values, VectorSimilari
302304
DefaultFlatVectorScorer.INSTANCE.getRandomVectorScorerSupplier(sim, values).scorer();
303305
defaultScorer.setScoringOrdinal(targetNode);
304306
float[] expectedScores = new float[size];
307+
float expectedMaxScore = Float.NEGATIVE_INFINITY;
305308
for (int i = 0; i < size; i++) {
306309
expectedScores[i] = defaultScorer.score(indices[i]);
310+
expectedMaxScore = Math.max(expectedMaxScore, expectedScores[i]);
307311
}
308312

309313
var supplier = flatVectorsScorer.getRandomVectorScorerSupplier(sim, values);
310314
for (var ss : List.of(supplier, supplier.copy())) {
311315
var updatableScorer = ss.scorer();
312316
updatableScorer.setScoringOrdinal(targetNode);
313317
float[] bulkScores = new float[size];
314-
updatableScorer.bulkScore(indices, bulkScores, size);
318+
assertEquals(expectedMaxScore, updatableScorer.bulkScore(indices, bulkScores, size), 0.001);
315319
assertArrayEquals(expectedScores, bulkScores, delta);
316320
}
317321
}

0 commit comments

Comments
 (0)