diff --git a/docs/changelog/121322.yaml b/docs/changelog/121322.yaml new file mode 100644 index 0000000000000..8ef570994bbd9 --- /dev/null +++ b/docs/changelog/121322.yaml @@ -0,0 +1,5 @@ +pr: 121322 +summary: "[PoC 2] ESQL - Add scoring for full text functions disjunctions using `ExpressionEvaluator`" +area: ES|QL +type: enhancement +issues: [] diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryExpressionEvaluator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryExpressionEvaluator.java index 0ba1872504c40..81be3ed7514f9 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryExpressionEvaluator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryExpressionEvaluator.java @@ -23,12 +23,15 @@ import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.DocBlock; import org.elasticsearch.compute.data.DocVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; import java.io.IOException; import java.io.UncheckedIOException; @@ -41,16 +44,20 @@ * this evaluator is here to save the day. */ public class LuceneQueryExpressionEvaluator implements EvalOperator.ExpressionEvaluator { + public record ShardConfig(Query query, IndexSearcher searcher) {} private final BlockFactory blockFactory; private final ShardConfig[] shards; + private final boolean usesScoring; private ShardState[] perShardState = EMPTY_SHARD_STATES; + private DoubleVector scoreVector; - public LuceneQueryExpressionEvaluator(BlockFactory blockFactory, ShardConfig[] shards) { + public LuceneQueryExpressionEvaluator(BlockFactory blockFactory, ShardConfig[] shards, boolean usesScoring) { this.blockFactory = blockFactory; this.shards = shards; + this.usesScoring = usesScoring; } @Override @@ -60,16 +67,26 @@ public Block eval(Page page) { assert block instanceof DocBlock : "LuceneQueryExpressionEvaluator expects DocBlock as input"; DocVector docs = (DocVector) block.asVector(); try { + final Tuple evalResult; if (docs.singleSegmentNonDecreasing()) { - return evalSingleSegmentNonDecreasing(docs).asBlock(); + evalResult = evalSingleSegmentNonDecreasing(docs); } else { - return evalSlow(docs).asBlock(); + evalResult = evalSlow(docs); } + // Cache the score vector for later use + scoreVector = evalResult.v2(); + return evalResult.v1().asBlock(); } catch (IOException e) { throw new UncheckedIOException(e); } } + @Override + public DoubleBlock score(Page page, BlockFactory blockFactory) { + assert scoreVector != null : "eval() should be invoked before calling score()"; + return scoreVector.asBlock(); + } + /** * Evaluate {@link DocVector#singleSegmentNonDecreasing()} documents. *

@@ -100,7 +117,7 @@ public Block eval(Page page) { * common. *

*/ - private BooleanVector evalSingleSegmentNonDecreasing(DocVector docs) throws IOException { + private Tuple evalSingleSegmentNonDecreasing(DocVector docs) throws IOException { ShardState shardState = shardState(docs.shards().getInt(0)); SegmentState segmentState = shardState.segmentState(docs.segments().getInt(0)); int min = docs.docs().getInt(0); @@ -126,13 +143,16 @@ private BooleanVector evalSingleSegmentNonDecreasing(DocVector docs) throws IOEx * the order that the {@link DocVector} came in. *

*/ - private BooleanVector evalSlow(DocVector docs) throws IOException { + private Tuple evalSlow(DocVector docs) throws IOException { int[] map = docs.shardSegmentDocMapForwards(); // Clear any state flags from the previous run int prevShard = -1; int prevSegment = -1; SegmentState segmentState = null; - try (BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount())) { + try ( + BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount()); + DoubleVector.Builder scoreBuilder = blockFactory.newDoubleVectorFixedBuilder(docs.getPositionCount()) + ) { for (int i = 0; i < docs.getPositionCount(); i++) { int shard = docs.shards().getInt(docs.shards().getInt(map[i])); int segment = docs.segments().getInt(map[i]); @@ -144,19 +164,26 @@ private BooleanVector evalSlow(DocVector docs) throws IOException { } if (segmentState.noMatch) { builder.appendBoolean(false); + scoreBuilder.appendDouble(NO_MATCH_SCORE); } else { - segmentState.scoreSingleDocWithScorer(builder, docs.docs().getInt(map[i])); + segmentState.scoreSingleDocWithScorer(builder, scoreBuilder, docs.docs().getInt(map[i])); } } - try (BooleanVector outOfOrder = builder.build()) { - return outOfOrder.filter(docs.shardSegmentDocMapBackwards()); + try (BooleanVector outOfOrder = builder.build(); DoubleVector outOfOrderScores = scoreBuilder.build()) { + return Tuple.tuple( + outOfOrder.filter(docs.shardSegmentDocMapBackwards()), + outOfOrderScores.filter(docs.shardSegmentDocMapBackwards()) + ); } } } @Override public void close() { - + if ((scoreVector != null) && scoreVector.isReleased() == false) { + // Scores may not be retrieved calling score(), for example with NOT. Try to free up the score vector + Releasables.closeExpectNoException(scoreVector); + } } private ShardState shardState(int shard) throws IOException { @@ -175,7 +202,9 @@ private class ShardState { private SegmentState[] perSegmentState = EMPTY_SEGMENT_STATES; ShardState(ShardConfig config) throws IOException { - weight = config.searcher.createWeight(config.query, ScoreMode.COMPLETE_NO_SCORES, 0.0f); + float boost = usesScoring ? 1.0f : 0.0f; + ScoreMode scoreMode = usesScoring ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES; + weight = config.searcher.createWeight(config.query, scoreMode, boost); searcher = config.searcher; } @@ -231,10 +260,13 @@ private SegmentState(Weight weight, LeafReaderContext ctx) { * Score a range using the {@link BulkScorer}. This should be faster * than using {@link #scoreSparse} for dense doc ids. */ - BooleanVector scoreDense(int min, int max) throws IOException { + Tuple scoreDense(int min, int max) throws IOException { int length = max - min + 1; if (noMatch) { - return blockFactory.newConstantBooleanVector(false, length); + return Tuple.tuple( + blockFactory.newConstantBooleanVector(false, length), + blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, length) + ); } if (bulkScorer == null || // The bulkScorer wasn't initialized Thread.currentThread() != bulkScorerThread // The bulkScorer was initialized on a different thread @@ -243,12 +275,22 @@ BooleanVector scoreDense(int min, int max) throws IOException { bulkScorer = weight.bulkScorer(ctx); if (bulkScorer == null) { noMatch = true; - return blockFactory.newConstantBooleanVector(false, length); + return Tuple.tuple( + blockFactory.newConstantBooleanVector(false, length), + blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, length) + ); } } - try (DenseCollector collector = new DenseCollector(blockFactory, min, max)) { + + final DenseCollector collector; + if (usesScoring) { + collector = new ScoringDenseCollector(blockFactory, min, max); + } else { + collector = new DenseCollector(blockFactory, min, max); + } + try (collector) { bulkScorer.score(collector, ctx.reader().getLiveDocs(), min, max + 1); - return collector.build(); + return Tuple.tuple(collector.buildMatchVector(), collector.buildScoreVector()); } } @@ -256,16 +298,22 @@ BooleanVector scoreDense(int min, int max) throws IOException { * Score a vector of doc ids using {@link Scorer}. If you have a dense range of * doc ids it'd be faster to use {@link #scoreDense}. */ - BooleanVector scoreSparse(IntVector docs) throws IOException { + Tuple scoreSparse(IntVector docs) throws IOException { initScorer(docs.getInt(0)); if (noMatch) { - return blockFactory.newConstantBooleanVector(false, docs.getPositionCount()); + return Tuple.tuple( + blockFactory.newConstantBooleanVector(false, docs.getPositionCount()), + blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, docs.getPositionCount()) + ); } - try (BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount())) { + try ( + BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount()); + DoubleVector.Builder scoreBuilder = blockFactory.newDoubleVectorFixedBuilder(docs.getPositionCount()) + ) { for (int i = 0; i < docs.getPositionCount(); i++) { - scoreSingleDocWithScorer(builder, docs.getInt(i)); + scoreSingleDocWithScorer(builder, scoreBuilder, docs.getInt(i)); } - return builder.build(); + return Tuple.tuple(builder.build(), scoreBuilder.build()); } } @@ -285,7 +333,8 @@ private void initScorer(int minDocId) throws IOException { } } - private void scoreSingleDocWithScorer(BooleanVector.Builder builder, int doc) throws IOException { + private void scoreSingleDocWithScorer(BooleanVector.Builder builder, DoubleVector.Builder scoreBuilder, int doc) + throws IOException { if (scorer.iterator().docID() == doc) { builder.appendBoolean(true); } else if (scorer.iterator().docID() > doc) { @@ -305,13 +354,13 @@ private void scoreSingleDocWithScorer(BooleanVector.Builder builder, int doc) th * which isn't documented, but @jpountz swears is true. */ static class DenseCollector implements LeafCollector, Releasable { - private final BooleanVector.FixedBuilder builder; + private final BooleanVector.FixedBuilder matchBuilder; private final int max; int next; DenseCollector(BlockFactory blockFactory, int min, int max) { - this.builder = blockFactory.newBooleanVectorFixedBuilder(max - min + 1); + this.matchBuilder = blockFactory.newBooleanVectorFixedBuilder(max - min + 1); this.max = max; next = min; } @@ -320,40 +369,92 @@ static class DenseCollector implements LeafCollector, Releasable { public void setScorer(Scorable scorable) {} @Override - public void collect(int doc) { + public void collect(int doc) throws IOException { while (next++ < doc) { - builder.appendBoolean(false); + appendNoMatch(); } - builder.appendBoolean(true); + appendMatch(); + } + + protected void appendMatch() throws IOException { + matchBuilder.appendBoolean(true); + } + + protected void appendNoMatch() { + matchBuilder.appendBoolean(false); } - public BooleanVector build() { - return builder.build(); + public BooleanVector buildMatchVector() { + return matchBuilder.build(); + } + + public DoubleVector buildScoreVector() { + return null; } @Override public void finish() { while (next++ <= max) { - builder.appendBoolean(false); + appendNoMatch(); } } @Override public void close() { - Releasables.closeExpectNoException(builder); + Releasables.closeExpectNoException(matchBuilder); + } + } + + static class ScoringDenseCollector extends DenseCollector { + private final DoubleVector.FixedBuilder scoreBuilder; + private Scorable scorable; + + ScoringDenseCollector(BlockFactory blockFactory, int min, int max) { + super(blockFactory, min, max); + this.scoreBuilder = blockFactory.newDoubleVectorFixedBuilder(max - min + 1); + } + + @Override + public void setScorer(Scorable scorable) { + this.scorable = scorable; + } + + @Override + protected void appendMatch() throws IOException { + super.appendMatch(); + scoreBuilder.appendDouble(scorable.score()); + } + + @Override + protected void appendNoMatch() { + super.appendNoMatch(); + scoreBuilder.appendDouble(NO_MATCH_SCORE); + } + + @Override + public DoubleVector buildScoreVector() { + return scoreBuilder.build(); + } + + @Override + public void close() { + super.close(); + Releasables.closeExpectNoException(scoreBuilder); } } public static class Factory implements EvalOperator.ExpressionEvaluator.Factory { private final ShardConfig[] shardConfigs; + private final boolean usesScoring; - public Factory(ShardConfig[] shardConfigs) { + public Factory(ShardConfig[] shardConfigs, boolean usesScoring) { this.shardConfigs = shardConfigs; + this.usesScoring = usesScoring; } @Override public EvalOperator.ExpressionEvaluator get(DriverContext context) { - return new LuceneQueryExpressionEvaluator(context.blockFactory(), shardConfigs); + return new LuceneQueryExpressionEvaluator(context.blockFactory(), shardConfigs, usesScoring); } } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java index 2573baf78b16a..d26822c14c792 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java @@ -9,6 +9,7 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; @@ -60,6 +61,12 @@ public void close() { * Evaluates an expression {@code a + b} or {@code log(c)} one {@link Page} at a time. */ public interface ExpressionEvaluator extends Releasable { + + /** + * Scoring value when an expression does not match (return false) or does not contribute to the score + */ + double NO_MATCH_SCORE = 0.0; + /** A Factory for creating ExpressionEvaluators. */ interface Factory { ExpressionEvaluator get(DriverContext context); @@ -81,6 +88,14 @@ default boolean eagerEvalSafeInLazy() { * @return the returned Block has its own reference and the caller is responsible for releasing it. */ Block eval(Page page); + + /** + * Retrieves the score for the expression. + * Only expressions that can be scored, or that can combine scores, should override this method. + */ + default DoubleBlock score(Page page, BlockFactory blockFactory) { + return blockFactory.newConstantDoubleBlockWith(NO_MATCH_SCORE, page.getPositionCount()); + } } public static final ExpressionEvaluator.Factory CONSTANT_NULL_FACTORY = new ExpressionEvaluator.Factory() { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FilterOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FilterOperator.java index 5b8d485c4da3a..f515b0051f222 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FilterOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FilterOperator.java @@ -8,7 +8,10 @@ package org.elasticsearch.compute.operator; import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.core.Releasables; @@ -17,13 +20,17 @@ public class FilterOperator extends AbstractPageMappingOperator { + public static final int SCORE_BLOCK_INDEX = 1; + private final EvalOperator.ExpressionEvaluator evaluator; + private final boolean usesScoring; + private final BlockFactory blockFactory; - public record FilterOperatorFactory(ExpressionEvaluator.Factory evaluatorSupplier) implements OperatorFactory { + public record FilterOperatorFactory(ExpressionEvaluator.Factory evaluatorSupplier, boolean usesScoring) implements OperatorFactory { @Override public Operator get(DriverContext driverContext) { - return new FilterOperator(evaluatorSupplier.get(driverContext)); + return new FilterOperator(evaluatorSupplier.get(driverContext), usesScoring, driverContext.blockFactory()); } @Override @@ -32,8 +39,10 @@ public String describe() { } } - public FilterOperator(EvalOperator.ExpressionEvaluator evaluator) { + public FilterOperator(ExpressionEvaluator evaluator, boolean usesScoring, BlockFactory blockFactory) { this.evaluator = evaluator; + this.usesScoring = usesScoring; + this.blockFactory = blockFactory; } @Override @@ -41,21 +50,22 @@ protected Page process(Page page) { int rowCount = 0; int[] positions = new int[page.getPositionCount()]; - try (BooleanBlock test = (BooleanBlock) evaluator.eval(page)) { - if (test.areAllValuesNull()) { + try (BooleanBlock filterResultBlock = (BooleanBlock) evaluator.eval(page)) { + if (filterResultBlock.areAllValuesNull()) { // All results are null which is like false. No values selected. page.releaseBlocks(); return null; } + // TODO we can detect constant true or false from the type // TODO or we could make a new method in bool-valued evaluators that returns a list of numbers for (int p = 0; p < page.getPositionCount(); p++) { - if (test.isNull(p) || test.getValueCount(p) != 1) { + if (filterResultBlock.isNull(p) || filterResultBlock.getValueCount(p) != 1) { // Null is like false // And, for now, multivalued results are like false too continue; } - if (test.getBoolean(test.getFirstValueIndex(p))) { + if (filterResultBlock.getBoolean(filterResultBlock.getFirstValueIndex(p))) { positions[rowCount++] = p; } } @@ -64,7 +74,13 @@ protected Page process(Page page) { page.releaseBlocks(); return null; } - if (rowCount == page.getPositionCount()) { + DoubleBlock scoreBlock = null; + if (usesScoring) { + scoreBlock = evaluator.score(page, blockFactory); + assert scoreBlock != null : "score block is when using scoring"; + } + + if (rowCount == page.getPositionCount() && usesScoring == false) { return page; } positions = Arrays.copyOf(positions, rowCount); @@ -73,11 +89,16 @@ protected Page process(Page page) { boolean success = false; try { for (int i = 0; i < page.getBlockCount(); i++) { - filteredBlocks[i] = page.getBlock(i).filter(positions); + if (usesScoring && i == SCORE_BLOCK_INDEX) { + filteredBlocks[i] = createScoresBlock(rowCount, page.getBlock(i), scoreBlock, positions); + } else { + filteredBlocks[i] = page.getBlock(i).filter(positions); + } } success = true; } finally { page.releaseBlocks(); + Releasables.closeExpectNoException(scoreBlock); if (success == false) { Releasables.closeExpectNoException(filteredBlocks); } @@ -86,6 +107,15 @@ protected Page process(Page page) { } } + private Block createScoresBlock(int rowCount, DoubleBlock originalScoreBlock, DoubleBlock newScoreBlock, int[] positions) { + // Create a new scores block with the retrieved scores, that will replace the existing one on the result page + DoubleVector.Builder updatedScoresBuilder = blockFactory.newDoubleVectorBuilder(rowCount); + for (int j = 0; j < rowCount; j++) { + updatedScoresBuilder.appendDouble(originalScoreBlock.getDouble(positions[j]) + newScoreBlock.getDouble(positions[j])); + } + return updatedScoresBuilder.build().asBlock(); + } + @Override public String toString() { return "FilterOperator[" + "evaluator=" + evaluator + ']'; diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryExpressionEvaluatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryExpressionEvaluatorTests.java index 54b33732aa425..95801ba6d8d93 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryExpressionEvaluatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryExpressionEvaluatorTests.java @@ -57,13 +57,13 @@ public class LuceneQueryExpressionEvaluatorTests extends ComputeTestCase { private static final String FIELD = "g"; - public void testDenseCollectorSmall() { + public void testDenseCollectorSmall() throws IOException { try (DenseCollector collector = new DenseCollector(blockFactory(), 0, 2)) { collector.collect(0); collector.collect(1); collector.collect(2); collector.finish(); - try (BooleanVector result = collector.build()) { + try (BooleanVector result = collector.buildMatchVector()) { for (int i = 0; i <= 2; i++) { assertThat(result.getBoolean(i), equalTo(true)); } @@ -71,12 +71,12 @@ public void testDenseCollectorSmall() { } } - public void testDenseCollectorSimple() { + public void testDenseCollectorSimple() throws IOException { try (DenseCollector collector = new DenseCollector(blockFactory(), 0, 10)) { collector.collect(2); collector.collect(5); collector.finish(); - try (BooleanVector result = collector.build()) { + try (BooleanVector result = collector.buildMatchVector()) { for (int i = 0; i < 11; i++) { assertThat(result.getBoolean(i), equalTo(i == 2 || i == 5)); } @@ -84,7 +84,7 @@ public void testDenseCollectorSimple() { } } - public void testDenseCollector() { + public void testDenseCollector() throws IOException { int length = between(1, 10_000); int min = between(0, Integer.MAX_VALUE - length - 1); int max = min + length + 1; @@ -97,7 +97,7 @@ public void testDenseCollector() { } } collector.finish(); - try (BooleanVector result = collector.build()) { + try (BooleanVector result = collector.buildMatchVector()) { for (int i = 0; i < length; i++) { assertThat(result.getBoolean(i), equalTo(expected[i])); } @@ -132,6 +132,7 @@ private void assertTermQuery(String term, List results) { matchCount++; } } + page.releaseBlocks(); } assertThat(matchCount, equalTo(1)); } @@ -183,8 +184,8 @@ private List runQuery(Set values, Query query, boolean shuffleDocs ); LuceneQueryExpressionEvaluator luceneQueryEvaluator = new LuceneQueryExpressionEvaluator( blockFactory, - new LuceneQueryExpressionEvaluator.ShardConfig[] { shard } - + new LuceneQueryExpressionEvaluator.ShardConfig[] { shard }, + false ); List operators = new ArrayList<>(); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/FilterOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/FilterOperatorTests.java index ce85d9baa5c7d..b6c29d083b978 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/FilterOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/FilterOperatorTests.java @@ -66,7 +66,7 @@ public EvalOperator.ExpressionEvaluator get(DriverContext context) { public String toString() { return "SameLastDigit[lhs=0, rhs=1]"; } - }); + }, false); } @Override @@ -113,7 +113,7 @@ public void testReadFromBlock() { new SequenceBooleanBlockSourceOperator(context.blockFactory(), List.of(true, false, true, false)) ); List results = drive( - new FilterOperator.FilterOperatorFactory(dvrCtx -> new EvalOperatorTests.LoadFromPage(0)).get(context), + new FilterOperator.FilterOperatorFactory(dvrCtx -> new EvalOperatorTests.LoadFromPage(0), false).get(context), input.iterator(), context ); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/MatchFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/MatchFunctionIT.java index b928b25929401..6f569774359f1 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/MatchFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/MatchFunctionIT.java @@ -19,7 +19,11 @@ import java.util.List; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.getValuesList; import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.lessThan; //@TestLogging(value = "org.elasticsearch.xpack.esql:TRACE,org.elasticsearch.compute:TRACE", reason = "debug") public class MatchFunctionIT extends AbstractEsqlIntegTestCase { @@ -260,20 +264,123 @@ public void testMatchWithinEval() { assertThat(error.getMessage(), containsString("[MATCH] function is only supported in WHERE commands")); } + public void testDisjunctionScoring() { + var query = """ + FROM test METADATA _score + | WHERE match(content, "fox") OR length(content) < 20 + | KEEP id, _score + | SORT _score DESC, id ASC + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double")); + List> values = getValuesList(resp); + assertThat(values.size(), equalTo(3)); + + assertThat(values.get(0).get(0), equalTo(1)); + assertThat(values.get(1).get(0), equalTo(6)); + assertThat(values.get(2).get(0), equalTo(2)); + + // Matches full text query and non pushable query + assertThat((Double) values.get(0).get(1), greaterThan(1.0)); + assertThat((Double) values.get(1).get(1), greaterThan(1.0)); + // Matches just non pushable query + assertThat((Double) values.get(2).get(1), equalTo(1.0)); + } + } + + public void testDisjunctionScoringMultipleNonPushableFunctions() { + var query = """ + FROM test METADATA _score + | WHERE match(content, "fox") OR length(content) < 20 AND id > 2 + | KEEP id, _score + | SORT _score DESC + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double")); + List> values = getValuesList(resp); + assertThat(values.size(), equalTo(2)); + + assertThat(values.get(0).get(0), equalTo(1)); + assertThat(values.get(1).get(0), equalTo(6)); + + // Matches the full text query and a two pushable query + assertThat((Double) values.get(0).get(1), greaterThan(2.0)); + assertThat((Double) values.get(0).get(1), lessThan(3.0)); + // Matches just the match function + assertThat((Double) values.get(1).get(1), lessThan(2.0)); + assertThat((Double) values.get(1).get(1), greaterThan(1.0)); + } + } + + public void testDisjunctionScoringWithNot() { + var query = """ + FROM test METADATA _score + | WHERE NOT(match(content, "dog")) OR length(content) > 50 + | KEEP id, _score + | SORT _score DESC, id ASC + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double")); + List> values = getValuesList(resp); + assertThat(values.size(), equalTo(3)); + + assertThat(values.get(0).get(0), equalTo(1)); + assertThat(values.get(1).get(0), equalTo(4)); + assertThat(values.get(2).get(0), equalTo(5)); + + // Matches NOT gets 0.0 and default score is 1.0 + assertThat((Double) values.get(0).get(1), equalTo(1.0)); + assertThat((Double) values.get(1).get(1), equalTo(1.0)); + assertThat((Double) values.get(2).get(1), equalTo(1.0)); + } + } + + public void testScoringWithNoFullTextFunction() { + var query = """ + FROM test METADATA _score + | WHERE length(content) > 50 + | KEEP id, _score + | SORT _score DESC, id ASC + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double")); + List> values = getValuesList(resp); + assertThat(values.size(), equalTo(1)); + + assertThat(values.get(0).get(0), equalTo(4)); + + // Non pushable query gets score of 0.0, summed with 1.0 coming from Lucene + assertThat((Double) values.get(0).get(1), equalTo(1.0)); + } + } + private void createAndPopulateIndex() { var indexName = "test"; var client = client().admin().indices(); var CreateRequest = client.prepareCreate(indexName) .setSettings(Settings.builder().put("index.number_of_shards", 1)) - .setMapping("id", "type=integer", "content", "type=text"); + .setMapping("id", "type=integer", "content", "type=text", "length", "type=integer"); assertAcked(CreateRequest); client().prepareBulk() - .add(new IndexRequest(indexName).id("1").source("id", 1, "content", "This is a brown fox")) - .add(new IndexRequest(indexName).id("2").source("id", 2, "content", "This is a brown dog")) - .add(new IndexRequest(indexName).id("3").source("id", 3, "content", "This dog is really brown")) - .add(new IndexRequest(indexName).id("4").source("id", 4, "content", "The dog is brown but this document is very very long")) - .add(new IndexRequest(indexName).id("5").source("id", 5, "content", "There is also a white cat")) - .add(new IndexRequest(indexName).id("6").source("id", 6, "content", "The quick brown fox jumps over the lazy dog")) + .add(new IndexRequest(indexName).id("1").source("id", 1, "content", "This is a brown fox", "length", 19)) + .add(new IndexRequest(indexName).id("2").source("id", 2, "content", "This is a brown dog", "length", 19)) + .add(new IndexRequest(indexName).id("3").source("id", 3, "content", "This dog is really brown", "length", 25)) + .add( + new IndexRequest(indexName).id("4") + .source("id", 4, "content", "The dog is brown but this document is very very long", "length", 52) + ) + .add(new IndexRequest(indexName).id("5").source("id", 5, "content", "There is also a white cat", "length", 25)) + .add( + new IndexRequest(indexName).id("6").source("id", 6, "content", "The quick brown fox jumps over the lazy dog", "length", 43) + ) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .get(); ensureYellow(indexName); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/EvalMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/EvalMapper.java index de3b070adbb1f..c6aa736d2922d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/EvalMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/EvalMapper.java @@ -13,6 +13,7 @@ import org.elasticsearch.compute.data.BlockUtils; import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.data.Vector; @@ -52,7 +53,7 @@ public final class EvalMapper { private EvalMapper() {} public static ExpressionEvaluator.Factory toEvaluator(FoldContext foldCtx, Expression exp, Layout layout) { - return toEvaluator(foldCtx, exp, layout, List.of()); + return toEvaluator(foldCtx, exp, layout, List.of(), false); } @SuppressWarnings({ "rawtypes", "unchecked" }) @@ -68,13 +69,14 @@ public static ExpressionEvaluator.Factory toEvaluator( FoldContext foldCtx, Expression exp, Layout layout, - List shardContexts + List shardContexts, + boolean usesScoring ) { if (exp instanceof EvaluatorMapper m) { return m.toEvaluator(new EvaluatorMapper.ToEvaluator() { @Override public ExpressionEvaluator.Factory apply(Expression expression) { - return toEvaluator(foldCtx, expression, layout, shardContexts); + return toEvaluator(foldCtx, expression, layout, shardContexts, usesScoring); } @Override @@ -86,11 +88,16 @@ public FoldContext foldCtx() { public List shardContexts() { return shardContexts; } + + @Override + public boolean usesScoring() { + return usesScoring; + } }); } for (ExpressionMapper em : MAPPERS) { if (em.typeToken.isInstance(exp)) { - return em.map(foldCtx, exp, layout, shardContexts); + return em.map(foldCtx, exp, layout, shardContexts, usesScoring); } } throw new QlIllegalArgumentException("Unsupported expression [{}]", exp); @@ -98,9 +105,15 @@ public List shardContexts() { static class BooleanLogic extends ExpressionMapper { @Override - public ExpressionEvaluator.Factory map(FoldContext foldCtx, BinaryLogic bc, Layout layout, List shardContexts) { - var leftEval = toEvaluator(foldCtx, bc.left(), layout, shardContexts); - var rightEval = toEvaluator(foldCtx, bc.right(), layout, shardContexts); + public ExpressionEvaluator.Factory map( + FoldContext foldCtx, + BinaryLogic bc, + Layout layout, + List shardContexts, + boolean usesScoring + ) { + var leftEval = toEvaluator(foldCtx, bc.left(), layout, shardContexts, usesScoring); + var rightEval = toEvaluator(foldCtx, bc.right(), layout, shardContexts, usesScoring); /** * Evaluator for the three-valued boolean expressions. * We can't generate these with the {@link Evaluator} annotation because that @@ -165,6 +178,22 @@ private Block eval(BooleanVector lhs, BooleanVector rhs) { } } + @Override + public DoubleBlock score(Page page, BlockFactory blockFactory) { + try (DoubleBlock lhs = leftEval.score(page, blockFactory); DoubleBlock rhs = rightEval.score(page, blockFactory)) { + int positionCount = lhs.getPositionCount(); + // TODO We could optimize for constant vectors + try (var result = lhs.blockFactory().newDoubleVectorFixedBuilder(positionCount)) { + for (int p = 0; p < positionCount; p++) { + double l = lhs.getDouble(p); + double r = rhs.getDouble(p); + result.appendDouble(bl.function().score(l, r)); + } + return result.build().asBlock(); + } + } + } + @Override public void close() { Releasables.closeExpectNoException(leftEval, rightEval); @@ -176,8 +205,14 @@ public void close() { static class Nots extends ExpressionMapper { @Override - public ExpressionEvaluator.Factory map(FoldContext foldCtx, Not not, Layout layout, List shardContexts) { - var expEval = toEvaluator(foldCtx, not.field(), layout); + public ExpressionEvaluator.Factory map( + FoldContext foldCtx, + Not not, + Layout layout, + List shardContexts, + boolean usesScoring + ) { + var expEval = toEvaluator(foldCtx, not.field(), layout, shardContexts, usesScoring); return dvrCtx -> new org.elasticsearch.xpack.esql.evaluator.predicate.operator.logical.NotEvaluator( not.source(), expEval.get(dvrCtx), @@ -188,7 +223,13 @@ public ExpressionEvaluator.Factory map(FoldContext foldCtx, Not not, Layout layo static class Attributes extends ExpressionMapper { @Override - public ExpressionEvaluator.Factory map(FoldContext foldCtx, Attribute attr, Layout layout, List shardContexts) { + public ExpressionEvaluator.Factory map( + FoldContext foldCtx, + Attribute attr, + Layout layout, + List shardContexts, + boolean usesScoring + ) { record Attribute(int channel) implements ExpressionEvaluator { @Override public Block eval(Page page) { @@ -223,7 +264,13 @@ public boolean eagerEvalSafeInLazy() { static class Literals extends ExpressionMapper { @Override - public ExpressionEvaluator.Factory map(FoldContext foldCtx, Literal lit, Layout layout, List shardContexts) { + public ExpressionEvaluator.Factory map( + FoldContext foldCtx, + Literal lit, + Layout layout, + List shardContexts, + boolean usesScoring + ) { record LiteralsEvaluator(DriverContext context, Literal lit) implements ExpressionEvaluator { @Override public Block eval(Page page) { @@ -280,8 +327,14 @@ private static Block block(Literal lit, BlockFactory blockFactory, int positions static class IsNulls extends ExpressionMapper { @Override - public ExpressionEvaluator.Factory map(FoldContext foldCtx, IsNull isNull, Layout layout, List shardContexts) { - var field = toEvaluator(foldCtx, isNull.field(), layout); + public ExpressionEvaluator.Factory map( + FoldContext foldCtx, + IsNull isNull, + Layout layout, + List shardContexts, + boolean usesScoring + ) { + var field = toEvaluator(foldCtx, isNull.field(), layout, shardContexts, usesScoring); return new IsNullEvaluatorFactory(field); } @@ -328,8 +381,14 @@ public String toString() { static class IsNotNulls extends ExpressionMapper { @Override - public ExpressionEvaluator.Factory map(FoldContext foldCtx, IsNotNull isNotNull, Layout layout, List shardContexts) { - return new IsNotNullEvaluatorFactory(toEvaluator(foldCtx, isNotNull.field(), layout)); + public ExpressionEvaluator.Factory map( + FoldContext foldCtx, + IsNotNull isNotNull, + Layout layout, + List shardContexts, + boolean usesScoring + ) { + return new IsNotNullEvaluatorFactory(toEvaluator(foldCtx, isNotNull.field(), layout, shardContexts, usesScoring)); } record IsNotNullEvaluatorFactory(EvalOperator.ExpressionEvaluator.Factory field) implements ExpressionEvaluator.Factory { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/EvaluatorMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/EvaluatorMapper.java index a4a17297abc09..57ca45f505d61 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/EvaluatorMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/EvaluatorMapper.java @@ -41,6 +41,10 @@ interface ToEvaluator { default List shardContexts() { throw new UnsupportedOperationException("Shard contexts should only be needed for evaluation operations"); } + + default boolean usesScoring() { + return false; + } } /** diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/ExpressionMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/ExpressionMapper.java index 06a8a92ecfce8..9d29abb002715 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/ExpressionMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/ExpressionMapper.java @@ -23,5 +23,11 @@ public ExpressionMapper() { typeToken = ReflectionUtils.detectSuperTypeForRuleLike(getClass()); } - public abstract ExpressionEvaluator.Factory map(FoldContext foldCtx, E expression, Layout layout, List shardContexts); + public abstract ExpressionEvaluator.Factory map( + FoldContext foldCtx, + E expression, + Layout layout, + List shardContexts, + boolean usesScoring + ); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java index 32a350ac7351e..45cb6697a70e8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java @@ -17,18 +17,15 @@ import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FoldContext; -import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.expression.function.Function; import org.elasticsearch.xpack.esql.core.querydsl.query.Query; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; import org.elasticsearch.xpack.esql.expression.predicate.logical.BinaryLogic; import org.elasticsearch.xpack.esql.expression.predicate.logical.Not; -import org.elasticsearch.xpack.esql.expression.predicate.logical.Or; import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; @@ -147,7 +144,7 @@ public boolean equals(Object obj) { @Override public boolean translatable(LucenePushdownPredicates pushdownPredicates) { - // In isolation, full text functions are pushable to source. We check if there are no disjunctions in Or conditions + // In isolation, full text functions are pushable to source return true; } @@ -208,13 +205,6 @@ private static void checkFullTextQueryFunctions(LogicalPlan plan, Failures failu failures ); checkFullTextFunctionsParents(condition, failures); - - boolean usesScore = plan.output() - .stream() - .anyMatch(attr -> attr instanceof MetadataAttribute ma && ma.name().equals(MetadataAttribute.SCORE)); - if (usesScore) { - checkFullTextSearchDisjunctions(condition, failures); - } } else { plan.forEachExpression(FullTextFunction.class, ftf -> { failures.add(fail(ftf, "[{}] {} is only supported in WHERE commands", ftf.functionName(), ftf.functionType())); @@ -222,65 +212,6 @@ private static void checkFullTextQueryFunctions(LogicalPlan plan, Failures failu } } - /** - * Checks whether a condition contains a disjunction with a full text search. - * If it does, check that every element of the disjunction is a full text search or combinations (AND, OR, NOT) of them. - * If not, add a failure to the failures collection. - * - * @param condition condition to check for disjunctions of full text searches - * @param failures failures collection to add to - */ - private static void checkFullTextSearchDisjunctions(Expression condition, Failures failures) { - Holder isInvalid = new Holder<>(false); - condition.forEachDown(Or.class, or -> { - if (isInvalid.get()) { - // Exit early if we already have a failures - return; - } - if (checkDisjunctionPushable(or) == false) { - isInvalid.set(true); - failures.add( - fail( - or, - "Invalid condition when using METADATA _score [{}]. Full text functions can be used in an OR condition, " - + "but only if just full text functions are used in the OR condition", - or.sourceText() - ) - ); - } - }); - } - - /** - * Checks if a disjunction is pushable from the point of view of FullTextFunctions. Either it has no FullTextFunctions or - * all it contains are FullTextFunctions. - * - * @param or disjunction to check - * @return true if the disjunction is pushable, false otherwise - */ - private static boolean checkDisjunctionPushable(Or or) { - boolean hasFullText = or.anyMatch(FullTextFunction.class::isInstance); - return hasFullText == false || onlyFullTextFunctionsInExpression(or); - } - - /** - * Checks whether an expression contains just full text functions or negations (NOT) and combinations (AND, OR) of full text functions - * - * @param expression expression to check - * @return true if all children are full text functions or negations of full text functions, false otherwise - */ - private static boolean onlyFullTextFunctionsInExpression(Expression expression) { - if (expression instanceof FullTextFunction) { - return true; - } else if (expression instanceof Not) { - return onlyFullTextFunctionsInExpression(expression.children().get(0)); - } else if (expression instanceof BinaryLogic binaryLogic) { - return onlyFullTextFunctionsInExpression(binaryLogic.left()) && onlyFullTextFunctionsInExpression(binaryLogic.right()); - } - - return false; - } - /** * Checks all commands that exist before a specific type satisfy conditions. * @@ -367,6 +298,6 @@ public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvalua for (EsPhysicalOperationProviders.ShardContext shardContext : shardContexts) { shardConfigs[i++] = new ShardConfig(shardContext.toQuery(queryBuilder()), shardContext.searcher()); } - return new LuceneQueryExpressionEvaluator.Factory(shardConfigs); + return new LuceneQueryExpressionEvaluator.Factory(shardConfigs, toEvaluator.usesScoring()); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/logical/BinaryLogicOperation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/logical/BinaryLogicOperation.java index 5a5fad32327fb..64c0b08afaae0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/logical/BinaryLogicOperation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/logical/BinaryLogicOperation.java @@ -50,6 +50,10 @@ public Boolean apply(Boolean left, Boolean right) { return process.apply(left, right); } + public Double score(Double leftScore, Double rightScore) { + return leftScore + rightScore; + } + @Override public final Boolean doApply(Boolean left, Boolean right) { return null; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEqualsMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEqualsMapper.java index 70d87b7cc77ff..6526999976710 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEqualsMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEqualsMapper.java @@ -35,13 +35,14 @@ public final ExpressionEvaluator.Factory map( FoldContext foldCtx, InsensitiveEquals bc, Layout layout, - List shardContexts + List shardContexts, + boolean usesScoring ) { DataType leftType = bc.left().dataType(); DataType rightType = bc.right().dataType(); - var leftEval = toEvaluator(foldCtx, bc.left(), layout, shardContexts); - var rightEval = toEvaluator(foldCtx, bc.right(), layout, shardContexts); + var leftEval = toEvaluator(foldCtx, bc.left(), layout, shardContexts, usesScoring); + var rightEval = toEvaluator(foldCtx, bc.right(), layout, shardContexts, usesScoring); if (DataType.isString(leftType)) { if (bc.right().foldable() && DataType.isString(rightType)) { BytesRef rightVal = BytesRefs.toBytesRef(bc.right().fold(FoldContext.small() /* TODO remove me */)); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java index aa24ea113cb48..7b77b06aef2b1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java @@ -676,8 +676,12 @@ private PhysicalOperation planProject(ProjectExec project, LocalExecutionPlanner private PhysicalOperation planFilter(FilterExec filter, LocalExecutionPlannerContext context) { PhysicalOperation source = plan(filter.child(), context); // TODO: should this be extracted into a separate eval block? + boolean usesScore = PlannerUtils.usesScoring(filter); return source.with( - new FilterOperatorFactory(EvalMapper.toEvaluator(context.foldCtx(), filter.condition(), source.layout, shardContexts)), + new FilterOperatorFactory( + EvalMapper.toEvaluator(context.foldCtx(), filter.condition(), source.layout, shardContexts, usesScore), + usesScore + ), source.layout ); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java index c5139d45f4b37..83c808863f93d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.esql.core.expression.AttributeSet; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; import org.elasticsearch.xpack.esql.core.tree.Node; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -31,6 +32,7 @@ import org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizer; import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalPlanOptimizer; +import org.elasticsearch.xpack.esql.plan.QueryPlan; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.join.Join; @@ -307,4 +309,8 @@ public static ElementType toElementType(DataType dataType, MappedFieldType.Field new NoopCircuitBreaker("noop-esql-breaker"), BigArrays.NON_RECYCLING_INSTANCE ); + + public static boolean usesScoring(QueryPlan plan) { + return plan.output().stream().anyMatch(attr -> attr instanceof MetadataAttribute ma && ma.name().equals(MetadataAttribute.SCORE)); + } }