Skip to content

Commit 2fc193c

Browse files
committed
Merge remote-tracking branch 'tteofili/esql_metadata_score' into carlosdelest/esql-match-operator-colon-scoring-demo
# Conflicts: # x-pack/plugin/esql/qa/testFixtures/src/main/resources/match-operator.csv-spec
2 parents bb1fd05 + bb22685 commit 2fc193c

File tree

19 files changed

+844
-30
lines changed

19 files changed

+844
-30
lines changed

docs/changelog/113120.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 113120
2+
summary: ESQL - enabling scoring with METADATA `_score`
3+
area: ES|QL
4+
type: enhancement
5+
issues: []

docs/reference/esql/metadata-fields.asciidoc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ supported ones are:
2020
* <<mapping-ignored-field,`_ignored`>>: the ignored source document fields. The field is of the type
2121
<<keyword,keyword>>.
2222

23+
* `_score`: the score of each document with respect to the portion of the ES|QL query that can be
24+
pushed down to Lucene. The field is of the type <<number,double>>.
25+
2326
To enable the access to these fields, the <<esql-from,`FROM`>> source command needs
2427
to be provided with a dedicated directive:
2528

x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MetadataAttribute.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
public class MetadataAttribute extends TypedAttribute {
3232
public static final String TIMESTAMP_FIELD = "@timestamp";
3333
public static final String TSID_FIELD = "_tsid";
34+
public static final String SCORE = "_score";
3435

3536
static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
3637
Attribute.class,
@@ -50,7 +51,9 @@ public class MetadataAttribute extends TypedAttribute {
5051
SourceFieldMapper.NAME,
5152
tuple(DataType.SOURCE, false),
5253
IndexModeFieldMapper.NAME,
53-
tuple(DataType.KEYWORD, true)
54+
tuple(DataType.KEYWORD, true),
55+
SCORE,
56+
tuple(DataType.DOUBLE, false)
5457
);
5558

5659
private final boolean searchable;

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,8 @@ static Function<ShardContext, Weight> weightFunction(Function<ShardContext, Quer
438438
final var query = queryFunction.apply(ctx);
439439
final var searcher = ctx.searcher();
440440
try {
441-
return searcher.createWeight(searcher.rewrite(new ConstantScoreQuery(query)), scoreMode, 1);
441+
Query actualQuery = scoreMode.needsScores() ? query : new ConstantScoreQuery(query);
442+
return searcher.createWeight(searcher.rewrite(actualQuery), scoreMode, 1);
442443
} catch (IOException e) {
443444
throw new UncheckedIOException(e);
444445
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperator.java

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
import org.apache.lucene.search.Query;
1515
import org.apache.lucene.search.ScoreDoc;
1616
import org.apache.lucene.search.ScoreMode;
17-
import org.apache.lucene.search.TopFieldCollector;
17+
import org.apache.lucene.search.TopDocsCollector;
1818
import org.apache.lucene.search.TopFieldCollectorManager;
1919
import org.elasticsearch.common.Strings;
2020
import org.elasticsearch.compute.data.BlockFactory;
21+
import org.elasticsearch.compute.data.DocBlock;
2122
import org.elasticsearch.compute.data.DocVector;
23+
import org.elasticsearch.compute.data.DoubleVector;
2224
import org.elasticsearch.compute.data.IntBlock;
2325
import org.elasticsearch.compute.data.IntVector;
2426
import org.elasticsearch.compute.data.Page;
@@ -38,10 +40,10 @@
3840
/**
3941
* Source operator that builds Pages out of the output of a TopFieldCollector (aka TopN)
4042
*/
41-
public final class LuceneTopNSourceOperator extends LuceneOperator {
42-
public static final class Factory extends LuceneOperator.Factory {
43-
private final int maxPageSize;
44-
private final List<SortBuilder<?>> sorts;
43+
public class LuceneTopNSourceOperator extends LuceneOperator {
44+
public static class Factory extends LuceneOperator.Factory {
45+
protected final int maxPageSize;
46+
protected final List<SortBuilder<?>> sorts;
4547

4648
public Factory(
4749
List<? extends ShardContext> contexts,
@@ -50,9 +52,10 @@ public Factory(
5052
int taskConcurrency,
5153
int maxPageSize,
5254
int limit,
53-
List<SortBuilder<?>> sorts
55+
List<SortBuilder<?>> sorts,
56+
ScoreMode scoreMode
5457
) {
55-
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, ScoreMode.TOP_DOCS);
58+
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, scoreMode);
5659
this.maxPageSize = maxPageSize;
5760
this.sorts = sorts;
5861
}
@@ -91,8 +94,8 @@ public String describe() {
9194
private int offset = 0;
9295

9396
private PerShardCollector perShardCollector;
94-
private final List<SortBuilder<?>> sorts;
95-
private final int limit;
97+
protected final List<SortBuilder<?>> sorts;
98+
protected final int limit;
9699

97100
public LuceneTopNSourceOperator(
98101
BlockFactory blockFactory,
@@ -145,7 +148,7 @@ private Page collect() throws IOException {
145148
try {
146149
if (perShardCollector == null || perShardCollector.shardContext.index() != scorer.shardContext().index()) {
147150
// TODO: share the bottom between shardCollectors
148-
perShardCollector = new PerShardCollector(scorer.shardContext(), sorts, limit);
151+
perShardCollector = newPerShardCollector(scorer.shardContext(), sorts, limit);
149152
}
150153
var leafCollector = perShardCollector.getLeafCollector(scorer.leafReaderContext());
151154
scorer.scoreNextRange(leafCollector, scorer.leafReaderContext().reader().getLiveDocs(), maxPageSize);
@@ -171,7 +174,7 @@ private Page emit(boolean startEmitting) {
171174
assert isEmitting() == false : "offset=" + offset + " score_docs=" + Arrays.toString(scoreDocs);
172175
offset = 0;
173176
if (perShardCollector != null) {
174-
scoreDocs = perShardCollector.topFieldCollector.topDocs().scoreDocs;
177+
scoreDocs = perShardCollector.collector.topDocs().scoreDocs;
175178
} else {
176179
scoreDocs = new ScoreDoc[0];
177180
}
@@ -183,10 +186,12 @@ private Page emit(boolean startEmitting) {
183186
IntBlock shard = null;
184187
IntVector segments = null;
185188
IntVector docs = null;
189+
DocBlock docBlock = null;
186190
Page page = null;
187191
try (
188192
IntVector.Builder currentSegmentBuilder = blockFactory.newIntVectorFixedBuilder(size);
189-
IntVector.Builder currentDocsBuilder = blockFactory.newIntVectorFixedBuilder(size)
193+
IntVector.Builder currentDocsBuilder = blockFactory.newIntVectorFixedBuilder(size);
194+
DoubleVector.Builder currentScoresBuilder = scoreVectorOrNull(size);
190195
) {
191196
int start = offset;
192197
offset += size;
@@ -196,12 +201,14 @@ private Page emit(boolean startEmitting) {
196201
int segment = ReaderUtil.subIndex(doc, leafContexts);
197202
currentSegmentBuilder.appendInt(segment);
198203
currentDocsBuilder.appendInt(doc - leafContexts.get(segment).docBase); // the offset inside the segment
204+
consumeScore(scoreDocs[i], currentScoresBuilder);
199205
}
200206

201207
shard = blockFactory.newConstantIntBlockWith(perShardCollector.shardContext.index(), size);
202208
segments = currentSegmentBuilder.build();
203209
docs = currentDocsBuilder.build();
204-
page = new Page(size, new DocVector(shard.asVector(), segments, docs, null).asBlock());
210+
docBlock = new DocVector(shard.asVector(), segments, docs, null).asBlock();
211+
page = maybeAppendScore(new Page(size, docBlock), currentScoresBuilder);
205212
} finally {
206213
if (page == null) {
207214
Releasables.closeExpectNoException(shard, segments, docs);
@@ -211,20 +218,41 @@ private Page emit(boolean startEmitting) {
211218
return page;
212219
}
213220

221+
protected DoubleVector.Builder scoreVectorOrNull(int size) {
222+
return null; // no scoring
223+
}
224+
225+
protected void consumeScore(ScoreDoc scoreDoc, DoubleVector.Builder currentScoresBuilder) {
226+
// no scoring
227+
assert currentScoresBuilder == null;
228+
}
229+
230+
protected Page maybeAppendScore(Page page, DoubleVector.Builder currentScoresBuilder) {
231+
// no scoring
232+
assert currentScoresBuilder == null;
233+
return page;
234+
}
235+
214236
@Override
215237
protected void describe(StringBuilder sb) {
216238
sb.append(", limit = ").append(limit);
217239
String notPrettySorts = sorts.stream().map(Strings::toString).collect(Collectors.joining(","));
218240
sb.append(", sorts = [").append(notPrettySorts).append("]");
219241
}
220242

221-
static final class PerShardCollector {
222-
private final ShardContext shardContext;
223-
private final TopFieldCollector topFieldCollector;
243+
PerShardCollector newPerShardCollector(ShardContext shardContext, List<SortBuilder<?>> sorts, int limit) throws IOException {
244+
return new PerShardCollector(shardContext, sorts, limit);
245+
}
246+
247+
static class PerShardCollector {
248+
protected ShardContext shardContext;
249+
protected TopDocsCollector<?> collector;
224250
private int leafIndex;
225251
private LeafCollector leafCollector;
226252
private Thread currentThread;
227253

254+
PerShardCollector() {}
255+
228256
PerShardCollector(ShardContext shardContext, List<SortBuilder<?>> sorts, int limit) throws IOException {
229257
this.shardContext = shardContext;
230258
Optional<SortAndFormats> sortAndFormats = shardContext.buildSort(sorts);
@@ -233,16 +261,17 @@ static final class PerShardCollector {
233261
}
234262

235263
// We don't use CollectorManager here as we don't retrieve the total hits and sort by score.
236-
this.topFieldCollector = new TopFieldCollectorManager(sortAndFormats.get().sort, limit, null, 0, false).newCollector();
264+
this.collector = new TopFieldCollectorManager(sortAndFormats.get().sort, limit, null, 0, false).newCollector();
237265
}
238266

239267
LeafCollector getLeafCollector(LeafReaderContext leafReaderContext) throws IOException {
240268
if (currentThread != Thread.currentThread() || leafIndex != leafReaderContext.ord) {
241-
leafCollector = topFieldCollector.getLeafCollector(leafReaderContext);
269+
leafCollector = collector.getLeafCollector(leafReaderContext);
242270
leafIndex = leafReaderContext.ord;
243271
currentThread = Thread.currentThread();
244272
}
245273
return leafCollector;
246274
}
275+
247276
}
248277
}

0 commit comments

Comments
 (0)