Skip to content

Commit e182051

Browse files
committed
Continue refactor LuceneQueryExpressionEvaluator to make scoring behaviour and vector returned pluggable
1 parent f76f0a7 commit e182051

File tree

7 files changed

+93
-405
lines changed

7 files changed

+93
-405
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryExpressionEvaluator.java

Lines changed: 83 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.compute.data.BooleanVector;
2424
import org.elasticsearch.compute.data.DocBlock;
2525
import org.elasticsearch.compute.data.DocVector;
26+
import org.elasticsearch.compute.data.DoubleVector;
2627
import org.elasticsearch.compute.data.IntVector;
2728
import org.elasticsearch.compute.data.Page;
2829
import org.elasticsearch.compute.data.Vector;
@@ -50,6 +51,8 @@ public record ShardConfig(Query query, IndexSearcher searcher) {}
5051

5152
private ShardState[] perShardState = EMPTY_SHARD_STATES;
5253

54+
public static final double SCORE_FOR_FALSE = -1.0;
55+
5356
LuceneQueryExpressionEvaluator(BlockFactory blockFactory, ShardConfig[] shards, DocScorerVectorProvider docScorerVectorProvider) {
5457
this.blockFactory = blockFactory;
5558
this.shards = shards;
@@ -177,7 +180,7 @@ private class ShardState {
177180
private SegmentState[] perSegmentState = EMPTY_SEGMENT_STATES;
178181

179182
ShardState(ShardConfig config) throws IOException {
180-
weight = config.searcher.createWeight(config.query, ScoreMode.COMPLETE_NO_SCORES, 0.0f);
183+
weight = config.searcher.createWeight(config.query, docScorerVectorProvider.scoreMode(), 1.0f);
181184
searcher = config.searcher;
182185
}
183186

@@ -233,7 +236,7 @@ private SegmentState(Weight weight, LeafReaderContext ctx) {
233236
* Score a range using the {@link BulkScorer}. This should be faster
234237
* than using {@link #scoreSparse} for dense doc ids.
235238
*/
236-
BooleanVector scoreDense(int min, int max) throws IOException {
239+
Vector scoreDense(int min, int max) throws IOException {
237240
int length = max - min + 1;
238241
if (noMatch) {
239242
return blockFactory.newConstantBooleanVector(false, length);
@@ -261,7 +264,7 @@ BooleanVector scoreDense(int min, int max) throws IOException {
261264
Vector scoreSparse(IntVector docs) throws IOException {
262265
initScorer(docs.getInt(0));
263266
if (noMatch) {
264-
return docScorerVectorProvider.noneMatch(docs.getPositionCount());
267+
return docScorerVectorProvider.noMatchVector(docs.getPositionCount());
265268
}
266269
docScorerVectorProvider.init(docs.getPositionCount());
267270
for (int i = 0; i < docs.getPositionCount(); i++) {
@@ -310,17 +313,16 @@ private void scoreSingleDocWithScorer(int doc) throws IOException {
310313
* which isn't documented, but @jpountz swears is true.
311314
*/
312315
static class DenseCollector implements LeafCollector, Releasable {
313-
private final BooleanVector.FixedBuilder builder;
314316
private final int max;
315317
private final DocScorerVectorProvider docScorerVectorProvider;
316318
private Scorable scorable;
317319

318320
int next;
319321

320322
DenseCollector(BlockFactory blockFactory, DocScorerVectorProvider docScorerVectorProvider, int min, int max) {
321-
this.builder = blockFactory.newBooleanVectorFixedBuilder(max - min + 1);
322-
this.max = max;
323323
this.docScorerVectorProvider = docScorerVectorProvider;
324+
this.docScorerVectorProvider.init(max - min + 1);
325+
this.max = max;
324326
next = min;
325327
}
326328

@@ -330,41 +332,43 @@ public void setScorer(Scorable scorable) {
330332
}
331333

332334
@Override
333-
public void collect(int doc) {
335+
public void collect(int doc) throws IOException {
334336
while (next++ < doc) {
335-
builder.appendBoolean(false);
337+
docScorerVectorProvider.scoreNoHit();
336338
}
337-
builder.appendBoolean(true);
339+
docScorerVectorProvider.scoreHit(scorable);
338340
}
339341

340-
public BooleanVector build() {
341-
return builder.build();
342+
public Vector build() {
343+
return docScorerVectorProvider.build();
342344
}
343345

344346
@Override
345347
public void finish() {
346348
while (next++ <= max) {
347-
builder.appendBoolean(false);
349+
docScorerVectorProvider.scoreNoHit();
348350
}
349351
}
350352

351353
@Override
352354
public void close() {
353-
Releasables.closeExpectNoException(builder);
355+
Releasables.closeExpectNoException(docScorerVectorProvider);
354356
}
355357
}
356358

357359
private interface DocScorerVectorProvider extends Releasable {
358360

359-
Vector noneMatch(int docs);
361+
Vector noMatchVector(int docs);
360362

361363
void init(int numDocs);
362364

363-
void scoreHit(Scorable scorable);
365+
void scoreHit(Scorable scorable) throws IOException;
364366

365367
void scoreNoHit();
366368

367369
Vector build();
370+
371+
ScoreMode scoreMode();
368372
}
369373

370374
static class NonScoringDocScorerVectorProvider implements DocScorerVectorProvider {
@@ -377,7 +381,12 @@ static class NonScoringDocScorerVectorProvider implements DocScorerVectorProvide
377381
}
378382

379383
@Override
380-
public Vector noneMatch(int docs) {
384+
public ScoreMode scoreMode() {
385+
return ScoreMode.COMPLETE_NO_SCORES;
386+
}
387+
388+
@Override
389+
public Vector noMatchVector(int docs) {
381390
return blockFactory.newConstantBooleanVector(false, docs);
382391
}
383392

@@ -400,7 +409,55 @@ public void scoreNoHit() {
400409

401410
@Override
402411
public Vector build() {
412+
assert builder != null : "init must be called before build";
413+
return builder.build();
414+
}
415+
416+
@Override
417+
public void close() {
418+
Releasables.closeExpectNoException(builder);
419+
}
420+
}
421+
422+
static class ScoringDocScorerVectorProvider implements DocScorerVectorProvider {
423+
424+
private final BlockFactory blockFactory;
425+
private DoubleVector.Builder builder;
426+
427+
ScoringDocScorerVectorProvider(BlockFactory blockFactory) {
428+
this.blockFactory = blockFactory;
429+
}
430+
431+
@Override
432+
public ScoreMode scoreMode() {
433+
return ScoreMode.COMPLETE;
434+
}
435+
436+
@Override
437+
public Vector noMatchVector(int docs) {
438+
return blockFactory.newConstantDoubleVector(SCORE_FOR_FALSE, docs);
439+
}
440+
441+
@Override
442+
public void init(int numDocs) {
443+
builder = blockFactory.newDoubleVectorFixedBuilder(numDocs);
444+
}
445+
446+
@Override
447+
public void scoreHit(Scorable scorable) throws IOException {
448+
assert builder != null : "init must be called before scoring";
449+
builder.appendDouble(scorable.score());
450+
}
451+
452+
@Override
453+
public void scoreNoHit() {
403454
assert builder != null : "init must be called before scoring";
455+
builder.appendDouble(SCORE_FOR_FALSE);
456+
}
457+
458+
@Override
459+
public Vector build() {
460+
assert builder != null : "init must be called before build";
404461
return builder.build();
405462
}
406463

@@ -412,18 +469,22 @@ public void close() {
412469

413470
public static class Factory implements EvalOperator.ExpressionEvaluator.Factory {
414471
private final ShardConfig[] shardConfigs;
472+
private final boolean useScoring;
415473

416-
public Factory(ShardConfig[] shardConfigs) {
474+
public Factory(ShardConfig[] shardConfigs, boolean useScoring) {
417475
this.shardConfigs = shardConfigs;
476+
this.useScoring = useScoring;
418477
}
419478

420479
@Override
421480
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
422-
return new LuceneQueryExpressionEvaluator(
423-
context.blockFactory(),
424-
shardConfigs,
425-
new NonScoringDocScorerVectorProvider(context.blockFactory())
426-
);
481+
final DocScorerVectorProvider docScorerVectorProvider;
482+
if (useScoring) {
483+
docScorerVectorProvider = new ScoringDocScorerVectorProvider(context.blockFactory());
484+
} else {
485+
docScorerVectorProvider = new NonScoringDocScorerVectorProvider(context.blockFactory());
486+
}
487+
return new LuceneQueryExpressionEvaluator(context.blockFactory(), shardConfigs, docScorerVectorProvider);
427488
}
428489
}
429490
}

0 commit comments

Comments
 (0)