diff --git a/docs/changelog/124446.yaml b/docs/changelog/124446.yaml new file mode 100644 index 0000000000000..c330dbf682780 --- /dev/null +++ b/docs/changelog/124446.yaml @@ -0,0 +1,6 @@ +pr: 124446 +summary: "ESQL: Fail in `AggregateFunction` when `LogicPlan` is not an `Aggregate`" +area: ES|QL +type: bug +issues: + - 124311 diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java index 8aa7f697489c6..e5ddc83c4cf09 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java @@ -19,8 +19,8 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.util.CollectionUtils; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import java.io.IOException; import java.util.List; @@ -139,14 +139,14 @@ public boolean equals(Object obj) { @Override public BiConsumer postAnalysisPlanVerification() { return (p, failures) -> { - if (p instanceof OrderBy order) { - order.order().forEach(o -> { - o.forEachDown(Function.class, f -> { - if (f instanceof AggregateFunction) { - failures.add(fail(f, "Aggregate functions are not allowed in SORT [{}]", f.functionName())); - } - }); - }); + if ((p instanceof Aggregate) == false) { + p.expressions().forEach(x -> x.forEachDown(AggregateFunction.class, af -> { + if (af instanceof Rate) { + failures.add(fail(af, "aggregate function [{}] not allowed outside METRICS command", af.sourceText())); + } else { + failures.add(fail(af, "aggregate function [{}] not allowed outside STATS command", af.sourceText())); + } + })); } }; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Eval.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Eval.java index af81e26d57c60..e85c63e79a7e9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Eval.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Eval.java @@ -24,8 +24,6 @@ import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; -import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.plan.GeneratingPlan; @@ -179,14 +177,6 @@ public void postAnalysisVerification(Failures failures) { ) ); } - // check no aggregate functions are used - field.forEachDown(AggregateFunction.class, af -> { - if (af instanceof Rate) { - failures.add(fail(af, "aggregate function [{}] not allowed outside METRICS command", af.sourceText())); - } else { - failures.add(fail(af, "aggregate function [{}] not allowed outside STATS command", af.sourceText())); - } - }); }); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index 12ad010a8e9c6..6a79ae4c40792 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -1005,6 +1005,17 @@ public void testNotFoundFieldInNestedFunction() { line 1:23: Unknown column [avg]""", error("from test | stats c = avg by missing + 1, not_found")); } + public void testMultipleAggsOutsideStats() { + assertEquals( + """ + 1:71: aggregate function [avg(salary)] not allowed outside STATS command + line 1:96: aggregate function [median(emp_no)] not allowed outside STATS command + line 1:22: aggregate function [sum(salary)] not allowed outside STATS command + line 1:39: aggregate function [avg(languages)] not allowed outside STATS command""", + error("from test | eval s = sum(salary), l = avg(languages) | where salary > avg(salary) and emp_no > median(emp_no)") + ); + } + public void testSpatialSort() { String prefix = "ROW wkt = [\"POINT(42.9711 -14.7553)\", \"POINT(75.8093 22.7277)\"] | MV_EXPAND wkt "; assertEquals("1:130: cannot sort on geo_point", error(prefix + "| EVAL shape = TO_GEOPOINT(wkt) | limit 5 | sort shape")); @@ -2032,10 +2043,53 @@ public void testCategorizeWithFilteredAggregations() { } public void testSortByAggregate() { - assertEquals("1:18: Aggregate functions are not allowed in SORT [COUNT]", error("ROW a = 1 | SORT count(*)")); - assertEquals("1:28: Aggregate functions are not allowed in SORT [COUNT]", error("ROW a = 1 | SORT to_string(count(*))")); - assertEquals("1:22: Aggregate functions are not allowed in SORT [MAX]", error("ROW a = 1 | SORT 1 + max(a)")); - assertEquals("1:18: Aggregate functions are not allowed in SORT [COUNT]", error("FROM test | SORT count(*)")); + assertEquals("1:18: aggregate function [count(*)] not allowed outside STATS command", error("ROW a = 1 | SORT count(*)")); + assertEquals( + "1:28: aggregate function [count(*)] not allowed outside STATS command", + error("ROW a = 1 | SORT to_string(count(*))") + ); + assertEquals("1:22: aggregate function [max(a)] not allowed outside STATS command", error("ROW a = 1 | SORT 1 + max(a)")); + assertEquals("1:18: aggregate function [count(*)] not allowed outside STATS command", error("FROM test | SORT count(*)")); + } + + public void testFilterByAggregate() { + assertEquals("1:19: aggregate function [count(*)] not allowed outside STATS command", error("ROW a = 1 | WHERE count(*) > 0")); + assertEquals( + "1:29: aggregate function [count(*)] not allowed outside STATS command", + error("ROW a = 1 | WHERE to_string(count(*)) IS NOT NULL") + ); + assertEquals("1:23: aggregate function [max(a)] not allowed outside STATS command", error("ROW a = 1 | WHERE 1 + max(a) > 0")); + assertEquals( + "1:24: aggregate function [min(languages)] not allowed outside STATS command", + error("FROM employees | WHERE min(languages) > 2") + ); + } + + public void testDissectByAggregate() { + assertEquals( + "1:21: aggregate function [min(first_name)] not allowed outside STATS command", + error("from test | dissect min(first_name) \"%{foo}\"") + ); + assertEquals( + "1:21: aggregate function [avg(salary)] not allowed outside STATS command", + error("from test | dissect avg(salary) \"%{foo}\"") + ); + } + + public void testGrokByAggregate() { + assertEquals( + "1:18: aggregate function [max(last_name)] not allowed outside STATS command", + error("from test | grok max(last_name) \"%{WORD:foo}\"") + ); + assertEquals( + "1:18: aggregate function [sum(salary)] not allowed outside STATS command", + error("from test | grok sum(salary) \"%{WORD:foo}\"") + ); + } + + public void testAggregateInRow() { + assertEquals("1:13: aggregate function [count(*)] not allowed outside STATS command", error("ROW a = 1 + count(*)")); + assertEquals("1:9: aggregate function [avg(2)] not allowed outside STATS command", error("ROW a = avg(2)")); } public void testLookupJoinDataTypeMismatch() {