Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/121322.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 121322
summary: "[PoC 2] ESQL - Add scoring for full text functions disjunctions using `ExpressionEvaluator`"
area: ES|QL
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@
import org.elasticsearch.compute.data.BooleanVector;
import org.elasticsearch.compute.data.DocBlock;
import org.elasticsearch.compute.data.DocVector;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.DoubleVector;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.Tuple;

import java.io.IOException;
import java.io.UncheckedIOException;
Expand All @@ -41,16 +44,20 @@
* this evaluator is here to save the day.
*/
public class LuceneQueryExpressionEvaluator implements EvalOperator.ExpressionEvaluator {

public record ShardConfig(Query query, IndexSearcher searcher) {}

private final BlockFactory blockFactory;
private final ShardConfig[] shards;
private final boolean usesScoring;

private ShardState[] perShardState = EMPTY_SHARD_STATES;
private DoubleVector scoreVector;

public LuceneQueryExpressionEvaluator(BlockFactory blockFactory, ShardConfig[] shards) {
public LuceneQueryExpressionEvaluator(BlockFactory blockFactory, ShardConfig[] shards, boolean usesScoring) {
this.blockFactory = blockFactory;
this.shards = shards;
this.usesScoring = usesScoring;
}

@Override
Expand All @@ -60,16 +67,26 @@ public Block eval(Page page) {
assert block instanceof DocBlock : "LuceneQueryExpressionEvaluator expects DocBlock as input";
DocVector docs = (DocVector) block.asVector();
try {
final Tuple<BooleanVector, DoubleVector> evalResult;
if (docs.singleSegmentNonDecreasing()) {
return evalSingleSegmentNonDecreasing(docs).asBlock();
evalResult = evalSingleSegmentNonDecreasing(docs);
} else {
return evalSlow(docs).asBlock();
evalResult = evalSlow(docs);
}
// Cache the score vector for later use
scoreVector = evalResult.v2();
return evalResult.v1().asBlock();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

@Override
public DoubleBlock score(Page page, BlockFactory blockFactory) {
assert scoreVector != null : "eval() should be invoked before calling score()";
return scoreVector.asBlock();
}

/**
* Evaluate {@link DocVector#singleSegmentNonDecreasing()} documents.
* <p>
Expand Down Expand Up @@ -100,7 +117,7 @@ public Block eval(Page page) {
* common.
* </p>
*/
private BooleanVector evalSingleSegmentNonDecreasing(DocVector docs) throws IOException {
private Tuple<BooleanVector, DoubleVector> evalSingleSegmentNonDecreasing(DocVector docs) throws IOException {
ShardState shardState = shardState(docs.shards().getInt(0));
SegmentState segmentState = shardState.segmentState(docs.segments().getInt(0));
int min = docs.docs().getInt(0);
Expand All @@ -126,13 +143,16 @@ private BooleanVector evalSingleSegmentNonDecreasing(DocVector docs) throws IOEx
* the order that the {@link DocVector} came in.
* </p>
*/
private BooleanVector evalSlow(DocVector docs) throws IOException {
private Tuple<BooleanVector, DoubleVector> evalSlow(DocVector docs) throws IOException {
int[] map = docs.shardSegmentDocMapForwards();
// Clear any state flags from the previous run
int prevShard = -1;
int prevSegment = -1;
SegmentState segmentState = null;
try (BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount())) {
try (
BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount());
DoubleVector.Builder scoreBuilder = blockFactory.newDoubleVectorFixedBuilder(docs.getPositionCount())
) {
for (int i = 0; i < docs.getPositionCount(); i++) {
int shard = docs.shards().getInt(docs.shards().getInt(map[i]));
int segment = docs.segments().getInt(map[i]);
Expand All @@ -144,19 +164,26 @@ private BooleanVector evalSlow(DocVector docs) throws IOException {
}
if (segmentState.noMatch) {
builder.appendBoolean(false);
scoreBuilder.appendDouble(NO_MATCH_SCORE);
} else {
segmentState.scoreSingleDocWithScorer(builder, docs.docs().getInt(map[i]));
segmentState.scoreSingleDocWithScorer(builder, scoreBuilder, docs.docs().getInt(map[i]));
}
}
try (BooleanVector outOfOrder = builder.build()) {
return outOfOrder.filter(docs.shardSegmentDocMapBackwards());
try (BooleanVector outOfOrder = builder.build(); DoubleVector outOfOrderScores = scoreBuilder.build()) {
return Tuple.tuple(
outOfOrder.filter(docs.shardSegmentDocMapBackwards()),
outOfOrderScores.filter(docs.shardSegmentDocMapBackwards())
);
}
}
}

@Override
public void close() {

if ((scoreVector != null) && scoreVector.isReleased() == false) {
// Scores may not be retrieved calling score(), for example with NOT. Try to free up the score vector
Releasables.closeExpectNoException(scoreVector);
}
}

private ShardState shardState(int shard) throws IOException {
Expand All @@ -175,7 +202,9 @@ private class ShardState {
private SegmentState[] perSegmentState = EMPTY_SEGMENT_STATES;

ShardState(ShardConfig config) throws IOException {
weight = config.searcher.createWeight(config.query, ScoreMode.COMPLETE_NO_SCORES, 0.0f);
float boost = usesScoring ? 1.0f : 0.0f;
ScoreMode scoreMode = usesScoring ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
weight = config.searcher.createWeight(config.query, scoreMode, boost);
searcher = config.searcher;
}

Expand Down Expand Up @@ -231,10 +260,13 @@ private SegmentState(Weight weight, LeafReaderContext ctx) {
* Score a range using the {@link BulkScorer}. This should be faster
* than using {@link #scoreSparse} for dense doc ids.
*/
BooleanVector scoreDense(int min, int max) throws IOException {
Tuple<BooleanVector, DoubleVector> scoreDense(int min, int max) throws IOException {
int length = max - min + 1;
if (noMatch) {
return blockFactory.newConstantBooleanVector(false, length);
return Tuple.tuple(
blockFactory.newConstantBooleanVector(false, length),
blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, length)
);
}
if (bulkScorer == null || // The bulkScorer wasn't initialized
Thread.currentThread() != bulkScorerThread // The bulkScorer was initialized on a different thread
Expand All @@ -243,29 +275,45 @@ BooleanVector scoreDense(int min, int max) throws IOException {
bulkScorer = weight.bulkScorer(ctx);
if (bulkScorer == null) {
noMatch = true;
return blockFactory.newConstantBooleanVector(false, length);
return Tuple.tuple(
blockFactory.newConstantBooleanVector(false, length),
blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, length)
);
}
}
try (DenseCollector collector = new DenseCollector(blockFactory, min, max)) {

final DenseCollector collector;
if (usesScoring) {
collector = new ScoringDenseCollector(blockFactory, min, max);
} else {
collector = new DenseCollector(blockFactory, min, max);
}
try (collector) {
bulkScorer.score(collector, ctx.reader().getLiveDocs(), min, max + 1);
return collector.build();
return Tuple.tuple(collector.buildMatchVector(), collector.buildScoreVector());
}
}

/**
* Score a vector of doc ids using {@link Scorer}. If you have a dense range of
* doc ids it'd be faster to use {@link #scoreDense}.
*/
BooleanVector scoreSparse(IntVector docs) throws IOException {
Tuple<BooleanVector, DoubleVector> scoreSparse(IntVector docs) throws IOException {
initScorer(docs.getInt(0));
if (noMatch) {
return blockFactory.newConstantBooleanVector(false, docs.getPositionCount());
return Tuple.tuple(
blockFactory.newConstantBooleanVector(false, docs.getPositionCount()),
blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, docs.getPositionCount())
);
}
try (BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount())) {
try (
BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount());
DoubleVector.Builder scoreBuilder = blockFactory.newDoubleVectorFixedBuilder(docs.getPositionCount())
) {
for (int i = 0; i < docs.getPositionCount(); i++) {
scoreSingleDocWithScorer(builder, docs.getInt(i));
scoreSingleDocWithScorer(builder, scoreBuilder, docs.getInt(i));
}
return builder.build();
return Tuple.tuple(builder.build(), scoreBuilder.build());
}
}

Expand All @@ -285,7 +333,8 @@ private void initScorer(int minDocId) throws IOException {
}
}

private void scoreSingleDocWithScorer(BooleanVector.Builder builder, int doc) throws IOException {
private void scoreSingleDocWithScorer(BooleanVector.Builder builder, DoubleVector.Builder scoreBuilder, int doc)
throws IOException {
if (scorer.iterator().docID() == doc) {
builder.appendBoolean(true);
} else if (scorer.iterator().docID() > doc) {
Expand All @@ -305,13 +354,13 @@ private void scoreSingleDocWithScorer(BooleanVector.Builder builder, int doc) th
* which isn't documented, but @jpountz swears is true.
*/
static class DenseCollector implements LeafCollector, Releasable {
private final BooleanVector.FixedBuilder builder;
private final BooleanVector.FixedBuilder matchBuilder;
private final int max;

int next;

DenseCollector(BlockFactory blockFactory, int min, int max) {
this.builder = blockFactory.newBooleanVectorFixedBuilder(max - min + 1);
this.matchBuilder = blockFactory.newBooleanVectorFixedBuilder(max - min + 1);
this.max = max;
next = min;
}
Expand All @@ -320,40 +369,92 @@ static class DenseCollector implements LeafCollector, Releasable {
public void setScorer(Scorable scorable) {}

@Override
public void collect(int doc) {
public void collect(int doc) throws IOException {
while (next++ < doc) {
builder.appendBoolean(false);
appendNoMatch();
}
builder.appendBoolean(true);
appendMatch();
}

protected void appendMatch() throws IOException {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why appendMatch() throws IOException, but appendNoMatch() doesn't?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks weird, right?

appendMatch() reads the score from the Scorable, which can result in an IOException. appendNoMatch() does not need to read the scoring.

matchBuilder.appendBoolean(true);
}

protected void appendNoMatch() {
matchBuilder.appendBoolean(false);
}

public BooleanVector build() {
return builder.build();
public BooleanVector buildMatchVector() {
return matchBuilder.build();
}

public DoubleVector buildScoreVector() {
return null;
}

@Override
public void finish() {
while (next++ <= max) {
builder.appendBoolean(false);
appendNoMatch();
}
}

@Override
public void close() {
Releasables.closeExpectNoException(builder);
Releasables.closeExpectNoException(matchBuilder);
}
}

static class ScoringDenseCollector extends DenseCollector {
private final DoubleVector.FixedBuilder scoreBuilder;
private Scorable scorable;

ScoringDenseCollector(BlockFactory blockFactory, int min, int max) {
super(blockFactory, min, max);
this.scoreBuilder = blockFactory.newDoubleVectorFixedBuilder(max - min + 1);
}

@Override
public void setScorer(Scorable scorable) {
this.scorable = scorable;
}

@Override
protected void appendMatch() throws IOException {
super.appendMatch();
scoreBuilder.appendDouble(scorable.score());
}

@Override
protected void appendNoMatch() {
super.appendNoMatch();
scoreBuilder.appendDouble(NO_MATCH_SCORE);
}

@Override
public DoubleVector buildScoreVector() {
return scoreBuilder.build();
}

@Override
public void close() {
super.close();
Releasables.closeExpectNoException(scoreBuilder);
}
}

public static class Factory implements EvalOperator.ExpressionEvaluator.Factory {
private final ShardConfig[] shardConfigs;
private final boolean usesScoring;

public Factory(ShardConfig[] shardConfigs) {
public Factory(ShardConfig[] shardConfigs, boolean usesScoring) {
this.shardConfigs = shardConfigs;
this.usesScoring = usesScoring;
}

@Override
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
return new LuceneQueryExpressionEvaluator(context.blockFactory(), shardConfigs);
return new LuceneQueryExpressionEvaluator(context.blockFactory(), shardConfigs, usesScoring);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
Expand Down Expand Up @@ -60,6 +61,12 @@ public void close() {
* Evaluates an expression {@code a + b} or {@code log(c)} one {@link Page} at a time.
*/
public interface ExpressionEvaluator extends Releasable {

/**
* Scoring value when an expression does not match (return false) or does not contribute to the score
*/
double NO_MATCH_SCORE = 0.0;

/** A Factory for creating ExpressionEvaluators. */
interface Factory {
ExpressionEvaluator get(DriverContext context);
Expand All @@ -81,6 +88,14 @@ default boolean eagerEvalSafeInLazy() {
* @return the returned Block has its own reference and the caller is responsible for releasing it.
*/
Block eval(Page page);

/**
* Retrieves the score for the expression.
* Only expressions that can be scored, or that can combine scores, should override this method.
*/
default DoubleBlock score(Page page, BlockFactory blockFactory) {
return blockFactory.newConstantDoubleBlockWith(NO_MATCH_SCORE, page.getPositionCount());
}
}

public static final ExpressionEvaluator.Factory CONSTANT_NULL_FACTORY = new ExpressionEvaluator.Factory() {
Expand Down
Loading