Skip to content
5 changes: 5 additions & 0 deletions docs/changelog/125930.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 125930
summary: Infer the score mode to use from the Lucene collector
area: "ES|QL"
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public Factory(
int taskConcurrency,
int limit
) {
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), dataPartitioning, taskConcurrency, limit, false);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.util.List;
import java.util.function.Function;

import static org.elasticsearch.compute.lucene.LuceneOperator.weightFunction;

/**
* Factory that generates an operator that finds the max value of a field using the {@link LuceneMinMaxOperator}.
*/
Expand Down Expand Up @@ -121,7 +123,7 @@ public LuceneMaxFactory(
NumberType numberType,
int limit
) {
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), dataPartitioning, taskConcurrency, limit, false);
this.fieldName = fieldName;
this.numberType = numberType;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.util.List;
import java.util.function.Function;

import static org.elasticsearch.compute.lucene.LuceneOperator.weightFunction;

/**
* Factory that generates an operator that finds the min value of a field using the {@link LuceneMinMaxOperator}.
*/
Expand Down Expand Up @@ -121,7 +123,7 @@ public LuceneMinFactory(
NumberType numberType,
int limit
) {
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), dataPartitioning, taskConcurrency, limit, false);
this.fieldName = fieldName;
this.numberType = numberType;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
Expand Down Expand Up @@ -84,28 +83,27 @@ public abstract static class Factory implements SourceOperator.SourceOperatorFac
protected final DataPartitioning dataPartitioning;
protected final int taskConcurrency;
protected final int limit;
protected final ScoreMode scoreMode;
protected final boolean needsScore;
protected final LuceneSliceQueue sliceQueue;

/**
* Build the factory.
*
* @param scoreMode the {@link ScoreMode} passed to {@link IndexSearcher#createWeight}
* @param needsScore Whether the score is needed.
*/
protected Factory(
List<? extends ShardContext> contexts,
Function<ShardContext, Query> queryFunction,
Function<ShardContext, Weight> weightFunction,
DataPartitioning dataPartitioning,
int taskConcurrency,
int limit,
ScoreMode scoreMode
boolean needsScore
) {
this.limit = limit;
this.scoreMode = scoreMode;
this.dataPartitioning = dataPartitioning;
var weightFunction = weightFunction(queryFunction, scoreMode);
this.sliceQueue = LuceneSliceQueue.create(contexts, weightFunction, dataPartitioning, taskConcurrency);
this.taskConcurrency = Math.min(sliceQueue.totalSlices(), taskConcurrency);
this.needsScore = needsScore;
}

public final int taskConcurrency() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DocBlock;
import org.elasticsearch.compute.data.DocVector;
Expand Down Expand Up @@ -56,17 +55,24 @@ public Factory(
int taskConcurrency,
int maxPageSize,
int limit,
boolean scoring
boolean needsScore
) {
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, scoring ? COMPLETE : COMPLETE_NO_SCORES);
super(
contexts,
weightFunction(queryFunction, needsScore ? COMPLETE : COMPLETE_NO_SCORES),
dataPartitioning,
taskConcurrency,
limit,
needsScore
);
this.maxPageSize = maxPageSize;
// TODO: use a single limiter for multiple stage execution
this.limiter = limit == NO_LIMIT ? Limiter.NO_LIMIT : new Limiter(limit);
}

@Override
public SourceOperator get(DriverContext driverContext) {
return new LuceneSourceOperator(driverContext.blockFactory(), maxPageSize, sliceQueue, limit, limiter, scoreMode);
return new LuceneSourceOperator(driverContext.blockFactory(), maxPageSize, sliceQueue, limit, limiter, needsScore);
}

public int maxPageSize() {
Expand All @@ -81,8 +87,8 @@ public String describe() {
+ maxPageSize
+ ", limit = "
+ limit
+ ", scoreMode = "
+ scoreMode
+ ", needsScore = "
+ needsScore
+ "]";
}
}
Expand All @@ -94,7 +100,7 @@ public LuceneSourceOperator(
LuceneSliceQueue sliceQueue,
int limit,
Limiter limiter,
ScoreMode scoreMode
boolean needsScore
) {
super(blockFactory, maxPageSize, sliceQueue);
this.minPageSize = Math.max(1, maxPageSize / 2);
Expand All @@ -104,7 +110,7 @@ public LuceneSourceOperator(
boolean success = false;
try {
this.docsBuilder = blockFactory.newIntVectorBuilder(estimatedSize);
if (scoreMode.needsScores()) {
if (needsScore) {
scoreBuilder = blockFactory.newDoubleVectorBuilder(estimatedSize);
this.leafCollector = new ScoringCollector();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TopFieldCollectorManager;
import org.apache.lucene.search.TopScoreDocCollectorManager;
import org.apache.lucene.search.Weight;
import org.elasticsearch.common.Strings;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DocBlock;
Expand All @@ -36,16 +36,14 @@
import org.elasticsearch.search.sort.SortBuilder;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.apache.lucene.search.ScoreMode.TOP_DOCS;
import static org.apache.lucene.search.ScoreMode.TOP_DOCS_WITH_SCORES;

/**
* Source operator that builds Pages out of the output of a TopFieldCollector (aka TopN)
*/
Expand All @@ -63,16 +61,16 @@ public Factory(
int maxPageSize,
int limit,
List<SortBuilder<?>> sorts,
boolean scoring
boolean needsScore
) {
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, scoring ? TOP_DOCS_WITH_SCORES : TOP_DOCS);
super(contexts, weightFunction(queryFunction, sorts, needsScore), dataPartitioning, taskConcurrency, limit, needsScore);
this.maxPageSize = maxPageSize;
this.sorts = sorts;
}

@Override
public SourceOperator get(DriverContext driverContext) {
return new LuceneTopNSourceOperator(driverContext.blockFactory(), maxPageSize, sorts, limit, sliceQueue, scoreMode);
return new LuceneTopNSourceOperator(driverContext.blockFactory(), maxPageSize, sorts, limit, sliceQueue, needsScore);
}

public int maxPageSize() {
Expand All @@ -88,8 +86,8 @@ public String describe() {
+ maxPageSize
+ ", limit = "
+ limit
+ ", scoreMode = "
+ scoreMode
+ ", needsScore = "
+ needsScore
+ ", sorts = ["
+ notPrettySorts
+ "]]";
Expand All @@ -108,20 +106,20 @@ public String describe() {
private PerShardCollector perShardCollector;
private final List<SortBuilder<?>> sorts;
private final int limit;
private final ScoreMode scoreMode;
private final boolean needsScore;

public LuceneTopNSourceOperator(
BlockFactory blockFactory,
int maxPageSize,
List<SortBuilder<?>> sorts,
int limit,
LuceneSliceQueue sliceQueue,
ScoreMode scoreMode
boolean needsScore
) {
super(blockFactory, maxPageSize, sliceQueue);
this.sorts = sorts;
this.limit = limit;
this.scoreMode = scoreMode;
this.needsScore = needsScore;
}

@Override
Expand Down Expand Up @@ -163,7 +161,7 @@ private Page collect() throws IOException {
try {
if (perShardCollector == null || perShardCollector.shardContext.index() != scorer.shardContext().index()) {
// TODO: share the bottom between shardCollectors
perShardCollector = newPerShardCollector(scorer.shardContext(), sorts, limit);
perShardCollector = newPerShardCollector(scorer.shardContext(), sorts, needsScore, limit);
}
var leafCollector = perShardCollector.getLeafCollector(scorer.leafReaderContext());
scorer.scoreNextRange(leafCollector, scorer.leafReaderContext().reader().getLiveDocs(), maxPageSize);
Expand Down Expand Up @@ -261,7 +259,7 @@ private float getScore(ScoreDoc scoreDoc) {
}

private DoubleVector.Builder scoreVectorOrNull(int size) {
if (scoreMode.needsScores()) {
if (needsScore) {
return blockFactory.newDoubleVectorFixedBuilder(size);
} else {
return null;
Expand All @@ -271,37 +269,11 @@ private DoubleVector.Builder scoreVectorOrNull(int size) {
@Override
protected void describe(StringBuilder sb) {
sb.append(", limit = ").append(limit);
sb.append(", scoreMode = ").append(scoreMode);
sb.append(", needsScore = ").append(needsScore);
String notPrettySorts = sorts.stream().map(Strings::toString).collect(Collectors.joining(","));
sb.append(", sorts = [").append(notPrettySorts).append("]");
}

PerShardCollector newPerShardCollector(ShardContext shardContext, List<SortBuilder<?>> sorts, int limit) throws IOException {
Optional<SortAndFormats> sortAndFormats = shardContext.buildSort(sorts);
if (sortAndFormats.isEmpty()) {
throw new IllegalStateException("sorts must not be disabled in TopN");
}
if (scoreMode.needsScores() == false) {
return new NonScoringPerShardCollector(shardContext, sortAndFormats.get().sort, limit);
} else {
SortField[] sortFields = sortAndFormats.get().sort.getSort();
if (sortFields != null && sortFields.length == 1 && sortFields[0].needsScores() && sortFields[0].getReverse() == false) {
// SORT _score DESC
return new ScoringPerShardCollector(shardContext, new TopScoreDocCollectorManager(limit, null, 0).newCollector());
} else {
// SORT ..., _score, ...
var sort = new Sort();
if (sortFields != null) {
var l = new ArrayList<>(Arrays.asList(sortFields));
l.add(SortField.FIELD_DOC);
l.add(SortField.FIELD_SCORE);
sort = new Sort(l.toArray(SortField[]::new));
}
return new ScoringPerShardCollector(shardContext, new TopFieldCollectorManager(sort, limit, null, 0).newCollector());
}
}
}

abstract static class PerShardCollector {
private final ShardContext shardContext;
private final TopDocsCollector<?> collector;
Expand Down Expand Up @@ -336,4 +308,45 @@ static final class ScoringPerShardCollector extends PerShardCollector {
super(shardContext, topDocsCollector);
}
}

private static Function<ShardContext, Weight> weightFunction(
Function<ShardContext, Query> queryFunction,
List<SortBuilder<?>> sorts,
boolean needsScore
) {
return ctx -> {
final var query = queryFunction.apply(ctx);
final var searcher = ctx.searcher();
try {
// we create a collector with a limit of 1 to determine the appropriate score mode to use.
var scoreMode = newPerShardCollector(ctx, sorts, needsScore, 1).collector.scoreMode();
return searcher.createWeight(searcher.rewrite(query), scoreMode, 1);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
};
}

private static PerShardCollector newPerShardCollector(ShardContext context, List<SortBuilder<?>> sorts, boolean needsScore, int limit)
throws IOException {
Optional<SortAndFormats> sortAndFormats = context.buildSort(sorts);
if (sortAndFormats.isEmpty()) {
throw new IllegalStateException("sorts must not be disabled in TopN");
}
if (needsScore == false) {
return new NonScoringPerShardCollector(context, sortAndFormats.get().sort, limit);
}
Sort sort = sortAndFormats.get().sort;
if (Sort.RELEVANCE.equals(sort)) {
// SORT _score DESC
return new ScoringPerShardCollector(context, new TopScoreDocCollectorManager(limit, null, 0).newCollector());
}

// SORT ..., _score, ...
var l = new ArrayList<>(Arrays.asList(sort.getSort()));
l.add(SortField.FIELD_DOC);
l.add(SortField.FIELD_SCORE);
sort = new Sort(l.toArray(SortField[]::new));
return new ScoringPerShardCollector(context, new TopFieldCollectorManager(sort, limit, null, 0).newCollector());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import java.util.List;
import java.util.function.Function;

import static org.elasticsearch.compute.lucene.LuceneOperator.weightFunction;

/**
* Creates a source operator that takes advantage of the natural sorting of segments in a tsdb index.
* <p>
Expand All @@ -56,7 +58,7 @@ private TimeSeriesSortedSourceOperatorFactory(
int maxPageSize,
int limit
) {
super(contexts, queryFunction, DataPartitioning.SHARD, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), DataPartitioning.SHARD, taskConcurrency, limit, false);
this.maxPageSize = maxPageSize;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ protected Matcher<String> expectedToStringOfSimple() {
protected Matcher<String> expectedDescriptionOfSimple() {
return matchesRegex(
"LuceneSourceOperator"
+ "\\[dataPartitioning = (DOC|SHARD|SEGMENT), maxPageSize = \\d+, limit = 100, scoreMode = (COMPLETE|COMPLETE_NO_SCORES)]"
+ "\\[dataPartitioning = (DOC|SHARD|SEGMENT), maxPageSize = \\d+, limit = 100, needsScore = (true|false)]"
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,15 @@ public Optional<SortAndFormats> buildSort(List<SortBuilder<?>> sorts) {
@Override
protected Matcher<String> expectedToStringOfSimple() {
return matchesRegex(
"LuceneTopNSourceOperator\\[shards = \\[test], "
+ "maxPageSize = \\d+, limit = 100, scoreMode = TOP_DOCS_WITH_SCORES, sorts = \\[\\{.+}]]"
"LuceneTopNSourceOperator\\[shards = \\[test], " + "maxPageSize = \\d+, limit = 100, needsScore = true, sorts = \\[\\{.+}]]"
);
}

@Override
protected Matcher<String> expectedDescriptionOfSimple() {
return matchesRegex(
"LuceneTopNSourceOperator\\[dataPartitioning = (DOC|SHARD|SEGMENT), "
+ "maxPageSize = \\d+, limit = 100, scoreMode = TOP_DOCS_WITH_SCORES, sorts = \\[\\{.+}]]"
+ "maxPageSize = \\d+, limit = 100, needsScore = true, sorts = \\[\\{.+}]]"
);
}

Expand Down
Loading
Loading