Skip to content

Commit adb7e3e

Browse files
committed
Implement usesScoring and solve scoring in binary logic
1 parent edb117a commit adb7e3e

File tree

12 files changed

+166
-45
lines changed

12 files changed

+166
-45
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryExpressionEvaluator.java

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.compute.data.BooleanVector;
2424
import org.elasticsearch.compute.data.DocBlock;
2525
import org.elasticsearch.compute.data.DocVector;
26+
import org.elasticsearch.compute.data.DoubleBlock;
2627
import org.elasticsearch.compute.data.DoubleVector;
2728
import org.elasticsearch.compute.data.IntVector;
2829
import org.elasticsearch.compute.data.Page;
@@ -83,7 +84,7 @@ public Block eval(Page page) {
8384
}
8485

8586
@Override
86-
public Block score(Page page, BlockFactory blockFactory) {
87+
public DoubleBlock score(Page page, BlockFactory blockFactory) {
8788
if (scoreVector == null) {
8889
Releasables.closeExpectNoException(eval(page));
8990
}
@@ -152,8 +153,10 @@ private Tuple<BooleanVector, DoubleVector> evalSlow(DocVector docs) throws IOExc
152153
int prevShard = -1;
153154
int prevSegment = -1;
154155
SegmentState segmentState = null;
155-
try (BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount());
156-
DoubleVector.Builder scoreBuilder = blockFactory.newDoubleVectorFixedBuilder(docs.getPositionCount())) {
156+
try (
157+
BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount());
158+
DoubleVector.Builder scoreBuilder = blockFactory.newDoubleVectorFixedBuilder(docs.getPositionCount())
159+
) {
157160
for (int i = 0; i < docs.getPositionCount(); i++) {
158161
int shard = docs.shards().getInt(docs.shards().getInt(map[i]));
159162
int segment = docs.segments().getInt(map[i]);
@@ -170,7 +173,7 @@ private Tuple<BooleanVector, DoubleVector> evalSlow(DocVector docs) throws IOExc
170173
segmentState.scoreSingleDocWithScorer(builder, scoreBuilder, docs.docs().getInt(map[i]));
171174
}
172175
}
173-
try (BooleanVector outOfOrder = builder.build();DoubleVector outOfOrderScores = scoreBuilder.build()) {
176+
try (BooleanVector outOfOrder = builder.build(); DoubleVector outOfOrderScores = scoreBuilder.build()) {
174177
return Tuple.tuple(
175178
outOfOrder.filter(docs.shardSegmentDocMapBackwards()),
176179
outOfOrderScores.filter(docs.shardSegmentDocMapBackwards())
@@ -261,8 +264,10 @@ private SegmentState(Weight weight, LeafReaderContext ctx) {
261264
Tuple<BooleanVector, DoubleVector> scoreDense(int min, int max) throws IOException {
262265
int length = max - min + 1;
263266
if (noMatch) {
264-
return Tuple.tuple(blockFactory.newConstantBooleanVector(false, length),
265-
blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, length));
267+
return Tuple.tuple(
268+
blockFactory.newConstantBooleanVector(false, length),
269+
blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, length)
270+
);
266271
}
267272
if (bulkScorer == null || // The bulkScorer wasn't initialized
268273
Thread.currentThread() != bulkScorerThread // The bulkScorer was initialized on a different thread
@@ -271,8 +276,10 @@ Tuple<BooleanVector, DoubleVector> scoreDense(int min, int max) throws IOExcepti
271276
bulkScorer = weight.bulkScorer(ctx);
272277
if (bulkScorer == null) {
273278
noMatch = true;
274-
return Tuple.tuple(blockFactory.newConstantBooleanVector(false, length),
275-
blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, length));
279+
return Tuple.tuple(
280+
blockFactory.newConstantBooleanVector(false, length),
281+
blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, length)
282+
);
276283
}
277284
}
278285

@@ -300,8 +307,10 @@ Tuple<BooleanVector, DoubleVector> scoreSparse(IntVector docs) throws IOExceptio
300307
blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, docs.getPositionCount())
301308
);
302309
}
303-
try (BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount());
304-
DoubleVector.Builder scoreBuilder = blockFactory.newDoubleVectorFixedBuilder(docs.getPositionCount())) {
310+
try (
311+
BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount());
312+
DoubleVector.Builder scoreBuilder = blockFactory.newDoubleVectorFixedBuilder(docs.getPositionCount())
313+
) {
305314
for (int i = 0; i < docs.getPositionCount(); i++) {
306315
scoreSingleDocWithScorer(builder, scoreBuilder, docs.getInt(i));
307316
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import org.elasticsearch.compute.data.Block;
1111
import org.elasticsearch.compute.data.BlockFactory;
12-
import org.elasticsearch.compute.data.ConstantNullVector;
12+
import org.elasticsearch.compute.data.DoubleBlock;
1313
import org.elasticsearch.compute.data.Page;
1414
import org.elasticsearch.core.Releasable;
1515
import org.elasticsearch.core.Releasables;
@@ -87,7 +87,7 @@ default boolean eagerEvalSafeInLazy() {
8787
* Retrieves the score for the expression
8888
* Only expressions that can be scored, or that can combine scores, should override this method
8989
*/
90-
default Block score(Page page, BlockFactory blockFactory) {
90+
default DoubleBlock score(Page page, BlockFactory blockFactory) {
9191
return blockFactory.newConstantDoubleBlockWith(0.0, page.getPositionCount());
9292
}
9393
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FilterOperator.java

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
package org.elasticsearch.compute.operator;
99

1010
import org.elasticsearch.compute.data.Block;
11+
import org.elasticsearch.compute.data.BlockFactory;
1112
import org.elasticsearch.compute.data.BooleanBlock;
13+
import org.elasticsearch.compute.data.DoubleBlock;
14+
import org.elasticsearch.compute.data.DoubleVector;
1215
import org.elasticsearch.compute.data.Page;
1316
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
1417
import org.elasticsearch.core.Releasables;
@@ -17,13 +20,17 @@
1720

1821
public class FilterOperator extends AbstractPageMappingOperator {
1922

23+
public static final int SCORE_BLOCK_INDEX = 1;
24+
2025
private final EvalOperator.ExpressionEvaluator evaluator;
26+
private final boolean usesScoring;
27+
private final BlockFactory blockFactory;
2128

22-
public record FilterOperatorFactory(ExpressionEvaluator.Factory evaluatorSupplier) implements OperatorFactory {
29+
public record FilterOperatorFactory(ExpressionEvaluator.Factory evaluatorSupplier, boolean usesScoring) implements OperatorFactory {
2330

2431
@Override
2532
public Operator get(DriverContext driverContext) {
26-
return new FilterOperator(evaluatorSupplier.get(driverContext));
33+
return new FilterOperator(evaluatorSupplier.get(driverContext), usesScoring, driverContext.blockFactory());
2734
}
2835

2936
@Override
@@ -32,30 +39,33 @@ public String describe() {
3239
}
3340
}
3441

35-
public FilterOperator(EvalOperator.ExpressionEvaluator evaluator) {
42+
public FilterOperator(ExpressionEvaluator evaluator, boolean usesScoring, BlockFactory blockFactory) {
3643
this.evaluator = evaluator;
44+
this.usesScoring = usesScoring;
45+
this.blockFactory = blockFactory;
3746
}
3847

3948
@Override
4049
protected Page process(Page page) {
4150
int rowCount = 0;
4251
int[] positions = new int[page.getPositionCount()];
4352

44-
try (BooleanBlock test = (BooleanBlock) evaluator.eval(page)) {
45-
if (test.areAllValuesNull()) {
53+
try (BooleanBlock filterResultBlock = (BooleanBlock) evaluator.eval(page)) {
54+
if (filterResultBlock.areAllValuesNull()) {
4655
// All results are null which is like false. No values selected.
4756
page.releaseBlocks();
4857
return null;
4958
}
59+
5060
// TODO we can detect constant true or false from the type
5161
// TODO or we could make a new method in bool-valued evaluators that returns a list of numbers
5262
for (int p = 0; p < page.getPositionCount(); p++) {
53-
if (test.isNull(p) || test.getValueCount(p) != 1) {
63+
if (filterResultBlock.isNull(p) || filterResultBlock.getValueCount(p) != 1) {
5464
// Null is like false
5565
// And, for now, multivalued results are like false too
5666
continue;
5767
}
58-
if (test.getBoolean(test.getFirstValueIndex(p))) {
68+
if (filterResultBlock.getBoolean(filterResultBlock.getFirstValueIndex(p))) {
5969
positions[rowCount++] = p;
6070
}
6171
}
@@ -64,7 +74,14 @@ protected Page process(Page page) {
6474
page.releaseBlocks();
6575
return null;
6676
}
67-
if (rowCount == page.getPositionCount()) {
77+
final DoubleBlock scoreBlock;
78+
if (usesScoring) {
79+
scoreBlock = evaluator.score(page, blockFactory);
80+
} else {
81+
scoreBlock = null;
82+
}
83+
84+
if (rowCount == page.getPositionCount() && (usesScoring == false || scoreBlock.asVector().isConstant())) {
6885
return page;
6986
}
7087
positions = Arrays.copyOf(positions, rowCount);
@@ -73,10 +90,15 @@ protected Page process(Page page) {
7390
boolean success = false;
7491
try {
7592
for (int i = 0; i < page.getBlockCount(); i++) {
76-
filteredBlocks[i] = page.getBlock(i).filter(positions);
93+
if (usesScoring && i == SCORE_BLOCK_INDEX) {
94+
filteredBlocks[i] = createScoresBlock(rowCount, page.getBlock(i), scoreBlock, positions);
95+
} else {
96+
filteredBlocks[i] = page.getBlock(i).filter(positions);
97+
}
7798
}
7899
success = true;
79100
} finally {
101+
Releasables.closeExpectNoException(scoreBlock);
80102
page.releaseBlocks();
81103
if (success == false) {
82104
Releasables.closeExpectNoException(filteredBlocks);
@@ -86,6 +108,15 @@ protected Page process(Page page) {
86108
}
87109
}
88110

111+
private Block createScoresBlock(int rowCount, DoubleBlock originalScoreBlock, DoubleBlock newScoreBlock, int[] positions) {
112+
// Create a new scores block with the retrieved scores, that will replace the existing one on the result page
113+
DoubleVector.Builder updatedScoresBuilder = blockFactory.newDoubleVectorBuilder(rowCount);
114+
for (int j = 0; j < rowCount; j++) {
115+
updatedScoresBuilder.appendDouble(originalScoreBlock.getDouble(positions[j]) + newScoreBlock.getDouble(positions[j]));
116+
}
117+
return updatedScoresBuilder.build().asBlock().filter(positions);
118+
}
119+
89120
@Override
90121
public String toString() {
91122
return "FilterOperator[" + "evaluator=" + evaluator + ']';

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryExpressionEvaluatorTests.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,9 @@ private List<Page> runQuery(Set<String> values, Query query, boolean shuffleDocs
184184
);
185185
LuceneQueryExpressionEvaluator luceneQueryEvaluator = new LuceneQueryExpressionEvaluator(
186186
blockFactory,
187-
new LuceneQueryExpressionEvaluator.ShardConfig[] { shard }, false);
187+
new LuceneQueryExpressionEvaluator.ShardConfig[] { shard },
188+
false
189+
);
188190

189191
List<Operator> operators = new ArrayList<>();
190192
if (shuffleDocs) {

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/FilterOperatorTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public EvalOperator.ExpressionEvaluator get(DriverContext context) {
6666
public String toString() {
6767
return "SameLastDigit[lhs=0, rhs=1]";
6868
}
69-
});
69+
}, false);
7070
}
7171

7272
@Override
@@ -113,7 +113,7 @@ public void testReadFromBlock() {
113113
new SequenceBooleanBlockSourceOperator(context.blockFactory(), List.of(true, false, true, false))
114114
);
115115
List<Page> results = drive(
116-
new FilterOperator.FilterOperatorFactory(dvrCtx -> new EvalOperatorTests.LoadFromPage(0)).get(context),
116+
new FilterOperator.FilterOperatorFactory(dvrCtx -> new EvalOperatorTests.LoadFromPage(0), false).get(context),
117117
input.iterator(),
118118
context
119119
);

0 commit comments

Comments
 (0)