Skip to content

Commit e4eb86d

Browse files
committed
Implement ExpressionScoreMapper for FullTextFunction and BinaryLogic
1 parent 9ca756a commit e4eb86d

File tree

2 files changed

+56
-71
lines changed

2 files changed

+56
-71
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java

Lines changed: 20 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,24 @@
1010
import org.elasticsearch.common.lucene.BytesRefs;
1111
import org.elasticsearch.compute.lucene.LuceneQueryExpressionEvaluator;
1212
import org.elasticsearch.compute.lucene.LuceneQueryExpressionEvaluator.ShardConfig;
13+
import org.elasticsearch.compute.lucene.LuceneQueryScoreEvaluator;
1314
import org.elasticsearch.compute.operator.EvalOperator;
15+
import org.elasticsearch.compute.operator.ScoreOperator;
1416
import org.elasticsearch.index.query.QueryBuilder;
1517
import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware;
1618
import org.elasticsearch.xpack.esql.capabilities.TranslationAware;
1719
import org.elasticsearch.xpack.esql.common.Failures;
1820
import org.elasticsearch.xpack.esql.core.expression.Expression;
1921
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
20-
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
2122
import org.elasticsearch.xpack.esql.core.expression.Nullability;
2223
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
2324
import org.elasticsearch.xpack.esql.core.expression.function.Function;
2425
import org.elasticsearch.xpack.esql.core.querydsl.query.Query;
2526
import org.elasticsearch.xpack.esql.core.tree.Source;
2627
import org.elasticsearch.xpack.esql.core.type.DataType;
27-
import org.elasticsearch.xpack.esql.core.util.Holder;
2828
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
2929
import org.elasticsearch.xpack.esql.expression.predicate.logical.BinaryLogic;
3030
import org.elasticsearch.xpack.esql.expression.predicate.logical.Not;
31-
import org.elasticsearch.xpack.esql.expression.predicate.logical.Or;
3231
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
3332
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
3433
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
@@ -39,6 +38,7 @@
3938
import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders;
4039
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
4140
import org.elasticsearch.xpack.esql.querydsl.query.TranslationAwareExpressionQuery;
41+
import org.elasticsearch.xpack.esql.score.ExpressionScoreMapper;
4242

4343
import java.util.List;
4444
import java.util.Locale;
@@ -56,7 +56,12 @@
5656
* These functions needs to be pushed down to Lucene queries to be executed - there's no Evaluator for them, but depend on
5757
* {@link org.elasticsearch.xpack.esql.optimizer.LocalPhysicalPlanOptimizer} to rewrite them into Lucene queries.
5858
*/
59-
public abstract class FullTextFunction extends Function implements TranslationAware, PostAnalysisPlanVerificationAware, EvaluatorMapper {
59+
public abstract class FullTextFunction extends Function
60+
implements
61+
TranslationAware,
62+
PostAnalysisPlanVerificationAware,
63+
EvaluatorMapper,
64+
ExpressionScoreMapper {
6065

6166
private final Expression query;
6267
private final QueryBuilder queryBuilder;
@@ -204,79 +209,13 @@ private static void checkFullTextQueryFunctions(LogicalPlan plan, Failures failu
204209
failures
205210
);
206211
checkFullTextFunctionsParents(condition, failures);
207-
208-
boolean usesScore = plan.output()
209-
.stream()
210-
.anyMatch(attr -> attr instanceof MetadataAttribute ma && ma.name().equals(MetadataAttribute.SCORE));
211-
if (usesScore) {
212-
checkFullTextSearchDisjunctions(condition, failures);
213-
}
214212
} else {
215213
plan.forEachExpression(FullTextFunction.class, ftf -> {
216214
failures.add(fail(ftf, "[{}] {} is only supported in WHERE commands", ftf.functionName(), ftf.functionType()));
217215
});
218216
}
219217
}
220218

221-
/**
222-
* Checks whether a condition contains a disjunction with a full text search.
223-
* If it does, check that every element of the disjunction is a full text search or combinations (AND, OR, NOT) of them.
224-
* If not, add a failure to the failures collection.
225-
*
226-
* @param condition condition to check for disjunctions of full text searches
227-
* @param failures failures collection to add to
228-
*/
229-
private static void checkFullTextSearchDisjunctions(Expression condition, Failures failures) {
230-
Holder<Boolean> isInvalid = new Holder<>(false);
231-
condition.forEachDown(Or.class, or -> {
232-
if (isInvalid.get()) {
233-
// Exit early if we already have a failures
234-
return;
235-
}
236-
if (checkDisjunctionPushable(or) == false) {
237-
isInvalid.set(true);
238-
failures.add(
239-
fail(
240-
or,
241-
"Invalid condition when using METADATA _score [{}]. Full text functions can be used in an OR condition, "
242-
+ "but only if just full text functions are used in the OR condition",
243-
or.sourceText()
244-
)
245-
);
246-
}
247-
});
248-
}
249-
250-
/**
251-
* Checks if a disjunction is pushable from the point of view of FullTextFunctions. Either it has no FullTextFunctions or
252-
* all it contains are FullTextFunctions.
253-
*
254-
* @param or disjunction to check
255-
* @return true if the disjunction is pushable, false otherwise
256-
*/
257-
private static boolean checkDisjunctionPushable(Or or) {
258-
boolean hasFullText = or.anyMatch(FullTextFunction.class::isInstance);
259-
return hasFullText == false || onlyFullTextFunctionsInExpression(or);
260-
}
261-
262-
/**
263-
* Checks whether an expression contains just full text functions or negations (NOT) and combinations (AND, OR) of full text functions
264-
*
265-
* @param expression expression to check
266-
* @return true if all children are full text functions or negations of full text functions, false otherwise
267-
*/
268-
private static boolean onlyFullTextFunctionsInExpression(Expression expression) {
269-
if (expression instanceof FullTextFunction) {
270-
return true;
271-
} else if (expression instanceof Not) {
272-
return onlyFullTextFunctionsInExpression(expression.children().get(0));
273-
} else if (expression instanceof BinaryLogic binaryLogic) {
274-
return onlyFullTextFunctionsInExpression(binaryLogic.left()) && onlyFullTextFunctionsInExpression(binaryLogic.right());
275-
}
276-
277-
return false;
278-
}
279-
280219
/**
281220
* Checks all commands that exist before a specific type satisfy conditions.
282221
*
@@ -365,4 +304,15 @@ public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvalua
365304
}
366305
return new LuceneQueryExpressionEvaluator.Factory(shardConfigs);
367306
}
307+
308+
@Override
309+
public ScoreOperator.ExpressionScorer.Factory toScorer(ToScorer toScorer) {
310+
List<EsPhysicalOperationProviders.ShardContext> shardContexts = toScorer.shardContexts();
311+
ShardConfig[] shardConfigs = new ShardConfig[shardContexts.size()];
312+
int i = 0;
313+
for (EsPhysicalOperationProviders.ShardContext shardContext : shardContexts) {
314+
shardConfigs[i++] = new ShardConfig(shardContext.toQuery(queryBuilder()), shardContext.searcher());
315+
}
316+
return new LuceneQueryScoreEvaluator.Factory(shardConfigs);
317+
}
368318
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/logical/BinaryLogic.java

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88

99
import org.elasticsearch.common.io.stream.StreamInput;
1010
import org.elasticsearch.common.io.stream.StreamOutput;
11+
import org.elasticsearch.compute.data.DoubleBlock;
12+
import org.elasticsearch.compute.data.DoubleVector;
13+
import org.elasticsearch.compute.data.Page;
14+
import org.elasticsearch.compute.operator.DriverContext;
15+
import org.elasticsearch.compute.operator.ScoreOperator;
1116
import org.elasticsearch.xpack.esql.capabilities.TranslationAware;
1217
import org.elasticsearch.xpack.esql.core.expression.Expression;
1318
import org.elasticsearch.xpack.esql.core.expression.Nullability;
@@ -22,14 +27,18 @@
2227
import org.elasticsearch.xpack.esql.core.util.PlanStreamInput;
2328
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
2429
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
30+
import org.elasticsearch.xpack.esql.score.ExpressionScoreMapper;
2531

2632
import java.io.IOException;
2733
import java.util.Arrays;
2834
import java.util.List;
2935

3036
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isBoolean;
3137

32-
public abstract class BinaryLogic extends BinaryOperator<Boolean, Boolean, Boolean, BinaryLogicOperation> implements TranslationAware {
38+
public abstract class BinaryLogic extends BinaryOperator<Boolean, Boolean, Boolean, BinaryLogicOperation>
39+
implements
40+
TranslationAware,
41+
ExpressionScoreMapper {
3342

3443
protected BinaryLogic(Source source, Expression left, Expression right, BinaryLogicOperation operation) {
3544
super(source, left, right, operation);
@@ -108,4 +117,30 @@ public static Query boolQuery(Source source, Query left, Query right, boolean is
108117
}
109118
return new BoolQuery(source, isAnd, queries);
110119
}
120+
121+
@Override
122+
public ScoreOperator.ExpressionScorer.Factory toScorer(ToScorer toScorer) {
123+
return context -> new BinaryLogicScorer(context, toScorer.toScorer(left()).get(context), toScorer.toScorer(right()).get(context));
124+
}
125+
126+
private record BinaryLogicScorer(DriverContext driverContext, ScoreOperator.ExpressionScorer left, ScoreOperator.ExpressionScorer right)
127+
implements
128+
ScoreOperator.ExpressionScorer {
129+
@Override
130+
public DoubleBlock score(Page page) {
131+
DoubleVector.Builder builder = driverContext.blockFactory().newDoubleVectorFixedBuilder(page.getPositionCount());
132+
try (DoubleVector leftVector = left.score(page).asVector(); DoubleVector rightVector = right.score(page).asVector()) {
133+
for (int i = 0; i < page.getPositionCount(); i++) {
134+
builder.appendDouble(leftVector.getDouble(i) + rightVector.getDouble(i));
135+
}
136+
}
137+
return builder.build().asBlock();
138+
}
139+
140+
@Override
141+
public void close() {
142+
left.close();
143+
right.close();
144+
}
145+
}
111146
}

0 commit comments

Comments
 (0)