Skip to content

Commit 42b7b78

Browse files
authored
[ES|QL] Infer the score mode to use from the Lucene collector (#125930)
This change uses the Lucene collector to infer which score mode to use when the topN collector is used.
1 parent 8028d5a commit 42b7b78

File tree

13 files changed

+125
-79
lines changed

13 files changed

+125
-79
lines changed

docs/changelog/125930.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 125930
2+
summary: Infer the score mode to use from the Lucene collector
3+
area: "ES|QL"
4+
type: enhancement
5+
issues: []

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneCountOperator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public Factory(
4949
int taskConcurrency,
5050
int limit
5151
) {
52-
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
52+
super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), dataPartitioning, taskConcurrency, limit, false);
5353
}
5454

5555
@Override

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMaxFactory.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import java.util.List;
2424
import java.util.function.Function;
2525

26+
import static org.elasticsearch.compute.lucene.LuceneOperator.weightFunction;
27+
2628
/**
2729
* Factory that generates an operator that finds the max value of a field using the {@link LuceneMinMaxOperator}.
2830
*/
@@ -121,7 +123,7 @@ public LuceneMaxFactory(
121123
NumberType numberType,
122124
int limit
123125
) {
124-
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
126+
super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), dataPartitioning, taskConcurrency, limit, false);
125127
this.fieldName = fieldName;
126128
this.numberType = numberType;
127129
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMinFactory.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import java.util.List;
2424
import java.util.function.Function;
2525

26+
import static org.elasticsearch.compute.lucene.LuceneOperator.weightFunction;
27+
2628
/**
2729
* Factory that generates an operator that finds the min value of a field using the {@link LuceneMinMaxOperator}.
2830
*/
@@ -121,7 +123,7 @@ public LuceneMinFactory(
121123
NumberType numberType,
122124
int limit
123125
) {
124-
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
126+
super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), dataPartitioning, taskConcurrency, limit, false);
125127
this.fieldName = fieldName;
126128
this.numberType = numberType;
127129
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.apache.lucene.search.BulkScorer;
1212
import org.apache.lucene.search.ConstantScoreQuery;
1313
import org.apache.lucene.search.DocIdSetIterator;
14-
import org.apache.lucene.search.IndexSearcher;
1514
import org.apache.lucene.search.LeafCollector;
1615
import org.apache.lucene.search.Query;
1716
import org.apache.lucene.search.ScoreMode;
@@ -84,28 +83,27 @@ public abstract static class Factory implements SourceOperator.SourceOperatorFac
8483
protected final DataPartitioning dataPartitioning;
8584
protected final int taskConcurrency;
8685
protected final int limit;
87-
protected final ScoreMode scoreMode;
86+
protected final boolean needsScore;
8887
protected final LuceneSliceQueue sliceQueue;
8988

9089
/**
9190
* Build the factory.
9291
*
93-
* @param scoreMode the {@link ScoreMode} passed to {@link IndexSearcher#createWeight}
92+
* @param needsScore Whether the score is needed.
9493
*/
9594
protected Factory(
9695
List<? extends ShardContext> contexts,
97-
Function<ShardContext, Query> queryFunction,
96+
Function<ShardContext, Weight> weightFunction,
9897
DataPartitioning dataPartitioning,
9998
int taskConcurrency,
10099
int limit,
101-
ScoreMode scoreMode
100+
boolean needsScore
102101
) {
103102
this.limit = limit;
104-
this.scoreMode = scoreMode;
105103
this.dataPartitioning = dataPartitioning;
106-
var weightFunction = weightFunction(queryFunction, scoreMode);
107104
this.sliceQueue = LuceneSliceQueue.create(contexts, weightFunction, dataPartitioning, taskConcurrency);
108105
this.taskConcurrency = Math.min(sliceQueue.totalSlices(), taskConcurrency);
106+
this.needsScore = needsScore;
109107
}
110108

111109
public final int taskConcurrency() {

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSourceOperator.java

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.apache.lucene.search.LeafCollector;
1212
import org.apache.lucene.search.Query;
1313
import org.apache.lucene.search.Scorable;
14-
import org.apache.lucene.search.ScoreMode;
1514
import org.elasticsearch.compute.data.BlockFactory;
1615
import org.elasticsearch.compute.data.DocBlock;
1716
import org.elasticsearch.compute.data.DocVector;
@@ -56,17 +55,24 @@ public Factory(
5655
int taskConcurrency,
5756
int maxPageSize,
5857
int limit,
59-
boolean scoring
58+
boolean needsScore
6059
) {
61-
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, scoring ? COMPLETE : COMPLETE_NO_SCORES);
60+
super(
61+
contexts,
62+
weightFunction(queryFunction, needsScore ? COMPLETE : COMPLETE_NO_SCORES),
63+
dataPartitioning,
64+
taskConcurrency,
65+
limit,
66+
needsScore
67+
);
6268
this.maxPageSize = maxPageSize;
6369
// TODO: use a single limiter for multiple stage execution
6470
this.limiter = limit == NO_LIMIT ? Limiter.NO_LIMIT : new Limiter(limit);
6571
}
6672

6773
@Override
6874
public SourceOperator get(DriverContext driverContext) {
69-
return new LuceneSourceOperator(driverContext.blockFactory(), maxPageSize, sliceQueue, limit, limiter, scoreMode);
75+
return new LuceneSourceOperator(driverContext.blockFactory(), maxPageSize, sliceQueue, limit, limiter, needsScore);
7076
}
7177

7278
public int maxPageSize() {
@@ -81,8 +87,8 @@ public String describe() {
8187
+ maxPageSize
8288
+ ", limit = "
8389
+ limit
84-
+ ", scoreMode = "
85-
+ scoreMode
90+
+ ", needsScore = "
91+
+ needsScore
8692
+ "]";
8793
}
8894
}
@@ -94,7 +100,7 @@ public LuceneSourceOperator(
94100
LuceneSliceQueue sliceQueue,
95101
int limit,
96102
Limiter limiter,
97-
ScoreMode scoreMode
103+
boolean needsScore
98104
) {
99105
super(blockFactory, maxPageSize, sliceQueue);
100106
this.minPageSize = Math.max(1, maxPageSize / 2);
@@ -104,7 +110,7 @@ public LuceneSourceOperator(
104110
boolean success = false;
105111
try {
106112
this.docsBuilder = blockFactory.newIntVectorBuilder(estimatedSize);
107-
if (scoreMode.needsScores()) {
113+
if (needsScore) {
108114
scoreBuilder = blockFactory.newDoubleVectorBuilder(estimatedSize);
109115
this.leafCollector = new ScoringCollector();
110116
} else {

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperator.java

Lines changed: 54 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
import org.apache.lucene.search.LeafCollector;
1515
import org.apache.lucene.search.Query;
1616
import org.apache.lucene.search.ScoreDoc;
17-
import org.apache.lucene.search.ScoreMode;
1817
import org.apache.lucene.search.Sort;
1918
import org.apache.lucene.search.SortField;
2019
import org.apache.lucene.search.TopDocsCollector;
2120
import org.apache.lucene.search.TopFieldCollectorManager;
2221
import org.apache.lucene.search.TopScoreDocCollectorManager;
22+
import org.apache.lucene.search.Weight;
2323
import org.elasticsearch.common.Strings;
2424
import org.elasticsearch.compute.data.BlockFactory;
2525
import org.elasticsearch.compute.data.DocBlock;
@@ -36,16 +36,14 @@
3636
import org.elasticsearch.search.sort.SortBuilder;
3737

3838
import java.io.IOException;
39+
import java.io.UncheckedIOException;
3940
import java.util.ArrayList;
4041
import java.util.Arrays;
4142
import java.util.List;
4243
import java.util.Optional;
4344
import java.util.function.Function;
4445
import java.util.stream.Collectors;
4546

46-
import static org.apache.lucene.search.ScoreMode.TOP_DOCS;
47-
import static org.apache.lucene.search.ScoreMode.TOP_DOCS_WITH_SCORES;
48-
4947
/**
5048
* Source operator that builds Pages out of the output of a TopFieldCollector (aka TopN)
5149
*/
@@ -63,16 +61,16 @@ public Factory(
6361
int maxPageSize,
6462
int limit,
6563
List<SortBuilder<?>> sorts,
66-
boolean scoring
64+
boolean needsScore
6765
) {
68-
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, scoring ? TOP_DOCS_WITH_SCORES : TOP_DOCS);
66+
super(contexts, weightFunction(queryFunction, sorts, needsScore), dataPartitioning, taskConcurrency, limit, needsScore);
6967
this.maxPageSize = maxPageSize;
7068
this.sorts = sorts;
7169
}
7270

7371
@Override
7472
public SourceOperator get(DriverContext driverContext) {
75-
return new LuceneTopNSourceOperator(driverContext.blockFactory(), maxPageSize, sorts, limit, sliceQueue, scoreMode);
73+
return new LuceneTopNSourceOperator(driverContext.blockFactory(), maxPageSize, sorts, limit, sliceQueue, needsScore);
7674
}
7775

7876
public int maxPageSize() {
@@ -88,8 +86,8 @@ public String describe() {
8886
+ maxPageSize
8987
+ ", limit = "
9088
+ limit
91-
+ ", scoreMode = "
92-
+ scoreMode
89+
+ ", needsScore = "
90+
+ needsScore
9391
+ ", sorts = ["
9492
+ notPrettySorts
9593
+ "]]";
@@ -108,20 +106,20 @@ public String describe() {
108106
private PerShardCollector perShardCollector;
109107
private final List<SortBuilder<?>> sorts;
110108
private final int limit;
111-
private final ScoreMode scoreMode;
109+
private final boolean needsScore;
112110

113111
public LuceneTopNSourceOperator(
114112
BlockFactory blockFactory,
115113
int maxPageSize,
116114
List<SortBuilder<?>> sorts,
117115
int limit,
118116
LuceneSliceQueue sliceQueue,
119-
ScoreMode scoreMode
117+
boolean needsScore
120118
) {
121119
super(blockFactory, maxPageSize, sliceQueue);
122120
this.sorts = sorts;
123121
this.limit = limit;
124-
this.scoreMode = scoreMode;
122+
this.needsScore = needsScore;
125123
}
126124

127125
@Override
@@ -163,7 +161,7 @@ private Page collect() throws IOException {
163161
try {
164162
if (perShardCollector == null || perShardCollector.shardContext.index() != scorer.shardContext().index()) {
165163
// TODO: share the bottom between shardCollectors
166-
perShardCollector = newPerShardCollector(scorer.shardContext(), sorts, limit);
164+
perShardCollector = newPerShardCollector(scorer.shardContext(), sorts, needsScore, limit);
167165
}
168166
var leafCollector = perShardCollector.getLeafCollector(scorer.leafReaderContext());
169167
scorer.scoreNextRange(leafCollector, scorer.leafReaderContext().reader().getLiveDocs(), maxPageSize);
@@ -261,7 +259,7 @@ private float getScore(ScoreDoc scoreDoc) {
261259
}
262260

263261
private DoubleVector.Builder scoreVectorOrNull(int size) {
264-
if (scoreMode.needsScores()) {
262+
if (needsScore) {
265263
return blockFactory.newDoubleVectorFixedBuilder(size);
266264
} else {
267265
return null;
@@ -271,37 +269,11 @@ private DoubleVector.Builder scoreVectorOrNull(int size) {
271269
@Override
272270
protected void describe(StringBuilder sb) {
273271
sb.append(", limit = ").append(limit);
274-
sb.append(", scoreMode = ").append(scoreMode);
272+
sb.append(", needsScore = ").append(needsScore);
275273
String notPrettySorts = sorts.stream().map(Strings::toString).collect(Collectors.joining(","));
276274
sb.append(", sorts = [").append(notPrettySorts).append("]");
277275
}
278276

279-
PerShardCollector newPerShardCollector(ShardContext shardContext, List<SortBuilder<?>> sorts, int limit) throws IOException {
280-
Optional<SortAndFormats> sortAndFormats = shardContext.buildSort(sorts);
281-
if (sortAndFormats.isEmpty()) {
282-
throw new IllegalStateException("sorts must not be disabled in TopN");
283-
}
284-
if (scoreMode.needsScores() == false) {
285-
return new NonScoringPerShardCollector(shardContext, sortAndFormats.get().sort, limit);
286-
} else {
287-
SortField[] sortFields = sortAndFormats.get().sort.getSort();
288-
if (sortFields != null && sortFields.length == 1 && sortFields[0].needsScores() && sortFields[0].getReverse() == false) {
289-
// SORT _score DESC
290-
return new ScoringPerShardCollector(shardContext, new TopScoreDocCollectorManager(limit, null, 0).newCollector());
291-
} else {
292-
// SORT ..., _score, ...
293-
var sort = new Sort();
294-
if (sortFields != null) {
295-
var l = new ArrayList<>(Arrays.asList(sortFields));
296-
l.add(SortField.FIELD_DOC);
297-
l.add(SortField.FIELD_SCORE);
298-
sort = new Sort(l.toArray(SortField[]::new));
299-
}
300-
return new ScoringPerShardCollector(shardContext, new TopFieldCollectorManager(sort, limit, null, 0).newCollector());
301-
}
302-
}
303-
}
304-
305277
abstract static class PerShardCollector {
306278
private final ShardContext shardContext;
307279
private final TopDocsCollector<?> collector;
@@ -336,4 +308,45 @@ static final class ScoringPerShardCollector extends PerShardCollector {
336308
super(shardContext, topDocsCollector);
337309
}
338310
}
311+
312+
private static Function<ShardContext, Weight> weightFunction(
313+
Function<ShardContext, Query> queryFunction,
314+
List<SortBuilder<?>> sorts,
315+
boolean needsScore
316+
) {
317+
return ctx -> {
318+
final var query = queryFunction.apply(ctx);
319+
final var searcher = ctx.searcher();
320+
try {
321+
// we create a collector with a limit of 1 to determine the appropriate score mode to use.
322+
var scoreMode = newPerShardCollector(ctx, sorts, needsScore, 1).collector.scoreMode();
323+
return searcher.createWeight(searcher.rewrite(query), scoreMode, 1);
324+
} catch (IOException e) {
325+
throw new UncheckedIOException(e);
326+
}
327+
};
328+
}
329+
330+
private static PerShardCollector newPerShardCollector(ShardContext context, List<SortBuilder<?>> sorts, boolean needsScore, int limit)
331+
throws IOException {
332+
Optional<SortAndFormats> sortAndFormats = context.buildSort(sorts);
333+
if (sortAndFormats.isEmpty()) {
334+
throw new IllegalStateException("sorts must not be disabled in TopN");
335+
}
336+
if (needsScore == false) {
337+
return new NonScoringPerShardCollector(context, sortAndFormats.get().sort, limit);
338+
}
339+
Sort sort = sortAndFormats.get().sort;
340+
if (Sort.RELEVANCE.equals(sort)) {
341+
// SORT _score DESC
342+
return new ScoringPerShardCollector(context, new TopScoreDocCollectorManager(limit, null, 0).newCollector());
343+
}
344+
345+
// SORT ..., _score, ...
346+
var l = new ArrayList<>(Arrays.asList(sort.getSort()));
347+
l.add(SortField.FIELD_DOC);
348+
l.add(SortField.FIELD_SCORE);
349+
sort = new Sort(l.toArray(SortField[]::new));
350+
return new ScoringPerShardCollector(context, new TopFieldCollectorManager(sort, limit, null, 0).newCollector());
351+
}
339352
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSortedSourceOperatorFactory.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
import java.util.List;
3434
import java.util.function.Function;
3535

36+
import static org.elasticsearch.compute.lucene.LuceneOperator.weightFunction;
37+
3638
/**
3739
* Creates a source operator that takes advantage of the natural sorting of segments in a tsdb index.
3840
* <p>
@@ -56,7 +58,7 @@ private TimeSeriesSortedSourceOperatorFactory(
5658
int maxPageSize,
5759
int limit
5860
) {
59-
super(contexts, queryFunction, DataPartitioning.SHARD, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
61+
super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), DataPartitioning.SHARD, taskConcurrency, limit, false);
6062
this.maxPageSize = maxPageSize;
6163
}
6264

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSourceOperatorTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ protected Matcher<String> expectedToStringOfSimple() {
120120
protected Matcher<String> expectedDescriptionOfSimple() {
121121
return matchesRegex(
122122
"LuceneSourceOperator"
123-
+ "\\[dataPartitioning = (DOC|SHARD|SEGMENT), maxPageSize = \\d+, limit = 100, scoreMode = (COMPLETE|COMPLETE_NO_SCORES)]"
123+
+ "\\[dataPartitioning = (DOC|SHARD|SEGMENT), maxPageSize = \\d+, limit = 100, needsScore = (true|false)]"
124124
);
125125
}
126126

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperatorScoringTests.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,15 @@ public Optional<SortAndFormats> buildSort(List<SortBuilder<?>> sorts) {
110110
@Override
111111
protected Matcher<String> expectedToStringOfSimple() {
112112
return matchesRegex(
113-
"LuceneTopNSourceOperator\\[shards = \\[test], "
114-
+ "maxPageSize = \\d+, limit = 100, scoreMode = TOP_DOCS_WITH_SCORES, sorts = \\[\\{.+}]]"
113+
"LuceneTopNSourceOperator\\[shards = \\[test], " + "maxPageSize = \\d+, limit = 100, needsScore = true, sorts = \\[\\{.+}]]"
115114
);
116115
}
117116

118117
@Override
119118
protected Matcher<String> expectedDescriptionOfSimple() {
120119
return matchesRegex(
121120
"LuceneTopNSourceOperator\\[dataPartitioning = (DOC|SHARD|SEGMENT), "
122-
+ "maxPageSize = \\d+, limit = 100, scoreMode = TOP_DOCS_WITH_SCORES, sorts = \\[\\{.+}]]"
121+
+ "maxPageSize = \\d+, limit = 100, needsScore = true, sorts = \\[\\{.+}]]"
123122
);
124123
}
125124

0 commit comments

Comments
 (0)