Skip to content

Commit 7fc9fd3

Browse files
authored
Further improve filtering by score. (#14970)
PRs #14906 and #14896 improved the efficiency of filtering by score. This PR tries to get some extra speedup by: - Skipping filtering by score when applying a non-essential clause that doesn't have matches over the range of doc IDs being scored. - Filtering on float[] scores rather than double[] scores whenever applicable so that vectorization can work on 2x more lanes at once. - Filtering by score using `VectorUtil#filterByScore` instead of relying on the collector to do it.
1 parent 7fe43de commit 7fc9fd3

File tree

10 files changed

+237
-15
lines changed

10 files changed

+237
-15
lines changed

lucene/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@ Optimizations
218218

219219
* GITHUB#14976: Utilize docIdRunEnd on ReqExclBulkScorer. (Ge Song)
220220

221+
* GITHUB#14970: Further speed up filtering hits by score. (Adrien Grand)
222+
221223
Changes in Runtime Behavior
222224
---------------------
223225
* GITHUB#14823: Decrease TieredMergePolicy's default number of segments per

lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,22 @@ private float quantizeFloat(float v, byte[] dest, int destIndex) {
310310
}
311311
}
312312

313+
@Override
314+
public int filterByScore(
315+
int[] docBuffer, float[] scoreBuffer, float minScoreInclusive, int upTo) {
316+
int newSize = 0;
317+
for (int i = 0; i < upTo; ++i) {
318+
int doc = docBuffer[i];
319+
float score = scoreBuffer[i];
320+
docBuffer[newSize] = doc;
321+
scoreBuffer[newSize] = score;
322+
if (score >= minScoreInclusive) {
323+
newSize++;
324+
}
325+
}
326+
return newSize;
327+
}
328+
313329
@Override
314330
public int filterByScore(
315331
int[] docBuffer, double[] scoreBuffer, double minScoreInclusive, int upTo) {

lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,20 @@ float recalculateScalarQuantizationOffset(
101101
float minQuantile,
102102
float maxQuantile);
103103

104+
/**
105+
* filter both {@code docBuffer} and {@code scoreBuffer} with {@code minScoreInclusive}, each
106+
* {@code docBuffer} and {@code scoreBuffer} of the same index forms a pair, pairs with score not
107+
* greater than or equal to {@code minScoreInclusive} will be filtered out from the array.
108+
*
109+
* @param docBuffer doc buffer contains docs (or some other value forms a pair with {@code
110+
* scoreBuffer})
111+
* @param scoreBuffer score buffer contains scores to be compared with {@code minScoreInclusive}
112+
* @param minScoreInclusive minimal required score to not be filtered out
113+
* @param upTo where the filter should end
114+
* @return how many pairs left after filter
115+
*/
116+
int filterByScore(int[] docBuffer, float[] scoreBuffer, float minScoreInclusive, int upTo);
117+
104118
/**
105119
* filter both {@code docBuffer} and {@code scoreBuffer} with {@code minScoreInclusive}, each
106120
* {@code docBuffer} and {@code scoreBuffer} of the same index forms a pair, pairs with score not

lucene/core/src/java/org/apache/lucene/search/BatchScoreBulkScorer.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.io.IOException;
2020
import org.apache.lucene.util.Bits;
21+
import org.apache.lucene.util.VectorUtil;
2122

2223
/**
2324
* A bulk scorer used when {@link ScoreMode#needsScores()} is true and {@link
@@ -49,11 +50,16 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr
4950
for (scorer.nextDocsAndScores(max, acceptDocs, buffer);
5051
buffer.size > 0;
5152
scorer.nextDocsAndScores(max, acceptDocs, buffer)) {
53+
54+
// The collector already filters hits whose scores is less than the minimum competitive score,
55+
// but doing it here is a bit more efficient.
56+
buffer.size =
57+
VectorUtil.filterByScore(
58+
buffer.docs, buffer.features, scorable.minCompetitiveScore, buffer.size);
59+
5260
for (int i = 0, size = buffer.size; i < size; i++) {
53-
float score = scorable.score = buffer.features[i];
54-
if (score >= scorable.minCompetitiveScore) {
55-
collector.collect(buffer.docs[i]);
56-
}
61+
scorable.score = buffer.features[i];
62+
collector.collect(buffer.docs[i]);
5763
}
5864
scorer.setMinCompetitiveScore(scorable.minCompetitiveScore);
5965
}

lucene/core/src/java/org/apache/lucene/search/BlockMaxConjunctionBulkScorer.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,15 +166,24 @@ private void scoreWindowScoreFirst(
166166
return;
167167
}
168168

169+
// two equal consecutive values mean that the first clause always returns a score of zero, so we
170+
// don't need to filter hits by score again.
171+
boolean leadingClauseHasZeroScores = sumOfOtherClauses[1] == sumOfOtherClauses[0];
172+
169173
for (scorers[0].nextDocsAndScores(max, acceptDocs, docAndScoreBuffer);
170174
docAndScoreBuffer.size > 0;
171175
scorers[0].nextDocsAndScores(max, acceptDocs, docAndScoreBuffer)) {
172176

177+
if (leadingClauseHasZeroScores == false) {
178+
ScorerUtil.filterCompetitiveHits(
179+
docAndScoreBuffer, sumOfOtherClauses[1], scorable.minCompetitiveScore, scorers.length);
180+
}
181+
173182
docAndScoreAccBuffer.copyFrom(docAndScoreBuffer);
174183

175184
for (int i = 1; i < scorers.length; ++i) {
176185
double sumOfOtherClause = sumOfOtherClauses[i];
177-
if (sumOfOtherClause != sumOfOtherClauses[i - 1]) {
186+
if (i > 1 && sumOfOtherClause != sumOfOtherClauses[i - 1]) {
178187
// two equal consecutive values mean that the first clause always returns a score of zero,
179188
// so we don't need to filter hits by score again.
180189
ScorerUtil.filterCompetitiveHits(

lucene/core/src/java/org/apache/lucene/search/MaxScoreBulkScorer.java

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.apache.lucene.util.Bits;
2323
import org.apache.lucene.util.FixedBitSet;
2424
import org.apache.lucene.util.MathUtil;
25+
import org.apache.lucene.util.VectorUtil;
2526

2627
final class MaxScoreBulkScorer extends BulkScorer {
2728

@@ -40,6 +41,8 @@ final class MaxScoreBulkScorer extends BulkScorer {
4041
// Index of the first scorer that is required, this scorer and all following scorers are required
4142
// for a document to match.
4243
int firstRequiredScorer;
44+
// Index of the first scorer that may produce positive scores on this window.
45+
int firstNonNullScorer;
4346
// The minimum value of minCompetitiveScore that would produce a more favorable partitioning.
4447
float nextMinCompetitiveScore;
4548
private final long cost;
@@ -230,8 +233,39 @@ private void scoreInnerWindowSingleEssentialClause(
230233
docAndScoreBuffer.size > 0;
231234
top.scorer.nextDocsAndScores(upTo, acceptDocs, docAndScoreBuffer)) {
232235

233-
docAndScoreAccBuffer.copyFrom(docAndScoreBuffer);
234-
scoreNonEssentialClauses(collector, docAndScoreAccBuffer, firstEssentialScorer);
236+
if (firstNonNullScorer >= firstEssentialScorer) {
237+
// Note: firstNonNullScorer may be > firstEssentialScorer if minCompetitiveScore=0 since
238+
// hits with a score of 0 are still competitive
239+
// There are no non-essential clauses, filter non-competitive hits and collect directly
240+
241+
int[] docs = docAndScoreBuffer.docs;
242+
float[] scores = docAndScoreBuffer.features;
243+
int size = docAndScoreBuffer.size;
244+
size = VectorUtil.filterByScore(docs, scores, scorable.minCompetitiveScore, size);
245+
246+
for (int i = 0; i < size; ++i) {
247+
scorable.score = scores[i];
248+
collector.collect(docs[i]);
249+
}
250+
} else {
251+
// Filter based on float scores before promoting them to doubles so that vectorization can
252+
// work on 2x more values at once.
253+
ScorerUtil.filterCompetitiveHits(
254+
docAndScoreBuffer,
255+
maxScoreSums[firstEssentialScorer - 1],
256+
scorable.minCompetitiveScore,
257+
allScorers.length);
258+
259+
docAndScoreAccBuffer.copyFrom(docAndScoreBuffer);
260+
261+
// Apply the last non-essential clause here instead of delegating it to
262+
// `scoreNonEssentialClauses` so that it doesn't re-do filtering by score.
263+
DisiWrapper scorer = allScorers[firstEssentialScorer - 1];
264+
ScorerUtil.applyOptionalClause(docAndScoreAccBuffer, scorer.iterator, scorer.scorable);
265+
scorer.doc = scorer.iterator.docID();
266+
267+
scoreNonEssentialClauses(collector, docAndScoreAccBuffer, firstEssentialScorer - 1);
268+
}
235269
}
236270

237271
top.doc = top.iterator.docID();
@@ -250,11 +284,19 @@ private void scoreInnerWindowAsConjunction(LeafCollector collector, Bits acceptD
250284
docAndScoreBuffer.size > 0;
251285
lead1.scorer.nextDocsAndScores(max, acceptDocs, docAndScoreBuffer)) {
252286

287+
// Filter based on float scores before promoting them to doubles so that vectorization can
288+
// work on 2x more values at once.
289+
ScorerUtil.filterCompetitiveHits(
290+
docAndScoreBuffer,
291+
maxScoreSums[allScorers.length - 2],
292+
scorable.minCompetitiveScore,
293+
allScorers.length);
294+
253295
docAndScoreAccBuffer.copyFrom(docAndScoreBuffer);
254296

255297
for (int i = allScorers.length - 2; i >= firstRequiredScorer; --i) {
256298

257-
if (scorable.minCompetitiveScore > 0) {
299+
if (i < allScorers.length - 2 && scorable.minCompetitiveScore > 0) {
258300
ScorerUtil.filterCompetitiveHits(
259301
docAndScoreAccBuffer,
260302
maxScoreSums[i],
@@ -371,7 +413,7 @@ private void scoreNonEssentialClauses(
371413
throws IOException {
372414
numCandidates += buffer.size;
373415

374-
for (int i = numNonEssentialClauses - 1; i >= 0; --i) {
416+
for (int i = numNonEssentialClauses - 1; i >= firstNonNullScorer; --i) {
375417
DisiWrapper scorer = allScorers[i];
376418
assert scorable.minCompetitiveScore > 0
377419
: "All clauses are essential if minCompetitiveScore is equal to zero";
@@ -381,9 +423,20 @@ private void scoreNonEssentialClauses(
381423
scorer.doc = scorer.iterator.docID();
382424
}
383425

384-
for (int i = 0; i < buffer.size; ++i) {
385-
scorable.score = (float) buffer.scores[i];
386-
collector.collect(buffer.docs[i]);
426+
// The collector already filters hits whose score is less than the min competitive score, but
427+
// doing it here is a bit more efficient.
428+
int size = buffer.size;
429+
int[] docs = buffer.docs;
430+
docAndScoreBuffer.growNoCopy(size);
431+
float[] scores = docAndScoreBuffer.features;
432+
for (int i = 0; i < size; ++i) {
433+
scores[i] = (float) buffer.scores[i];
434+
}
435+
size = VectorUtil.filterByScore(docs, scores, scorable.minCompetitiveScore, size);
436+
437+
for (int i = 0; i < size; ++i) {
438+
scorable.score = scores[i];
439+
collector.collect(docs[i]);
387440
}
388441
}
389442

@@ -408,10 +461,14 @@ boolean partitionScorers() {
408461
(double) scorer2.maxWindowScore / Math.max(1L, scorer2.cost));
409462
});
410463
double maxScoreSum = 0;
464+
firstNonNullScorer = 0;
411465
firstEssentialScorer = 0;
412466
nextMinCompetitiveScore = Float.POSITIVE_INFINITY;
413467
for (int i = 0; i < allScorers.length; ++i) {
414468
final DisiWrapper w = scratch[i];
469+
if (w.maxWindowScore == 0f) {
470+
firstNonNullScorer = i + 1;
471+
}
415472
double newMaxScoreSum = maxScoreSum + w.maxWindowScore;
416473
float maxScoreSumFloat =
417474
(float) MathUtil.sumUpperBound(newMaxScoreSum, firstEssentialScorer + 1);

lucene/core/src/java/org/apache/lucene/search/ScorerUtil.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,34 @@ static double minRequiredScore(
138138
return minRequiredScore;
139139
}
140140

141+
/**
142+
* Filters competitive hits from the provided {@link DocAndFloatFeatureBuffer}.
143+
*
144+
* <p>This method removes documents from the buffer that cannot possibly have a score competitive
145+
* enough to exceed the minimum competitive score, given the maximum remaining score and the
146+
* number of scorers.
147+
*/
148+
static void filterCompetitiveHits(
149+
DocAndFloatFeatureBuffer buffer,
150+
double maxRemainingScore,
151+
float minCompetitiveScore,
152+
int numScorers) {
153+
double minRequiredScoreDouble =
154+
minRequiredScore(maxRemainingScore, minCompetitiveScore, numScorers);
155+
float minRequiredScoreFloat = (float) minRequiredScoreDouble;
156+
if ((double) minRequiredScoreFloat > minRequiredScoreDouble) { // the cast rounded up
157+
minRequiredScoreFloat = Math.nextDown(minRequiredScoreFloat);
158+
}
159+
assert (double) minRequiredScoreFloat <= minRequiredScoreDouble;
160+
161+
if (minRequiredScoreFloat <= 0) {
162+
return;
163+
}
164+
165+
buffer.size =
166+
VectorUtil.filterByScore(buffer.docs, buffer.features, minRequiredScoreFloat, buffer.size);
167+
}
168+
141169
/**
142170
* Filters competitive hits from the provided {@link DocAndScoreAccBuffer}.
143171
*

lucene/core/src/java/org/apache/lucene/util/VectorUtil.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,27 @@ public static float recalculateOffset(
377377
vector, oldAlpha, oldMinQuantile, scale, alpha, minQuantile, maxQuantile);
378378
}
379379

380+
/**
381+
* filter both {@code docBuffer} and {@code scoreBuffer} with {@code minScoreInclusive}, each
382+
* {@code docBuffer} and {@code scoreBuffer} of the same index forms a pair, pairs with score not
383+
* greater than or equal to {@code minScoreInclusive} will be filtered out from the array.
384+
*
385+
* @param docBuffer doc buffer contains docs (or some other value forms a pair with {@code
386+
* scoreBuffer})
387+
* @param scoreBuffer score buffer contains scores to be compared with {@code minScoreInclusive}
388+
* @param minScoreInclusive minimal required score to not be filtered out
389+
* @param upTo where the filter should end
390+
* @return how many pairs left after filter
391+
*/
392+
public static int filterByScore(
393+
int[] docBuffer, float[] scoreBuffer, float minScoreInclusive, int upTo) {
394+
if (docBuffer.length < upTo || scoreBuffer.length < upTo) {
395+
throw new IllegalArgumentException(
396+
"docBuffer and scoreBuffer should be at least as long as upTo");
397+
}
398+
return IMPL.filterByScore(docBuffer, scoreBuffer, minScoreInclusive, upTo);
399+
}
400+
380401
/**
381402
* filter both {@code docBuffer} and {@code scoreBuffer} with {@code minScoreInclusive}, each
382403
* {@code docBuffer} and {@code scoreBuffer} of the same index forms a pair, pairs with score not
@@ -391,9 +412,9 @@ public static float recalculateOffset(
391412
*/
392413
public static int filterByScore(
393414
int[] docBuffer, double[] scoreBuffer, double minScoreInclusive, int upTo) {
394-
if (docBuffer.length != scoreBuffer.length || docBuffer.length < upTo) {
415+
if (docBuffer.length < upTo || scoreBuffer.length < upTo) {
395416
throw new IllegalArgumentException(
396-
"docBuffer and scoreBuffer should keep same length and at least as long as upTo");
417+
"docBuffer and scoreBuffer should be at least as long as upTo");
397418
}
398419
return IMPL.filterByScore(docBuffer, scoreBuffer, minScoreInclusive, upTo);
399420
}

lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,6 +1005,35 @@ public float recalculateScalarQuantizationOffset(
10051005
return correction;
10061006
}
10071007

1008+
@SuppressForbidden(reason = "Uses compress and cast only where fast and carefully contained")
1009+
@Override
1010+
public int filterByScore(
1011+
int[] docBuffer, float[] scoreBuffer, float minScoreInclusive, int upTo) {
1012+
int newUpto = 0;
1013+
int i = 0;
1014+
if (Constants.HAS_FAST_COMPRESS_MASK_CAST) {
1015+
for (int bound = FLOAT_SPECIES.loopBound(upTo); i < bound; i += FLOAT_SPECIES.length()) {
1016+
FloatVector scoreVector = FloatVector.fromArray(FLOAT_SPECIES, scoreBuffer, i);
1017+
IntVector docVector = IntVector.fromArray(INT_SPECIES, docBuffer, i);
1018+
VectorMask<Float> mask = scoreVector.compare(VectorOperators.GE, minScoreInclusive);
1019+
scoreVector.compress(mask).intoArray(scoreBuffer, newUpto);
1020+
docVector.compress(mask.cast(INT_SPECIES)).intoArray(docBuffer, newUpto);
1021+
newUpto += mask.trueCount();
1022+
}
1023+
}
1024+
1025+
for (; i < upTo; ++i) {
1026+
int doc = docBuffer[i];
1027+
float score = scoreBuffer[i];
1028+
docBuffer[newUpto] = doc;
1029+
scoreBuffer[newUpto] = score;
1030+
if (score >= minScoreInclusive) {
1031+
newUpto++;
1032+
}
1033+
}
1034+
return newUpto;
1035+
}
1036+
10081037
@SuppressForbidden(reason = "Uses compress and cast only where fast and carefully contained")
10091038
@Override
10101039
public int filterByScore(

0 commit comments

Comments
 (0)