Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/125930.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 125930
summary: Infer the score mode to use from the Lucene collector
area: "ES|QL"
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public Factory(
int taskConcurrency,
int limit
) {
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), dataPartitioning, taskConcurrency, limit, false);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.util.List;
import java.util.function.Function;

import static org.elasticsearch.compute.lucene.LuceneOperator.weightFunction;

/**
* Factory that generates an operator that finds the max value of a field using the {@link LuceneMinMaxOperator}.
*/
Expand Down Expand Up @@ -121,7 +123,7 @@ public LuceneMaxFactory(
NumberType numberType,
int limit
) {
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), dataPartitioning, taskConcurrency, limit, false);
this.fieldName = fieldName;
this.numberType = numberType;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.util.List;
import java.util.function.Function;

import static org.elasticsearch.compute.lucene.LuceneOperator.weightFunction;

/**
* Factory that generates an operator that finds the min value of a field using the {@link LuceneMinMaxOperator}.
*/
Expand Down Expand Up @@ -121,7 +123,7 @@ public LuceneMinFactory(
NumberType numberType,
int limit
) {
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), dataPartitioning, taskConcurrency, limit, false);
this.fieldName = fieldName;
this.numberType = numberType;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
Expand Down Expand Up @@ -83,28 +82,27 @@ public abstract static class Factory implements SourceOperator.SourceOperatorFac
protected final DataPartitioning dataPartitioning;
protected final int taskConcurrency;
protected final int limit;
protected final ScoreMode scoreMode;
protected final boolean needsScore;
protected final LuceneSliceQueue sliceQueue;

/**
* Build the factory.
*
* @param scoreMode the {@link ScoreMode} passed to {@link IndexSearcher#createWeight}
* @param needsScore Whether the score is needed.
*/
protected Factory(
List<? extends ShardContext> contexts,
Function<ShardContext, Query> queryFunction,
Function<ShardContext, Weight> weightFunction,
DataPartitioning dataPartitioning,
int taskConcurrency,
int limit,
ScoreMode scoreMode
boolean needsScore
) {
this.limit = limit;
this.scoreMode = scoreMode;
this.dataPartitioning = dataPartitioning;
var weightFunction = weightFunction(queryFunction, scoreMode);
this.sliceQueue = LuceneSliceQueue.create(contexts, weightFunction, dataPartitioning, taskConcurrency);
this.taskConcurrency = Math.min(sliceQueue.totalSlices(), taskConcurrency);
this.needsScore = needsScore;
}

public final int taskConcurrency() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DocBlock;
import org.elasticsearch.compute.data.DocVector;
Expand Down Expand Up @@ -56,17 +55,24 @@ public Factory(
int taskConcurrency,
int maxPageSize,
int limit,
boolean scoring
boolean needsScore
) {
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, scoring ? COMPLETE : COMPLETE_NO_SCORES);
super(
contexts,
weightFunction(queryFunction, needsScore ? COMPLETE : COMPLETE_NO_SCORES),
dataPartitioning,
taskConcurrency,
limit,
needsScore
);
this.maxPageSize = maxPageSize;
// TODO: use a single limiter for multiple stage execution
this.limiter = limit == NO_LIMIT ? Limiter.NO_LIMIT : new Limiter(limit);
}

@Override
public SourceOperator get(DriverContext driverContext) {
return new LuceneSourceOperator(driverContext.blockFactory(), maxPageSize, sliceQueue, limit, limiter, scoreMode);
return new LuceneSourceOperator(driverContext.blockFactory(), maxPageSize, sliceQueue, limit, limiter, needsScore);
}

public int maxPageSize() {
Expand All @@ -81,8 +87,8 @@ public String describe() {
+ maxPageSize
+ ", limit = "
+ limit
+ ", scoreMode = "
+ scoreMode
+ ", needsScore = "
+ needsScore
+ "]";
}
}
Expand All @@ -94,7 +100,7 @@ public LuceneSourceOperator(
LuceneSliceQueue sliceQueue,
int limit,
Limiter limiter,
ScoreMode scoreMode
boolean needsScore
) {
super(blockFactory, maxPageSize, sliceQueue);
this.minPageSize = Math.max(1, maxPageSize / 2);
Expand All @@ -104,7 +110,7 @@ public LuceneSourceOperator(
boolean success = false;
try {
this.docsBuilder = blockFactory.newIntVectorBuilder(estimatedSize);
if (scoreMode.needsScores()) {
if (needsScore) {
scoreBuilder = blockFactory.newDoubleVectorBuilder(estimatedSize);
this.leafCollector = new ScoringCollector();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TopFieldCollectorManager;
import org.apache.lucene.search.TopScoreDocCollectorManager;
import org.apache.lucene.search.Weight;
import org.elasticsearch.common.Strings;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DocBlock;
Expand All @@ -36,16 +36,14 @@
import org.elasticsearch.search.sort.SortBuilder;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.apache.lucene.search.ScoreMode.COMPLETE;
import static org.apache.lucene.search.ScoreMode.TOP_DOCS;

/**
* Source operator that builds Pages out of the output of a TopFieldCollector (aka TopN)
*/
Expand All @@ -62,16 +60,16 @@ public Factory(
int maxPageSize,
int limit,
List<SortBuilder<?>> sorts,
boolean scoring
boolean needsScore
) {
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, scoring ? COMPLETE : TOP_DOCS);
super(contexts, weightFunction(queryFunction, sorts, needsScore), dataPartitioning, taskConcurrency, limit, needsScore);
this.maxPageSize = maxPageSize;
this.sorts = sorts;
}

@Override
public SourceOperator get(DriverContext driverContext) {
return new LuceneTopNSourceOperator(driverContext.blockFactory(), maxPageSize, sorts, limit, sliceQueue, scoreMode);
return new LuceneTopNSourceOperator(driverContext.blockFactory(), maxPageSize, sorts, limit, sliceQueue, needsScore);
}

public int maxPageSize() {
Expand All @@ -87,8 +85,8 @@ public String describe() {
+ maxPageSize
+ ", limit = "
+ limit
+ ", scoreMode = "
+ scoreMode
+ ", needsScore = "
+ needsScore
+ ", sorts = ["
+ notPrettySorts
+ "]]";
Expand All @@ -107,20 +105,20 @@ public String describe() {
private PerShardCollector perShardCollector;
private final List<SortBuilder<?>> sorts;
private final int limit;
private final ScoreMode scoreMode;
private final boolean needsScore;

public LuceneTopNSourceOperator(
BlockFactory blockFactory,
int maxPageSize,
List<SortBuilder<?>> sorts,
int limit,
LuceneSliceQueue sliceQueue,
ScoreMode scoreMode
boolean needsScore
) {
super(blockFactory, maxPageSize, sliceQueue);
this.sorts = sorts;
this.limit = limit;
this.scoreMode = scoreMode;
this.needsScore = needsScore;
}

@Override
Expand Down Expand Up @@ -162,7 +160,7 @@ private Page collect() throws IOException {
try {
if (perShardCollector == null || perShardCollector.shardContext.index() != scorer.shardContext().index()) {
// TODO: share the bottom between shardCollectors
perShardCollector = newPerShardCollector(scorer.shardContext(), sorts, limit);
perShardCollector = newPerShardCollector(scorer.shardContext(), sorts, needsScore, limit);
}
var leafCollector = perShardCollector.getLeafCollector(scorer.leafReaderContext());
scorer.scoreNextRange(leafCollector, scorer.leafReaderContext().reader().getLiveDocs(), maxPageSize);
Expand Down Expand Up @@ -260,7 +258,7 @@ private float getScore(ScoreDoc scoreDoc) {
}

private DoubleVector.Builder scoreVectorOrNull(int size) {
if (scoreMode.needsScores()) {
if (needsScore) {
return blockFactory.newDoubleVectorFixedBuilder(size);
} else {
return null;
Expand All @@ -270,43 +268,11 @@ private DoubleVector.Builder scoreVectorOrNull(int size) {
@Override
protected void describe(StringBuilder sb) {
sb.append(", limit = ").append(limit);
sb.append(", scoreMode = ").append(scoreMode);
sb.append(", needsScore = ").append(needsScore);
String notPrettySorts = sorts.stream().map(Strings::toString).collect(Collectors.joining(","));
sb.append(", sorts = [").append(notPrettySorts).append("]");
}

PerShardCollector newPerShardCollector(ShardContext shardContext, List<SortBuilder<?>> sorts, int limit) throws IOException {
Optional<SortAndFormats> sortAndFormats = shardContext.buildSort(sorts);
if (sortAndFormats.isEmpty()) {
throw new IllegalStateException("sorts must not be disabled in TopN");
}
if (scoreMode.needsScores() == false) {
return new NonScoringPerShardCollector(shardContext, sortAndFormats.get().sort, limit);
} else {
SortField[] sortFields = sortAndFormats.get().sort.getSort();
if (sortFields != null && sortFields.length == 1 && sortFields[0].needsScores() && sortFields[0].getReverse() == false) {
// SORT _score DESC
return new ScoringPerShardCollector(
shardContext,
new TopScoreDocCollectorManager(limit, null, limit, false).newCollector()
);
} else {
// SORT ..., _score, ...
var sort = new Sort();
if (sortFields != null) {
var l = new ArrayList<>(Arrays.asList(sortFields));
l.add(SortField.FIELD_DOC);
l.add(SortField.FIELD_SCORE);
sort = new Sort(l.toArray(SortField[]::new));
}
return new ScoringPerShardCollector(
shardContext,
new TopFieldCollectorManager(sort, limit, null, limit, false).newCollector()
);
}
}
}

abstract static class PerShardCollector {
private final ShardContext shardContext;
private final TopDocsCollector<?> collector;
Expand Down Expand Up @@ -341,4 +307,45 @@ static final class ScoringPerShardCollector extends PerShardCollector {
super(shardContext, topDocsCollector);
}
}

private static Function<ShardContext, Weight> weightFunction(
Function<ShardContext, Query> queryFunction,
List<SortBuilder<?>> sorts,
boolean needsScore
) {
return ctx -> {
final var query = queryFunction.apply(ctx);
final var searcher = ctx.searcher();
try {
// we create a collector with a limit of 1 to determine the appropriate score mode to use.
var scoreMode = newPerShardCollector(ctx, sorts, needsScore, 1).collector.scoreMode();
return searcher.createWeight(searcher.rewrite(query), scoreMode, 1);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
};
}

private static PerShardCollector newPerShardCollector(ShardContext context, List<SortBuilder<?>> sorts, boolean needsScore, int limit)
throws IOException {
Optional<SortAndFormats> sortAndFormats = context.buildSort(sorts);
if (sortAndFormats.isEmpty()) {
throw new IllegalStateException("sorts must not be disabled in TopN");
}
if (needsScore == false) {
return new NonScoringPerShardCollector(context, sortAndFormats.get().sort, limit);
}
Sort sort = sortAndFormats.get().sort;
if (Sort.RELEVANCE.equals(sort)) {
// SORT _score DESC
return new ScoringPerShardCollector(context, new TopScoreDocCollectorManager(limit, null, 0).newCollector());
}

// SORT ..., _score, ...
var l = new ArrayList<>(Arrays.asList(sort.getSort()));
l.add(SortField.FIELD_DOC);
l.add(SortField.FIELD_SCORE);
sort = new Sort(l.toArray(SortField[]::new));
return new ScoringPerShardCollector(context, new TopFieldCollectorManager(sort, limit, null, 0).newCollector());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import java.util.List;
import java.util.function.Function;

import static org.elasticsearch.compute.lucene.LuceneOperator.weightFunction;

/**
* Creates a source operator that takes advantage of the natural sorting of segments in a tsdb index.
* <p>
Expand All @@ -56,7 +58,7 @@ private TimeSeriesSortedSourceOperatorFactory(
int maxPageSize,
int limit
) {
super(contexts, queryFunction, DataPartitioning.SHARD, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), DataPartitioning.SHARD, taskConcurrency, limit, false);
this.maxPageSize = maxPageSize;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ protected Matcher<String> expectedToStringOfSimple() {
protected Matcher<String> expectedDescriptionOfSimple() {
return matchesRegex(
"LuceneSourceOperator"
+ "\\[dataPartitioning = (DOC|SHARD|SEGMENT), maxPageSize = \\d+, limit = 100, scoreMode = (COMPLETE|COMPLETE_NO_SCORES)]"
+ "\\[dataPartitioning = (DOC|SHARD|SEGMENT), maxPageSize = \\d+, limit = 100, needsScore = (true|false)]"
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,14 @@ public Optional<SortAndFormats> buildSort(List<SortBuilder<?>> sorts) {

@Override
protected Matcher<String> expectedToStringOfSimple() {
return matchesRegex("LuceneTopNSourceOperator\\[maxPageSize = \\d+, limit = 100, scoreMode = COMPLETE, sorts = \\[\\{.+}]]");
return matchesRegex("LuceneTopNSourceOperator\\[maxPageSize = \\d+, limit = 100, needsScore = true, sorts = \\[\\{.+}]]");
}

@Override
protected Matcher<String> expectedDescriptionOfSimple() {
return matchesRegex(
"LuceneTopNSourceOperator"
+ "\\[dataPartitioning = (DOC|SHARD|SEGMENT), maxPageSize = \\d+, limit = 100, scoreMode = COMPLETE, sorts = \\[\\{.+}]]"
+ "\\[dataPartitioning = (DOC|SHARD|SEGMENT), maxPageSize = \\d+, limit = 100, needsScore = true, sorts = \\[\\{.+}]]"
);
}

Expand Down
Loading