Skip to content

Commit 146cb39

Browse files
ESQL - enabling scoring with METADATA _score (#113120)
* ESQL - enabling scoring with METADATA _score Co-authored-by: ChrisHegarty <[email protected]>
1 parent 64dfed4 commit 146cb39

File tree

32 files changed

+1570
-96
lines changed

32 files changed

+1570
-96
lines changed

docs/changelog/113120.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 113120
2+
summary: ESQL - enabling scoring with METADATA `_score`
3+
area: ES|QL
4+
type: enhancement
5+
issues: []

muted-tests.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,12 @@ tests:
224224
issue: https://github.com/elastic/elasticsearch/issues/117591
225225
- class: org.elasticsearch.repositories.s3.RepositoryS3ClientYamlTestSuiteIT
226226
issue: https://github.com/elastic/elasticsearch/issues/117596
227+
- class: "org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT"
228+
method: "test {scoring.*}"
229+
issue: https://github.com/elastic/elasticsearch/issues/117641
230+
- class: "org.elasticsearch.xpack.esql.qa.single_node.EsqlSpecIT"
231+
method: "test {scoring.*}"
232+
issue: https://github.com/elastic/elasticsearch/issues/117641
227233

228234
# Examples:
229235
#

server/src/main/java/org/elasticsearch/search/sort/SortBuilder.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,11 @@ private static void parseCompoundSortField(XContentParser parser, List<SortBuild
158158
}
159159

160160
public static Optional<SortAndFormats> buildSort(List<SortBuilder<?>> sortBuilders, SearchExecutionContext context) throws IOException {
161+
return buildSort(sortBuilders, context, true);
162+
}
163+
164+
public static Optional<SortAndFormats> buildSort(List<SortBuilder<?>> sortBuilders, SearchExecutionContext context, boolean optimize)
165+
throws IOException {
161166
List<SortField> sortFields = new ArrayList<>(sortBuilders.size());
162167
List<DocValueFormat> sortFormats = new ArrayList<>(sortBuilders.size());
163168
for (SortBuilder<?> builder : sortBuilders) {
@@ -172,9 +177,13 @@ public static Optional<SortAndFormats> buildSort(List<SortBuilder<?>> sortBuilde
172177
if (sortFields.size() > 1) {
173178
sort = true;
174179
} else {
175-
SortField sortField = sortFields.get(0);
176-
if (sortField.getType() == SortField.Type.SCORE && sortField.getReverse() == false) {
177-
sort = false;
180+
if (optimize) {
181+
SortField sortField = sortFields.get(0);
182+
if (sortField.getType() == SortField.Type.SCORE && sortField.getReverse() == false) {
183+
sort = false;
184+
} else {
185+
sort = true;
186+
}
178187
} else {
179188
sort = true;
180189
}

x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MetadataAttribute.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
public class MetadataAttribute extends TypedAttribute {
3232
public static final String TIMESTAMP_FIELD = "@timestamp";
3333
public static final String TSID_FIELD = "_tsid";
34+
public static final String SCORE = "_score";
3435

3536
static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
3637
Attribute.class,
@@ -50,7 +51,9 @@ public class MetadataAttribute extends TypedAttribute {
5051
SourceFieldMapper.NAME,
5152
tuple(DataType.SOURCE, false),
5253
IndexModeFieldMapper.NAME,
53-
tuple(DataType.KEYWORD, true)
54+
tuple(DataType.KEYWORD, true),
55+
SCORE,
56+
tuple(DataType.DOUBLE, false)
5457
);
5558

5659
private final boolean searchable;

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ public abstract static class Factory implements SourceOperator.SourceOperatorFac
7979
protected final DataPartitioning dataPartitioning;
8080
protected final int taskConcurrency;
8181
protected final int limit;
82+
protected final ScoreMode scoreMode;
8283
protected final LuceneSliceQueue sliceQueue;
8384

8485
/**
@@ -95,6 +96,7 @@ protected Factory(
9596
ScoreMode scoreMode
9697
) {
9798
this.limit = limit;
99+
this.scoreMode = scoreMode;
98100
this.dataPartitioning = dataPartitioning;
99101
var weightFunction = weightFunction(queryFunction, scoreMode);
100102
this.sliceQueue = LuceneSliceQueue.create(contexts, weightFunction, dataPartitioning, taskConcurrency);
@@ -438,7 +440,8 @@ static Function<ShardContext, Weight> weightFunction(Function<ShardContext, Quer
438440
final var query = queryFunction.apply(ctx);
439441
final var searcher = ctx.searcher();
440442
try {
441-
return searcher.createWeight(searcher.rewrite(new ConstantScoreQuery(query)), scoreMode, 1);
443+
Query actualQuery = scoreMode.needsScores() ? query : new ConstantScoreQuery(query);
444+
return searcher.createWeight(searcher.rewrite(actualQuery), scoreMode, 1);
442445
} catch (IOException e) {
443446
throw new UncheckedIOException(e);
444447
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSourceOperator.java

Lines changed: 74 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
import org.apache.lucene.search.Scorable;
1414
import org.apache.lucene.search.ScoreMode;
1515
import org.elasticsearch.compute.data.BlockFactory;
16+
import org.elasticsearch.compute.data.DocBlock;
1617
import org.elasticsearch.compute.data.DocVector;
18+
import org.elasticsearch.compute.data.DoubleVector;
1719
import org.elasticsearch.compute.data.IntBlock;
1820
import org.elasticsearch.compute.data.IntVector;
1921
import org.elasticsearch.compute.data.Page;
@@ -25,6 +27,9 @@
2527
import java.util.List;
2628
import java.util.function.Function;
2729

30+
import static org.apache.lucene.search.ScoreMode.COMPLETE;
31+
import static org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES;
32+
2833
/**
2934
* Source operator that incrementally runs Lucene searches
3035
*/
@@ -34,6 +39,7 @@ public class LuceneSourceOperator extends LuceneOperator {
3439
private int remainingDocs;
3540

3641
private IntVector.Builder docsBuilder;
42+
private DoubleVector.Builder scoreBuilder;
3743
private final LeafCollector leafCollector;
3844
private final int minPageSize;
3945

@@ -47,15 +53,16 @@ public Factory(
4753
DataPartitioning dataPartitioning,
4854
int taskConcurrency,
4955
int maxPageSize,
50-
int limit
56+
int limit,
57+
boolean scoring
5158
) {
52-
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
59+
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, scoring ? COMPLETE : COMPLETE_NO_SCORES);
5360
this.maxPageSize = maxPageSize;
5461
}
5562

5663
@Override
5764
public SourceOperator get(DriverContext driverContext) {
58-
return new LuceneSourceOperator(driverContext.blockFactory(), maxPageSize, sliceQueue, limit);
65+
return new LuceneSourceOperator(driverContext.blockFactory(), maxPageSize, sliceQueue, limit, scoreMode);
5966
}
6067

6168
public int maxPageSize() {
@@ -70,32 +77,65 @@ public String describe() {
7077
+ maxPageSize
7178
+ ", limit = "
7279
+ limit
80+
+ ", scoreMode = "
81+
+ scoreMode
7382
+ "]";
7483
}
7584
}
7685

77-
public LuceneSourceOperator(BlockFactory blockFactory, int maxPageSize, LuceneSliceQueue sliceQueue, int limit) {
86+
@SuppressWarnings("this-escape")
87+
public LuceneSourceOperator(BlockFactory blockFactory, int maxPageSize, LuceneSliceQueue sliceQueue, int limit, ScoreMode scoreMode) {
7888
super(blockFactory, maxPageSize, sliceQueue);
7989
this.minPageSize = Math.max(1, maxPageSize / 2);
8090
this.remainingDocs = limit;
81-
this.docsBuilder = blockFactory.newIntVectorBuilder(Math.min(limit, maxPageSize));
82-
this.leafCollector = new LeafCollector() {
83-
@Override
84-
public void setScorer(Scorable scorer) {
85-
91+
int estimatedSize = Math.min(limit, maxPageSize);
92+
boolean success = false;
93+
try {
94+
this.docsBuilder = blockFactory.newIntVectorBuilder(estimatedSize);
95+
if (scoreMode.needsScores()) {
96+
scoreBuilder = blockFactory.newDoubleVectorBuilder(estimatedSize);
97+
this.leafCollector = new ScoringCollector();
98+
} else {
99+
scoreBuilder = null;
100+
this.leafCollector = new LimitingCollector();
86101
}
102+
success = true;
103+
} finally {
104+
if (success == false) {
105+
close();
106+
}
107+
}
108+
}
87109

88-
@Override
89-
public void collect(int doc) {
90-
if (remainingDocs > 0) {
91-
--remainingDocs;
92-
docsBuilder.appendInt(doc);
93-
currentPagePos++;
94-
} else {
95-
throw new CollectionTerminatedException();
96-
}
110+
class LimitingCollector implements LeafCollector {
111+
@Override
112+
public void setScorer(Scorable scorer) {}
113+
114+
@Override
115+
public void collect(int doc) throws IOException {
116+
if (remainingDocs > 0) {
117+
--remainingDocs;
118+
docsBuilder.appendInt(doc);
119+
currentPagePos++;
120+
} else {
121+
throw new CollectionTerminatedException();
97122
}
98-
};
123+
}
124+
}
125+
126+
final class ScoringCollector extends LuceneSourceOperator.LimitingCollector {
127+
private Scorable scorable;
128+
129+
@Override
130+
public void setScorer(Scorable scorer) {
131+
this.scorable = scorer;
132+
}
133+
134+
@Override
135+
public void collect(int doc) throws IOException {
136+
super.collect(doc);
137+
scoreBuilder.appendDouble(scorable.score());
138+
}
99139
}
100140

101141
@Override
@@ -139,15 +179,27 @@ public Page getCheckedOutput() throws IOException {
139179
IntBlock shard = null;
140180
IntBlock leaf = null;
141181
IntVector docs = null;
182+
DoubleVector scores = null;
183+
DocBlock docBlock = null;
142184
try {
143185
shard = blockFactory.newConstantIntBlockWith(scorer.shardContext().index(), currentPagePos);
144186
leaf = blockFactory.newConstantIntBlockWith(scorer.leafReaderContext().ord, currentPagePos);
145187
docs = docsBuilder.build();
146188
docsBuilder = blockFactory.newIntVectorBuilder(Math.min(remainingDocs, maxPageSize));
147-
page = new Page(currentPagePos, new DocVector(shard.asVector(), leaf.asVector(), docs, true).asBlock());
189+
docBlock = new DocVector(shard.asVector(), leaf.asVector(), docs, true).asBlock();
190+
shard = null;
191+
leaf = null;
192+
docs = null;
193+
if (scoreBuilder == null) {
194+
page = new Page(currentPagePos, docBlock);
195+
} else {
196+
scores = scoreBuilder.build();
197+
scoreBuilder = blockFactory.newDoubleVectorBuilder(Math.min(remainingDocs, maxPageSize));
198+
page = new Page(currentPagePos, docBlock, scores.asBlock());
199+
}
148200
} finally {
149201
if (page == null) {
150-
Releasables.closeExpectNoException(shard, leaf, docs);
202+
Releasables.closeExpectNoException(shard, leaf, docs, docBlock, scores);
151203
}
152204
}
153205
currentPagePos = 0;
@@ -160,7 +212,7 @@ public Page getCheckedOutput() throws IOException {
160212

161213
@Override
162214
public void close() {
163-
docsBuilder.close();
215+
Releasables.close(docsBuilder, scoreBuilder);
164216
}
165217

166218
@Override

0 commit comments

Comments
 (0)