Skip to content

Commit 967a415

Browse files
committed
KNN can't be used in stats functions
1 parent a6562f8 commit 967a415

File tree

4 files changed

+79
-26
lines changed

4 files changed

+79
-26
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ public QueryBuilder queryBuilder() {
181181

182182
@Override
183183
public BiConsumer<LogicalPlan, Failures> postAnalysisPlanVerification() {
184-
return FullTextFunction::checkFullTextQueryFunctions;
184+
return this::checkFullTextQueryFunctions;
185185
}
186186

187187
/**
@@ -190,7 +190,7 @@ public BiConsumer<LogicalPlan, Failures> postAnalysisPlanVerification() {
190190
* @param plan root plan to check
191191
* @param failures failures found
192192
*/
193-
private static void checkFullTextQueryFunctions(LogicalPlan plan, Failures failures) {
193+
private void checkFullTextQueryFunctions(LogicalPlan plan, Failures failures) {
194194
if (plan instanceof Filter f) {
195195
Expression condition = f.condition();
196196

@@ -219,23 +219,25 @@ private static void checkFullTextQueryFunctions(LogicalPlan plan, Failures failu
219219
checkFullTextFunctionsInAggs(agg, failures);
220220
} else {
221221
plan.forEachExpression(FullTextFunction.class, ftf -> {
222-
failures.add(fail(ftf, "[{}] {} is only supported in WHERE and STATS commands", ftf.functionName(), ftf.functionType()));
222+
failures.add(fail(ftf, notSupportedErroMessage(), ftf.functionName(), ftf.functionType()));
223223
});
224224
}
225225
}
226226

227-
private static void checkFullTextFunctionsInAggs(Aggregate agg, Failures failures) {
227+
protected void checkFullTextFunctionsInAggs(Aggregate agg, Failures failures) {
228228
agg.groupings().forEach(exp -> {
229229
exp.forEachDown(e -> {
230230
if (e instanceof FullTextFunction ftf) {
231-
failures.add(
232-
fail(ftf, "[{}] {} is only supported in WHERE and STATS commands", ftf.functionName(), ftf.functionType())
233-
);
231+
failures.add(fail(ftf, notSupportedErroMessage(), ftf.functionName(), ftf.functionType()));
234232
}
235233
});
236234
});
237235
}
238236

237+
protected String notSupportedErroMessage() {
238+
return "[{}] {} is only supported in WHERE and STATS commands";
239+
}
240+
239241
/**
240242
* Checks all commands that exist before a specific type satisfy conditions.
241243
*

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.common.io.stream.StreamOutput;
1313
import org.elasticsearch.index.query.QueryBuilder;
14+
import org.elasticsearch.xpack.esql.common.Failures;
1415
import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
1516
import org.elasticsearch.xpack.esql.core.expression.Expression;
1617
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
@@ -28,9 +29,11 @@
2829
import org.elasticsearch.xpack.esql.expression.function.MapParam;
2930
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
3031
import org.elasticsearch.xpack.esql.expression.function.Param;
32+
import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredExpression;
3133
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction;
3234
import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
3335
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
36+
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
3437
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
3538
import org.elasticsearch.xpack.esql.querydsl.query.KnnQuery;
3639

@@ -45,6 +48,7 @@
4548
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD;
4649
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD;
4750
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD;
51+
import static org.elasticsearch.xpack.esql.common.Failure.fail;
4852
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
4953
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
5054
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD;
@@ -264,6 +268,26 @@ public Expression replaceLimit(int limit) {
264268
);
265269
}
266270

271+
/**
272+
* KNN should not be used in aggregations, as it is a top-N query and not a filtering query
273+
*/
274+
@Override
275+
protected void checkFullTextFunctionsInAggs(Aggregate agg, Failures failures) {
276+
super.checkFullTextFunctionsInAggs(agg, failures);
277+
agg.aggregates().forEach(exp -> {
278+
exp.forEachDown(FilteredExpression.class, filterExp -> {
279+
filterExp.filter().forEachDown(Knn.class, knn -> {
280+
failures.add(fail(knn, notSupportedErroMessage(), knn.functionName(), knn.functionType()));
281+
});
282+
});
283+
});
284+
}
285+
286+
@Override
287+
protected String notSupportedErroMessage() {
288+
return "[{}] {} is only supported in WHERE commands";
289+
}
290+
267291
@Override
268292
protected NodeInfo<? extends Expression> info() {
269293
return NodeInfo.create(this, Knn::new, field(), query(), options());

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,40 +1357,45 @@ private void checkNonFieldBasedFullTextFunctionsNotAllowedAfterCommands(String f
13571357
}
13581358

13591359
public void testFullTextFunctionsOnlyAllowedInWhere() throws Exception {
1360-
checkFullTextFunctionsOnlyAllowedInWhere("MATCH", "match(title, \"Meditation\")", "function");
1361-
checkFullTextFunctionsOnlyAllowedInWhere(":", "title:\"Meditation\"", "operator");
1362-
checkFullTextFunctionsOnlyAllowedInWhere("QSTR", "qstr(\"Meditation\")", "function");
1363-
checkFullTextFunctionsOnlyAllowedInWhere("KQL", "kql(\"Meditation\")", "function");
1364-
checkFullTextFunctionsOnlyAllowedInWhere("MatchPhrase", "match_phrase(title, \"Meditation\")", "function");
1360+
String defaultErrorMsg = "is only supported in WHERE and STATS commands";
1361+
checkFullTextFunctionsOnlyAllowedInWhere("MATCH", "match(title, \"Meditation\")", "function", defaultErrorMsg);
1362+
checkFullTextFunctionsOnlyAllowedInWhere(":", "title:\"Meditation\"", "operator", defaultErrorMsg);
1363+
checkFullTextFunctionsOnlyAllowedInWhere("QSTR", "qstr(\"Meditation\")", "function", defaultErrorMsg);
1364+
checkFullTextFunctionsOnlyAllowedInWhere("KQL", "kql(\"Meditation\")", "function", defaultErrorMsg);
1365+
checkFullTextFunctionsOnlyAllowedInWhere("MatchPhrase", "match_phrase(title, \"Meditation\")", "function", defaultErrorMsg);
13651366
if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
1366-
checkFullTextFunctionsOnlyAllowedInWhere("Term", "term(title, \"Meditation\")", "function");
1367+
checkFullTextFunctionsOnlyAllowedInWhere("Term", "term(title, \"Meditation\")", "function", defaultErrorMsg);
13671368
}
13681369
if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
1369-
checkFullTextFunctionsOnlyAllowedInWhere("MultiMatch", "multi_match(\"Meditation\", title, body)", "function");
1370+
checkFullTextFunctionsOnlyAllowedInWhere("MultiMatch", "multi_match(\"Meditation\", title, body)", "function", defaultErrorMsg);
13701371
}
13711372
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
1372-
checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2])", "function");
1373+
checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2])", "function", "is only supported in WHERE commands");
13731374
}
13741375
}
13751376

1376-
private void checkFullTextFunctionsOnlyAllowedInWhere(String functionName, String functionInvocation, String functionType)
1377-
throws Exception {
1377+
private void checkFullTextFunctionsOnlyAllowedInWhere(
1378+
String functionName,
1379+
String functionInvocation,
1380+
String functionType,
1381+
String errorMsg
1382+
) throws Exception {
13781383
assertThat(
13791384
error("from test | eval y = " + functionInvocation, fullTextAnalyzer),
1380-
containsString("[" + functionName + "] " + functionType + " is only supported in WHERE and STATS commands")
1385+
containsString("[" + functionName + "] " + functionType + " " + errorMsg)
13811386
);
13821387
assertThat(
13831388
error("from test | sort " + functionInvocation + " asc", fullTextAnalyzer),
1384-
containsString("[" + functionName + "] " + functionType + " is only supported in WHERE and STATS commands")
1389+
containsString("[" + functionName + "] " + functionType + " " + errorMsg)
13851390
);
13861391
assertThat(
13871392
error("from test | stats max_id = max(id) by " + functionInvocation, fullTextAnalyzer),
1388-
containsString("[" + functionName + "] " + functionType + " is only supported in WHERE and STATS commands")
1393+
containsString("[" + functionName + "] " + functionType + " " + errorMsg)
13891394
);
13901395
if ("KQL".equals(functionName) || "QSTR".equals(functionName)) {
13911396
assertThat(
13921397
error("row a = " + functionInvocation, fullTextAnalyzer),
1393-
containsString("[" + functionName + "] " + functionType + " is only supported in WHERE and STATS commands")
1398+
containsString("[" + functionName + "] " + functionType + " " + errorMsg)
13941399
);
13951400
}
13961401
}
@@ -2215,7 +2220,7 @@ public void testFullTextFunctionsInStats() {
22152220
checkFullTextFunctionsInStats("multi_match(\"Meditation\", title, body)");
22162221
}
22172222
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
2218-
checkFullTextFunctionsInStats("knn(vector, [0, 1, 2])");
2223+
checkFullTextFunctionsInStatsError("knn(vector, [0, 1, 2])");
22192224
}
22202225
}
22212226

@@ -2234,6 +2239,25 @@ private void checkFullTextFunctionsInStats(String functionInvocation) {
22342239
);
22352240
}
22362241

2242+
private void checkFullTextFunctionsInStatsError(String functionInvocation) {
2243+
assertThat(
2244+
error("from test metadata _score | stats c = max(_score) where " + functionInvocation, fullTextAnalyzer),
2245+
containsString("cannot use _score aggregations with a WHERE filter in a STATS command")
2246+
);
2247+
assertThat(
2248+
error("from test | stats c = max(id) where " + functionInvocation, fullTextAnalyzer),
2249+
containsString("is only supported in WHERE commands")
2250+
);
2251+
assertThat(
2252+
error("from test | stats c = max(id) where " + functionInvocation + " or length(title) > 10", fullTextAnalyzer),
2253+
containsString("is only supported in WHERE commands")
2254+
);
2255+
assertThat(
2256+
error("from test metadata _score | stats c = max(_score) where " + functionInvocation, fullTextAnalyzer),
2257+
containsString("is only supported in WHERE commands")
2258+
);
2259+
}
2260+
22372261
private void query(String query) {
22382262
query(query, defaultAnalyzer);
22392263
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
6262
import org.elasticsearch.xpack.esql.expression.function.fulltext.MatchOperator;
6363
import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString;
64+
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
6465
import org.elasticsearch.xpack.esql.expression.predicate.logical.Or;
6566
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
6667
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual;
@@ -1475,10 +1476,12 @@ private void testFullTextFunction(FullTextFunctionTestCase testCase) {
14751476
testMultipleFullTextFunctionFilterPushdown(testCase);
14761477
testFullTextFunctionsDisjunctionPushdown(testCase);
14771478
testFullTextFunctionsDisjunctionWithFiltersPushdown(testCase);
1478-
testFullTextFunctionWithStatsWherePushable(testCase);
1479-
testFullTextFunctionWithStatsPushableAndNonPushableCondition(testCase);
1480-
testFullTextFunctionStatsWithNonPushableCondition(testCase);
1481-
testFullTextFunctionWithStatsBy(testCase);
1479+
if (testCase.fullTextFunction != Knn.class) {
1480+
testFullTextFunctionWithStatsWherePushable(testCase);
1481+
testFullTextFunctionWithStatsPushableAndNonPushableCondition(testCase);
1482+
testFullTextFunctionStatsWithNonPushableCondition(testCase);
1483+
testFullTextFunctionWithStatsBy(testCase);
1484+
}
14821485
}
14831486

14841487
private void testBasicFullTextFunction(FullTextFunctionTestCase testCase) {

0 commit comments

Comments
 (0)