Skip to content

Commit 0a3043b

Browse files
committed
Improve handling of scores in boolean operations
1 parent 4f291ad commit 0a3043b

File tree

5 files changed

+80
-30
lines changed

5 files changed

+80
-30
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryExpressionEvaluator.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@
4545
*/
4646
public class LuceneQueryExpressionEvaluator implements EvalOperator.ExpressionEvaluator {
4747

48-
public static final float NO_MATCH_SCORE = -1.0F;
49-
5048
public record ShardConfig(Query query, IndexSearcher searcher) {}
5149

5250
private final BlockFactory blockFactory;
@@ -181,8 +179,7 @@ private Tuple<BooleanVector, DoubleVector> evalSlow(DocVector docs) throws IOExc
181179
}
182180

183181
@Override
184-
public void close() {
185-
}
182+
public void close() {}
186183

187184
private ShardState shardState(int shard) throws IOException {
188185
if (shard >= perShardState.length) {

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ public void close() {
6161
* Evaluates an expression {@code a + b} or {@code log(c)} one {@link Page} at a time.
6262
*/
6363
public interface ExpressionEvaluator extends Releasable {
64+
65+
/**
66+
* Scoring value when an expression does not match (return false)
67+
*/
68+
double NO_MATCH_SCORE = -1.0;
69+
70+
/**
71+
* Default score for expressions that match (return true) but are not full text functions
72+
*/
73+
double MATCH_SCORE = 0.0;
74+
6475
/** A Factory for creating ExpressionEvaluators. */
6576
interface Factory {
6677
ExpressionEvaluator get(DriverContext context);
@@ -84,11 +95,11 @@ default boolean eagerEvalSafeInLazy() {
8495
Block eval(Page page);
8596

8697
/**
87-
* Retrieves the score for the expression
98+
* Retrieves the score for the expression.
8899
* Only expressions that can be scored, or that can combine scores, should override this method
89100
*/
90101
default DoubleBlock score(Page page, BlockFactory blockFactory) {
91-
return blockFactory.newConstantDoubleBlockWith(0.0, page.getPositionCount());
102+
return blockFactory.newConstantDoubleBlockWith(MATCH_SCORE, page.getPositionCount());
92103
}
93104
}
94105

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/MatchFunctionIT.java

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -283,10 +283,10 @@ public void testDisjunctionScoring() {
283283
assertThat(values.get(2).get(0), equalTo(2));
284284

285285
// Matches full text query and non pushable query
286-
assertThat((Double) values.get(0).get(1), greaterThan(1.0));
287-
assertThat((Double) values.get(1).get(1), greaterThan(1.0));
286+
assertThat((Double) values.get(0).get(1), greaterThan(0.0));
287+
assertThat((Double) values.get(1).get(1), greaterThan(0.0));
288288
// Matches just non pushable query
289-
assertThat((Double) values.get(2).get(1), equalTo(1.0));
289+
assertThat((Double) values.get(2).get(1), equalTo(0.0));
290290
}
291291
}
292292

@@ -308,11 +308,11 @@ public void testDisjunctionScoringMultipleNonPushableFunctions() {
308308
assertThat(values.get(1).get(0), equalTo(6));
309309

310310
// Matches the full text query and a two pushable query
311-
assertThat((Double) values.get(0).get(1), greaterThan(2.0));
312-
assertThat((Double) values.get(0).get(1), lessThan(3.0));
311+
assertThat((Double) values.get(0).get(1), greaterThan(1.0));
312+
assertThat((Double) values.get(0).get(1), lessThan(2.0));
313313
// Matches just the match function
314-
assertThat((Double) values.get(1).get(1), lessThan(2.0));
315-
assertThat((Double) values.get(1).get(1), greaterThan(1.0));
314+
assertThat((Double) values.get(1).get(1), lessThan(1.0));
315+
assertThat((Double) values.get(1).get(1), greaterThan(0.0));
316316
}
317317
}
318318

@@ -334,10 +334,10 @@ public void testDisjunctionScoringWithNot() {
334334
assertThat(values.get(1).get(0), equalTo(4));
335335
assertThat(values.get(2).get(0), equalTo(5));
336336

337-
// Matches NOT gets 0.0 and default score is 1.0
338-
assertThat((Double) values.get(0).get(1), equalTo(1.0));
339-
assertThat((Double) values.get(1).get(1), equalTo(1.0));
340-
assertThat((Double) values.get(2).get(1), equalTo(1.0));
337+
// Matches NOT gets 0.0
338+
assertThat((Double) values.get(0).get(1), equalTo(0.0));
339+
assertThat((Double) values.get(1).get(1), equalTo(0.0));
340+
assertThat((Double) values.get(2).get(1), equalTo(0.0));
341341
}
342342
}
343343

@@ -357,8 +357,8 @@ public void testScoringWithNoFullTextFunction() {
357357

358358
assertThat(values.get(0).get(0), equalTo(4));
359359

360-
// Non pushable query gets score of 0.0, summed with 1.0 coming from Lucene
361-
assertThat((Double) values.get(0).get(1), equalTo(1.0));
360+
// Non pushable query gets score of 0.0
361+
assertThat((Double) values.get(0).get(1), equalTo(0.0));
362362
}
363363
}
364364

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/EvalMapper.java

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@
4141

4242
import java.util.List;
4343

44-
import static org.elasticsearch.compute.lucene.LuceneQueryExpressionEvaluator.NO_MATCH_SCORE;
45-
4644
public final class EvalMapper {
4745

4846
private static final List<ExpressionMapper<?>> MAPPERS = List.of(
@@ -192,13 +190,7 @@ public DoubleBlock score(Page page, BlockFactory blockFactory) {
192190
for (int p = 0; p < positionCount; p++) {
193191
double l = lhs.getDouble(p);
194192
double r = rhs.getDouble(p);
195-
if (l == NO_MATCH_SCORE) {
196-
result.appendDouble(p, r);
197-
} else if (r == NO_MATCH_SCORE) {
198-
result.appendDouble(p, l);
199-
} else {
200-
result.appendDouble(p, l + r);
201-
}
193+
result.appendDouble(bl.function().score(l, r));
202194
}
203195
return result.build().asBlock();
204196
}
@@ -239,7 +231,7 @@ public DoubleBlock score(Page page, BlockFactory blockFactory) {
239231
DoubleVector.Builder result = blockFactory.newDoubleVectorFixedBuilder(page.getPositionCount());
240232
// TODO We could optimize for constant vectors
241233
for (int i = 0; i < scoreVector.getPositionCount(); i++) {
242-
result.appendDouble(scoreVector.getDouble(i) == NO_MATCH_SCORE ? 0.0 : NO_MATCH_SCORE);
234+
result.appendDouble(scoreVector.getDouble(i) == NO_MATCH_SCORE ? MATCH_SCORE : NO_MATCH_SCORE);
243235
}
244236
return result.build().asBlock();
245237
}
@@ -408,6 +400,22 @@ public Block eval(Page page) {
408400
}
409401
}
410402

403+
@Override
404+
public DoubleBlock score(Page page, BlockFactory blockFactory) {
405+
// TODO We could skip re-evaluating the field if we store the result of eval()
406+
try (Block fieldBlock = field.eval(page)) {
407+
if (fieldBlock.asVector() != null) {
408+
return driverContext.blockFactory().newConstantDoubleBlockWith(NO_MATCH_SCORE, page.getPositionCount());
409+
}
410+
try (var builder = driverContext.blockFactory().newDoubleVectorFixedBuilder(page.getPositionCount())) {
411+
for (int p = 0; p < page.getPositionCount(); p++) {
412+
builder.appendDouble(p, fieldBlock.isNull(p) ? MATCH_SCORE : NO_MATCH_SCORE);
413+
}
414+
return builder.build().asBlock();
415+
}
416+
}
417+
}
418+
411419
@Override
412420
public void close() {
413421
Releasables.closeExpectNoException(field);
@@ -463,6 +471,22 @@ public Block eval(Page page) {
463471
}
464472
}
465473

474+
@Override
475+
public DoubleBlock score(Page page, BlockFactory blockFactory) {
476+
// TODO We could skip re-evaluating the field if we store the result of eval()
477+
try (Block fieldBlock = field.eval(page)) {
478+
if (fieldBlock.asVector() != null) {
479+
return driverContext.blockFactory().newConstantDoubleBlockWith(MATCH_SCORE, page.getPositionCount());
480+
}
481+
try (var builder = driverContext.blockFactory().newDoubleVectorFixedBuilder(page.getPositionCount())) {
482+
for (int p = 0; p < page.getPositionCount(); p++) {
483+
builder.appendDouble(p, fieldBlock.isNull(p) ? NO_MATCH_SCORE : MATCH_SCORE);
484+
}
485+
return builder.build().asBlock();
486+
}
487+
}
488+
}
489+
466490
@Override
467491
public void close() {
468492
Releasables.closeExpectNoException(field);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/logical/BinaryLogicOperation.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
import java.util.function.BiFunction;
1313

14+
import static org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator.NO_MATCH_SCORE;
15+
1416
public enum BinaryLogicOperation implements PredicateBiFunction<Boolean, Boolean, Boolean> {
1517

1618
AND((l, r) -> {
@@ -21,6 +23,11 @@ public enum BinaryLogicOperation implements PredicateBiFunction<Boolean, Boolean
2123
return null;
2224
}
2325
return Boolean.logicalAnd(l.booleanValue(), r.booleanValue());
26+
}, (leftScore, rightScore) -> {
27+
if (NO_MATCH_SCORE == leftScore || NO_MATCH_SCORE == rightScore) {
28+
return NO_MATCH_SCORE;
29+
}
30+
return leftScore + rightScore;
2431
}, "AND"),
2532
OR((l, r) -> {
2633
if (Boolean.TRUE.equals(l) || Boolean.TRUE.equals(r)) {
@@ -30,13 +37,20 @@ public enum BinaryLogicOperation implements PredicateBiFunction<Boolean, Boolean
3037
return null;
3138
}
3239
return Boolean.logicalOr(l.booleanValue(), r.booleanValue());
40+
}, (leftScore, rightScore) -> {
41+
if (NO_MATCH_SCORE == leftScore || NO_MATCH_SCORE == rightScore) {
42+
return Math.max(leftScore, rightScore);
43+
}
44+
return leftScore + rightScore;
3345
}, "OR");
3446

3547
private final BiFunction<Boolean, Boolean, Boolean> process;
48+
private final BiFunction<Double, Double, Double> scoreFunction;
3649
private final String symbol;
3750

38-
BinaryLogicOperation(BiFunction<Boolean, Boolean, Boolean> process, String symbol) {
51+
BinaryLogicOperation(BiFunction<Boolean, Boolean, Boolean> process, BiFunction<Double, Double, Double> scoreFunction, String symbol) {
3952
this.process = process;
53+
this.scoreFunction = scoreFunction;
4054
this.symbol = symbol;
4155
}
4256

@@ -50,6 +64,10 @@ public Boolean apply(Boolean left, Boolean right) {
5064
return process.apply(left, right);
5165
}
5266

67+
public Double score(Double leftScore, Double rightScore) {
68+
return scoreFunction.apply(leftScore, rightScore);
69+
}
70+
5371
@Override
5472
public final Boolean doApply(Boolean left, Boolean right) {
5573
return null;

0 commit comments

Comments
 (0)