Skip to content

Commit e98ad43

Browse files
committed
Forbid usage of _score aggregations on STATS when there is a WHERE clause
1 parent f9bf170 commit e98ad43

File tree

4 files changed

+60
-3
lines changed

4 files changed

+60
-3
lines changed

x-pack/plugin/esql/qa/testFixtures/src/main/resources/scoring.csv-spec

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -551,9 +551,10 @@ required_capability: match_function
551551
required_capability: full_text_functions_in_stats_where
552552

553553
from books metadata _score
554-
| stats avg_score = avg(_score) where match(title, "Lord Rings", {"operator": "AND"})
554+
| where match(title, "Lord Rings", {"operator": "AND"})
555+
| stats avg_score = avg(_score), max_score = max(_score), min_score = min(_score)
555556
;
556557

557-
avg_score:double
558-
3.869828939437866
558+
avg_score:double | max_score:double | min_score:double
559+
3.869828939437866 | 5.123856544494629 | 3.0124807357788086
559560
;

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/MatchFunctionIT.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
import org.elasticsearch.common.settings.Settings;
1414
import org.elasticsearch.xpack.esql.VerificationException;
1515
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
16+
import org.hamcrest.Matchers;
1617
import org.junit.Before;
1718

1819
import java.util.List;
1920

2021
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
22+
import static org.elasticsearch.xpack.esql.EsqlTestUtils.getValuesList;
2123
import static org.hamcrest.CoreMatchers.containsString;
2224

2325
//@TestLogging(value = "org.elasticsearch.xpack.esql:TRACE,org.elasticsearch.compute:TRACE", reason = "debug")
@@ -265,6 +267,21 @@ public void testMatchWithStats() {
265267
assertColumnTypes(resp.columns(), List.of("long", "long"));
266268
assertValues(resp.values(), List.of(List.of(2L, 4L)));
267269
}
270+
271+
query = """
272+
FROM test METADATA _score
273+
| WHERE match(content, "fox")
274+
| STATS m = max(_score), n = min(_score)
275+
""";
276+
277+
try (var resp = run(query)) {
278+
assertColumnNames(resp.columns(), List.of("m", "n"));
279+
assertColumnTypes(resp.columns(), List.of("double", "double"));
280+
List<List<Object>> valuesList = getValuesList(resp.values());
281+
assertEquals(1, valuesList.size());
282+
assertThat((double) valuesList.get(0).get(0), Matchers.greaterThan(1.0));
283+
assertThat((double) valuesList.get(0).get(1), Matchers.greaterThan(0.0));
284+
}
268285
}
269286

270287
public void testMatchWithinEval() {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Aggregate.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,16 @@
2222
import org.elasticsearch.xpack.esql.core.expression.Expression;
2323
import org.elasticsearch.xpack.esql.core.expression.Expressions;
2424
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
25+
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
2526
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
2627
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
2728
import org.elasticsearch.xpack.esql.core.tree.Source;
29+
import org.elasticsearch.xpack.esql.core.util.Holder;
2830
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
2931
import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredExpression;
3032
import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate;
3133
import org.elasticsearch.xpack.esql.expression.function.aggregate.TimeSeriesAggregateFunction;
34+
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction;
3235
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
3336
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
3437
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
@@ -229,7 +232,22 @@ public void postAnalysisVerification(Failures failures) {
229232
);
230233
}
231234
checkCategorizeGrouping(failures);
235+
checkMultipleScoreAggregations(failures);
236+
}
232237

238+
private void checkMultipleScoreAggregations(Failures failures) {
239+
Holder<Boolean> hasScoringAggs = new Holder<>();
240+
forEachExpression(FilteredExpression.class, fe -> {
241+
if (fe.delegate() instanceof AggregateFunction aggregateFunction) {
242+
if (aggregateFunction.field() instanceof MetadataAttribute metadataAttribute) {
243+
if (MetadataAttribute.SCORE.equals(metadataAttribute.name())) {
244+
if (fe.filter().anyMatch(e -> e instanceof FullTextFunction)) {
245+
failures.add(fail(fe, "cannot use _score aggregations with a WHERE filter in a STATS command"));
246+
}
247+
}
248+
}
249+
}
250+
});
233251
}
234252

235253
/**

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2493,6 +2493,27 @@ public void testInsistNotOnTopOfFrom() {
24932493
);
24942494
}
24952495

2496+
public void testFullTextFunctionsInStats() {
2497+
checkFullTextFunctionsInStats("match(last_name, \"Smith\")");
2498+
checkFullTextFunctionsInStats("multi_match(\"Smith\", first_name, last_name)");
2499+
checkFullTextFunctionsInStats("last_name : \"Smith\"");
2500+
checkFullTextFunctionsInStats("qstr(\"last_name: Smith\")");
2501+
checkFullTextFunctionsInStats("kql(\"last_name: Smith\")");
2502+
}
2503+
2504+
private void checkFullTextFunctionsInStats(String functionInvocation) {
2505+
2506+
query("from test | stats c = max(salary) where " + functionInvocation);
2507+
query("from test | stats c = max(salary) where " + functionInvocation + " or length(first_name) > 10");
2508+
query("from test metadata _score | where " + functionInvocation + " | stats c = max(_score)");
2509+
query("from test metadata _score | where " + functionInvocation + " or length(first_name) > 10 | stats c = max(_score)");
2510+
2511+
assertThat(
2512+
error("from test metadata _score | stats c = max(_score) where " + functionInvocation),
2513+
containsString("cannot use _score aggregations with a WHERE filter in a STATS command")
2514+
);
2515+
}
2516+
24962517
private void query(String query) {
24972518
query(query, defaultAnalyzer);
24982519
}

0 commit comments

Comments
 (0)