Skip to content

Commit edb117a

Browse files
committed
Add score to EvalOperator and LuceneQueryExpressionEvaluator
1 parent fc500d1 commit edb117a

File tree

3 files changed

+143
-42
lines changed

3 files changed

+143
-42
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryExpressionEvaluator.java

Lines changed: 126 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@
2323
import org.elasticsearch.compute.data.BooleanVector;
2424
import org.elasticsearch.compute.data.DocBlock;
2525
import org.elasticsearch.compute.data.DocVector;
26+
import org.elasticsearch.compute.data.DoubleVector;
2627
import org.elasticsearch.compute.data.IntVector;
2728
import org.elasticsearch.compute.data.Page;
2829
import org.elasticsearch.compute.operator.DriverContext;
2930
import org.elasticsearch.compute.operator.EvalOperator;
3031
import org.elasticsearch.core.Releasable;
3132
import org.elasticsearch.core.Releasables;
33+
import org.elasticsearch.core.Tuple;
3234

3335
import java.io.IOException;
3436
import java.io.UncheckedIOException;
@@ -41,16 +43,22 @@
4143
* this evaluator is here to save the day.
4244
*/
4345
public class LuceneQueryExpressionEvaluator implements EvalOperator.ExpressionEvaluator {
46+
47+
public static final float NO_MATCH_SCORE = -1.0F;
48+
4449
public record ShardConfig(Query query, IndexSearcher searcher) {}
4550

4651
private final BlockFactory blockFactory;
4752
private final ShardConfig[] shards;
53+
private final boolean usesScoring;
4854

4955
private ShardState[] perShardState = EMPTY_SHARD_STATES;
56+
private DoubleVector scoreVector;
5057

51-
public LuceneQueryExpressionEvaluator(BlockFactory blockFactory, ShardConfig[] shards) {
58+
public LuceneQueryExpressionEvaluator(BlockFactory blockFactory, ShardConfig[] shards, boolean usesScoring) {
5259
this.blockFactory = blockFactory;
5360
this.shards = shards;
61+
this.usesScoring = usesScoring;
5462
}
5563

5664
@Override
@@ -60,16 +68,28 @@ public Block eval(Page page) {
6068
assert block instanceof DocBlock : "LuceneQueryExpressionEvaluator expects DocBlock as input";
6169
DocVector docs = (DocVector) block.asVector();
6270
try {
71+
final Tuple<BooleanVector, DoubleVector> evalResult;
6372
if (docs.singleSegmentNonDecreasing()) {
64-
return evalSingleSegmentNonDecreasing(docs).asBlock();
73+
evalResult = evalSingleSegmentNonDecreasing(docs);
6574
} else {
66-
return evalSlow(docs).asBlock();
75+
evalResult = evalSlow(docs);
6776
}
77+
// Cache the score vector for later use
78+
scoreVector = evalResult.v2();
79+
return evalResult.v1().asBlock();
6880
} catch (IOException e) {
6981
throw new UncheckedIOException(e);
7082
}
7183
}
7284

85+
@Override
86+
public Block score(Page page, BlockFactory blockFactory) {
87+
if (scoreVector == null) {
88+
Releasables.closeExpectNoException(eval(page));
89+
}
90+
return scoreVector.asBlock();
91+
}
92+
7393
/**
7494
* Evaluate {@link DocVector#singleSegmentNonDecreasing()} documents.
7595
* <p>
@@ -100,7 +120,7 @@ public Block eval(Page page) {
100120
* common.
101121
* </p>
102122
*/
103-
private BooleanVector evalSingleSegmentNonDecreasing(DocVector docs) throws IOException {
123+
private Tuple<BooleanVector, DoubleVector> evalSingleSegmentNonDecreasing(DocVector docs) throws IOException {
104124
ShardState shardState = shardState(docs.shards().getInt(0));
105125
SegmentState segmentState = shardState.segmentState(docs.segments().getInt(0));
106126
int min = docs.docs().getInt(0);
@@ -126,13 +146,14 @@ private BooleanVector evalSingleSegmentNonDecreasing(DocVector docs) throws IOEx
126146
* the order that the {@link DocVector} came in.
127147
* </p>
128148
*/
129-
private BooleanVector evalSlow(DocVector docs) throws IOException {
149+
private Tuple<BooleanVector, DoubleVector> evalSlow(DocVector docs) throws IOException {
130150
int[] map = docs.shardSegmentDocMapForwards();
131151
// Clear any state flags from the previous run
132152
int prevShard = -1;
133153
int prevSegment = -1;
134154
SegmentState segmentState = null;
135-
try (BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount())) {
155+
try (BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount());
156+
DoubleVector.Builder scoreBuilder = blockFactory.newDoubleVectorFixedBuilder(docs.getPositionCount())) {
136157
for (int i = 0; i < docs.getPositionCount(); i++) {
137158
int shard = docs.shards().getInt(docs.shards().getInt(map[i]));
138159
int segment = docs.segments().getInt(map[i]);
@@ -144,19 +165,23 @@ private BooleanVector evalSlow(DocVector docs) throws IOException {
144165
}
145166
if (segmentState.noMatch) {
146167
builder.appendBoolean(false);
168+
scoreBuilder.appendDouble(NO_MATCH_SCORE);
147169
} else {
148-
segmentState.scoreSingleDocWithScorer(builder, docs.docs().getInt(map[i]));
170+
segmentState.scoreSingleDocWithScorer(builder, scoreBuilder, docs.docs().getInt(map[i]));
149171
}
150172
}
151-
try (BooleanVector outOfOrder = builder.build()) {
152-
return outOfOrder.filter(docs.shardSegmentDocMapBackwards());
173+
try (BooleanVector outOfOrder = builder.build();DoubleVector outOfOrderScores = scoreBuilder.build()) {
174+
return Tuple.tuple(
175+
outOfOrder.filter(docs.shardSegmentDocMapBackwards()),
176+
outOfOrderScores.filter(docs.shardSegmentDocMapBackwards())
177+
);
153178
}
154179
}
155180
}
156181

157182
@Override
158183
public void close() {
159-
184+
Releasables.closeExpectNoException(scoreVector);
160185
}
161186

162187
private ShardState shardState(int shard) throws IOException {
@@ -175,7 +200,9 @@ private class ShardState {
175200
private SegmentState[] perSegmentState = EMPTY_SEGMENT_STATES;
176201

177202
ShardState(ShardConfig config) throws IOException {
178-
weight = config.searcher.createWeight(config.query, ScoreMode.COMPLETE_NO_SCORES, 0.0f);
203+
float boost = usesScoring ? 1.0f : 0.0f;
204+
ScoreMode scoreMode = usesScoring ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
205+
weight = config.searcher.createWeight(config.query, scoreMode, boost);
179206
searcher = config.searcher;
180207
}
181208

@@ -231,10 +258,11 @@ private SegmentState(Weight weight, LeafReaderContext ctx) {
231258
* Score a range using the {@link BulkScorer}. This should be faster
232259
* than using {@link #scoreSparse} for dense doc ids.
233260
*/
234-
BooleanVector scoreDense(int min, int max) throws IOException {
261+
Tuple<BooleanVector, DoubleVector> scoreDense(int min, int max) throws IOException {
235262
int length = max - min + 1;
236263
if (noMatch) {
237-
return blockFactory.newConstantBooleanVector(false, length);
264+
return Tuple.tuple(blockFactory.newConstantBooleanVector(false, length),
265+
blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, length));
238266
}
239267
if (bulkScorer == null || // The bulkScorer wasn't initialized
240268
Thread.currentThread() != bulkScorerThread // The bulkScorer was initialized on a different thread
@@ -243,29 +271,41 @@ BooleanVector scoreDense(int min, int max) throws IOException {
243271
bulkScorer = weight.bulkScorer(ctx);
244272
if (bulkScorer == null) {
245273
noMatch = true;
246-
return blockFactory.newConstantBooleanVector(false, length);
274+
return Tuple.tuple(blockFactory.newConstantBooleanVector(false, length),
275+
blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, length));
247276
}
248277
}
249-
try (DenseCollector collector = new DenseCollector(blockFactory, min, max)) {
278+
279+
final DenseCollector collector;
280+
if (usesScoring) {
281+
collector = new DenseCollector(blockFactory, min, max);
282+
} else {
283+
collector = new ScoringDenseCollector(blockFactory, min, max);
284+
}
285+
try (collector) {
250286
bulkScorer.score(collector, ctx.reader().getLiveDocs(), min, max + 1);
251-
return collector.build();
287+
return Tuple.tuple(collector.buildMatchVector(), collector.buildScoreVector());
252288
}
253289
}
254290

255291
/**
256292
* Score a vector of doc ids using {@link Scorer}. If you have a dense range of
257293
* doc ids it'd be faster to use {@link #scoreDense}.
258294
*/
259-
BooleanVector scoreSparse(IntVector docs) throws IOException {
295+
Tuple<BooleanVector, DoubleVector> scoreSparse(IntVector docs) throws IOException {
260296
initScorer(docs.getInt(0));
261297
if (noMatch) {
262-
return blockFactory.newConstantBooleanVector(false, docs.getPositionCount());
298+
return Tuple.tuple(
299+
blockFactory.newConstantBooleanVector(false, docs.getPositionCount()),
300+
blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, docs.getPositionCount())
301+
);
263302
}
264-
try (BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount())) {
303+
try (BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount());
304+
DoubleVector.Builder scoreBuilder = blockFactory.newDoubleVectorFixedBuilder(docs.getPositionCount())) {
265305
for (int i = 0; i < docs.getPositionCount(); i++) {
266-
scoreSingleDocWithScorer(builder, docs.getInt(i));
306+
scoreSingleDocWithScorer(builder, scoreBuilder, docs.getInt(i));
267307
}
268-
return builder.build();
308+
return Tuple.tuple(builder.build(), scoreBuilder.build());
269309
}
270310
}
271311

@@ -285,7 +325,8 @@ private void initScorer(int minDocId) throws IOException {
285325
}
286326
}
287327

288-
private void scoreSingleDocWithScorer(BooleanVector.Builder builder, int doc) throws IOException {
328+
private void scoreSingleDocWithScorer(BooleanVector.Builder builder, DoubleVector.Builder scoreBuilder, int doc)
329+
throws IOException {
289330
if (scorer.iterator().docID() == doc) {
290331
builder.appendBoolean(true);
291332
} else if (scorer.iterator().docID() > doc) {
@@ -305,13 +346,13 @@ private void scoreSingleDocWithScorer(BooleanVector.Builder builder, int doc) th
305346
* which isn't documented, but @jpountz swears is true.
306347
*/
307348
static class DenseCollector implements LeafCollector, Releasable {
308-
private final BooleanVector.FixedBuilder builder;
349+
private final BooleanVector.FixedBuilder matchBuilder;
309350
private final int max;
310351

311352
int next;
312353

313354
DenseCollector(BlockFactory blockFactory, int min, int max) {
314-
this.builder = blockFactory.newBooleanVectorFixedBuilder(max - min + 1);
355+
this.matchBuilder = blockFactory.newBooleanVectorFixedBuilder(max - min + 1);
315356
this.max = max;
316357
next = min;
317358
}
@@ -320,40 +361,92 @@ static class DenseCollector implements LeafCollector, Releasable {
320361
public void setScorer(Scorable scorable) {}
321362

322363
@Override
323-
public void collect(int doc) {
364+
public void collect(int doc) throws IOException {
324365
while (next++ < doc) {
325-
builder.appendBoolean(false);
366+
appendNoMatch();
326367
}
327-
builder.appendBoolean(true);
368+
appendMatch();
369+
}
370+
371+
protected void appendMatch() throws IOException {
372+
matchBuilder.appendBoolean(true);
373+
}
374+
375+
protected void appendNoMatch() {
376+
matchBuilder.appendBoolean(false);
328377
}
329378

330-
public BooleanVector build() {
331-
return builder.build();
379+
public BooleanVector buildMatchVector() {
380+
return matchBuilder.build();
381+
}
382+
383+
public DoubleVector buildScoreVector() {
384+
return null;
332385
}
333386

334387
@Override
335388
public void finish() {
336389
while (next++ <= max) {
337-
builder.appendBoolean(false);
390+
appendNoMatch();
338391
}
339392
}
340393

341394
@Override
342395
public void close() {
343-
Releasables.closeExpectNoException(builder);
396+
Releasables.closeExpectNoException(matchBuilder);
397+
}
398+
}
399+
400+
static class ScoringDenseCollector extends DenseCollector {
401+
private final DoubleVector.FixedBuilder scoreBuilder;
402+
private Scorable scorable;
403+
404+
ScoringDenseCollector(BlockFactory blockFactory, int min, int max) {
405+
super(blockFactory, min, max);
406+
this.scoreBuilder = blockFactory.newDoubleVectorFixedBuilder(max - min + 1);
407+
}
408+
409+
@Override
410+
public void setScorer(Scorable scorable) {
411+
this.scorable = scorable;
412+
}
413+
414+
@Override
415+
protected void appendMatch() throws IOException {
416+
super.appendMatch();
417+
scoreBuilder.appendDouble(scorable.score());
418+
}
419+
420+
@Override
421+
protected void appendNoMatch() {
422+
super.appendNoMatch();
423+
scoreBuilder.appendDouble(NO_MATCH_SCORE);
424+
}
425+
426+
@Override
427+
public DoubleVector buildScoreVector() {
428+
return scoreBuilder.build();
429+
}
430+
431+
@Override
432+
public void close() {
433+
super.close();
434+
Releasables.closeExpectNoException(scoreBuilder);
344435
}
345436
}
346437

347438
public static class Factory implements EvalOperator.ExpressionEvaluator.Factory {
348439
private final ShardConfig[] shardConfigs;
440+
private final boolean usesScoring;
349441

350-
public Factory(ShardConfig[] shardConfigs) {
442+
public Factory(ShardConfig[] shardConfigs, boolean usesScoring) {
351443
this.shardConfigs = shardConfigs;
444+
this.usesScoring = usesScoring;
352445
}
353446

354447
@Override
355448
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
356-
return new LuceneQueryExpressionEvaluator(context.blockFactory(), shardConfigs);
449+
return new LuceneQueryExpressionEvaluator(context.blockFactory(), shardConfigs, usesScoring);
357450
}
358451
}
359452
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.compute.data.Block;
1111
import org.elasticsearch.compute.data.BlockFactory;
12+
import org.elasticsearch.compute.data.ConstantNullVector;
1213
import org.elasticsearch.compute.data.Page;
1314
import org.elasticsearch.core.Releasable;
1415
import org.elasticsearch.core.Releasables;
@@ -81,6 +82,14 @@ default boolean eagerEvalSafeInLazy() {
8182
* @return the returned Block has its own reference and the caller is responsible for releasing it.
8283
*/
8384
Block eval(Page page);
85+
86+
/**
87+
* Retrieves the score for the expression
88+
* Only expressions that can be scored, or that can combine scores, should override this method
89+
*/
90+
default Block score(Page page, BlockFactory blockFactory) {
91+
return blockFactory.newConstantDoubleBlockWith(0.0, page.getPositionCount());
92+
}
8493
}
8594

8695
public static final ExpressionEvaluator.Factory CONSTANT_NULL_FACTORY = new ExpressionEvaluator.Factory() {

0 commit comments

Comments
 (0)