88package org .elasticsearch .compute .operator ;
99
1010import org .elasticsearch .compute .data .Block ;
11+ import org .elasticsearch .compute .data .BlockFactory ;
1112import org .elasticsearch .compute .data .BooleanBlock ;
13+ import org .elasticsearch .compute .data .DoubleBlock ;
14+ import org .elasticsearch .compute .data .DoubleVector ;
1215import org .elasticsearch .compute .data .Page ;
1316import org .elasticsearch .compute .operator .EvalOperator .ExpressionEvaluator ;
1417import org .elasticsearch .core .Releasables ;
1518
1619import java .util .Arrays ;
1720
21+ import static org .elasticsearch .compute .lucene .LuceneQueryExpressionEvaluator .SCORE_FOR_FALSE ;
22+
1823public class FilterOperator extends AbstractPageMappingOperator {
1924
25+ public static final int SCORE_BLOCK_INDEX = 1 ;
26+
2027 private final EvalOperator .ExpressionEvaluator evaluator ;
28+ private final boolean usesScoring ;
29+ private final BlockFactory blockFactory ;
2130
22- public record FilterOperatorFactory (ExpressionEvaluator .Factory evaluatorSupplier ) implements OperatorFactory {
31+ public record FilterOperatorFactory (ExpressionEvaluator .Factory evaluatorSupplier , boolean usesScoring ) implements OperatorFactory {
2332
2433 @ Override
2534 public Operator get (DriverContext driverContext ) {
26- return new FilterOperator (evaluatorSupplier .get (driverContext ));
35+ return new FilterOperator (evaluatorSupplier .get (driverContext ), usesScoring , driverContext . blockFactory () );
2736 }
2837
2938 @ Override
@@ -32,30 +41,46 @@ public String describe() {
3241 }
3342 }
3443
35- public FilterOperator (EvalOperator . ExpressionEvaluator evaluator ) {
44+ public FilterOperator (ExpressionEvaluator evaluator , boolean usesScoring , BlockFactory blockFactory ) {
3645 this .evaluator = evaluator ;
46+ this .usesScoring = usesScoring ;
47+ this .blockFactory = blockFactory ;
3748 }
3849
3950 @ Override
4051 protected Page process (Page page ) {
4152 int rowCount = 0 ;
4253 int [] positions = new int [page .getPositionCount ()];
4354
44- try (BooleanBlock test = ( BooleanBlock ) evaluator .eval (page )) {
45- if (test .areAllValuesNull ()) {
55+ try (Block filterResultBlock = evaluator .eval (page )) {
56+ if (filterResultBlock .areAllValuesNull ()) {
4657 // All results are null which is like false. No values selected.
4758 page .releaseBlocks ();
4859 return null ;
4960 }
61+
62+ // Explicit types to avoid casting on every element
63+ DoubleBlock scoreBlock = null ;
64+ BooleanBlock testBlock = null ;
65+ if (usesScoring ) {
66+ assert filterResultBlock instanceof DoubleBlock : "Evaluated block should be a DoubleBlock when using scoring" ;
67+ scoreBlock = (DoubleBlock ) filterResultBlock ;
68+ } else {
69+ assert filterResultBlock instanceof BooleanBlock : "Evaluated block should be a BooleanBlock when not using scoring" ;
70+ testBlock = (BooleanBlock ) filterResultBlock ;
71+ }
72+
5073 // TODO we can detect constant true or false from the type
5174 // TODO or we could make a new method in bool-valued evaluators that returns a list of numbers
5275 for (int p = 0 ; p < page .getPositionCount (); p ++) {
53- if (test .isNull (p ) || test .getValueCount (p ) != 1 ) {
76+ if (filterResultBlock .isNull (p ) || filterResultBlock .getValueCount (p ) != 1 ) {
5477 // Null is like false
5578 // And, for now, multivalued results are like false too
5679 continue ;
5780 }
58- if (test .getBoolean (test .getFirstValueIndex (p ))) {
81+ if (usesScoring && scoreBlock .getDouble (scoreBlock .getFirstValueIndex (p )) != SCORE_FOR_FALSE ) {
82+ positions [rowCount ++] = p ;
83+ } else if (usesScoring == false && testBlock .getBoolean (testBlock .getFirstValueIndex (p ))) {
5984 positions [rowCount ++] = p ;
6085 }
6186 }
@@ -73,7 +98,11 @@ protected Page process(Page page) {
7398 boolean success = false ;
7499 try {
75100 for (int i = 0 ; i < page .getBlockCount (); i ++) {
76- filteredBlocks [i ] = page .getBlock (i ).filter (positions );
101+ if (usesScoring && i == SCORE_BLOCK_INDEX ) {
102+ filteredBlocks [i ] = createScoresBlock (rowCount , (DoubleBlock ) filterResultBlock , positions );
103+ } else {
104+ filteredBlocks [i ] = page .getBlock (i ).filter (positions );
105+ }
77106 }
78107 success = true ;
79108 } finally {
@@ -86,6 +115,15 @@ protected Page process(Page page) {
86115 }
87116 }
88117
118+ private Block createScoresBlock (int rowCount , DoubleBlock scoreBlock , int [] positions ) {
119+ // Create a new scores block with the retrieved scores, that will replace the existing one on the result page
120+ DoubleVector .Builder updatedScoresBuilder = blockFactory .newDoubleVectorBuilder (rowCount );
121+ for (int j = 0 ; j < rowCount ; j ++) {
122+ updatedScoresBuilder .appendDouble (scoreBlock .getDouble (positions [j ]));
123+ }
124+ return updatedScoresBuilder .build ().asBlock ();
125+ }
126+
89127 @ Override
90128 public String toString () {
91129 return "FilterOperator[" + "evaluator=" + evaluator + ']' ;
0 commit comments