diff --git a/docs/changelog/124446.yaml b/docs/changelog/124446.yaml new file mode 100644 index 0000000000000..c330dbf682780 --- /dev/null +++ b/docs/changelog/124446.yaml @@ -0,0 +1,6 @@ +pr: 124446 +summary: "ESQL: Fail in `AggregateFunction` when `LogicPlan` is not an `Aggregate`" +area: ES|QL +type: bug +issues: + - 124311 diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java index aab893e6ed5cc..62554570ba3c0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java @@ -19,8 +19,9 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.util.CollectionUtils; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.plan.logical.Aggregate; +import org.elasticsearch.xpack.esql.plan.logical.Dedup; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import java.io.IOException; import java.util.List; @@ -139,14 +140,12 @@ public boolean equals(Object obj) { @Override public BiConsumer postAnalysisPlanVerification() { return (p, failures) -> { - if (p instanceof OrderBy order) { - order.order().forEach(o -> { - o.forEachDown(Function.class, f -> { - if (f instanceof AggregateFunction) { - failures.add(fail(f, "Aggregate functions are not allowed in SORT [{}]", f.functionName())); - } - }); - }); + // `dedup` for now is not exposed as a command, + // so allowing aggregate functions for dedup explicitly is just an internal implementation detail + if ((p instanceof Aggregate) == false && (p instanceof Dedup) == false) { + p.expressions().forEach(x -> x.forEachDown(AggregateFunction.class, af -> { + failures.add(fail(af, "aggregate function [{}] not allowed outside STATS command", af.sourceText())); + })); } }; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Eval.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Eval.java index f8b4f9e549916..e85c63e79a7e9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Eval.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Eval.java @@ -24,7 +24,6 @@ import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.plan.GeneratingPlan; @@ -178,10 +177,6 @@ public void postAnalysisVerification(Failures failures) { ) ); } - // check no aggregate functions are used - field.forEachDown(AggregateFunction.class, af -> { - failures.add(fail(af, "aggregate function [{}] not allowed outside STATS command", af.sourceText())); - }); }); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index 5d7d881141e7a..58c12afb92241 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -1014,6 +1014,17 @@ public void testNotFoundFieldInNestedFunction() { line 1:23: Unknown column [avg]""", error("from test | stats c = avg by missing + 1, not_found")); } + public void testMultipleAggsOutsideStats() { + assertEquals( + """ + 1:71: aggregate function [avg(salary)] not allowed outside STATS command + line 1:96: aggregate function [median(emp_no)] not allowed outside STATS command + line 1:22: aggregate function [sum(salary)] not allowed outside STATS command + line 1:39: aggregate function [avg(languages)] not allowed outside STATS command""", + error("from test | eval s = sum(salary), l = avg(languages) | where salary > avg(salary) and emp_no > median(emp_no)") + ); + } + public void testSpatialSort() { String prefix = "ROW wkt = [\"POINT(42.9711 -14.7553)\", \"POINT(75.8093 22.7277)\"] | MV_EXPAND wkt "; assertEquals("1:130: cannot sort on geo_point", error(prefix + "| EVAL shape = TO_GEOPOINT(wkt) | limit 5 | sort shape")); @@ -2107,10 +2118,53 @@ public void testChangePoint_valueNumeric() { } public void testSortByAggregate() { - assertEquals("1:18: Aggregate functions are not allowed in SORT [COUNT]", error("ROW a = 1 | SORT count(*)")); - assertEquals("1:28: Aggregate functions are not allowed in SORT [COUNT]", error("ROW a = 1 | SORT to_string(count(*))")); - assertEquals("1:22: Aggregate functions are not allowed in SORT [MAX]", error("ROW a = 1 | SORT 1 + max(a)")); - assertEquals("1:18: Aggregate functions are not allowed in SORT [COUNT]", error("FROM test | SORT count(*)")); + assertEquals("1:18: aggregate function [count(*)] not allowed outside STATS command", error("ROW a = 1 | SORT count(*)")); + assertEquals( + "1:28: aggregate function [count(*)] not allowed outside STATS command", + error("ROW a = 1 | SORT to_string(count(*))") + ); + assertEquals("1:22: aggregate function [max(a)] not allowed outside STATS command", error("ROW a = 1 | SORT 1 + max(a)")); + assertEquals("1:18: aggregate function [count(*)] not allowed outside STATS command", error("FROM test | SORT count(*)")); + } + + public void testFilterByAggregate() { + assertEquals("1:19: aggregate function [count(*)] not allowed outside STATS command", error("ROW a = 1 | WHERE count(*) > 0")); + assertEquals( + "1:29: aggregate function [count(*)] not allowed outside STATS command", + error("ROW a = 1 | WHERE to_string(count(*)) IS NOT NULL") + ); + assertEquals("1:23: aggregate function [max(a)] not allowed outside STATS command", error("ROW a = 1 | WHERE 1 + max(a) > 0")); + assertEquals( + "1:24: aggregate function [min(languages)] not allowed outside STATS command", + error("FROM employees | WHERE min(languages) > 2") + ); + } + + public void testDissectByAggregate() { + assertEquals( + "1:21: aggregate function [min(first_name)] not allowed outside STATS command", + error("from test | dissect min(first_name) \"%{foo}\"") + ); + assertEquals( + "1:21: aggregate function [avg(salary)] not allowed outside STATS command", + error("from test | dissect avg(salary) \"%{foo}\"") + ); + } + + public void testGrokByAggregate() { + assertEquals( + "1:18: aggregate function [max(last_name)] not allowed outside STATS command", + error("from test | grok max(last_name) \"%{WORD:foo}\"") + ); + assertEquals( + "1:18: aggregate function [sum(salary)] not allowed outside STATS command", + error("from test | grok sum(salary) \"%{WORD:foo}\"") + ); + } + + public void testAggregateInRow() { + assertEquals("1:13: aggregate function [count(*)] not allowed outside STATS command", error("ROW a = 1 + count(*)")); + assertEquals("1:9: aggregate function [avg(2)] not allowed outside STATS command", error("ROW a = avg(2)")); } public void testLookupJoinDataTypeMismatch() {