Skip to content

Commit 3ca55b1

Browse files
committed
Fixed planner tests and added some extra ones
1 parent a2d4638 commit 3ca55b1

File tree

2 files changed

+144
-10
lines changed

2 files changed

+144
-10
lines changed

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import org.elasticsearch.xpack.esql.core.expression.Expression;
4444
import org.elasticsearch.xpack.esql.core.expression.Expressions;
4545
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
46+
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
4647
import org.elasticsearch.xpack.esql.core.expression.Literal;
4748
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
4849
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
@@ -52,6 +53,7 @@
5253
import org.elasticsearch.xpack.esql.core.type.MultiTypeEsField;
5354
import org.elasticsearch.xpack.esql.core.util.Holder;
5455
import org.elasticsearch.xpack.esql.enrich.ResolvedEnrichPolicy;
56+
import org.elasticsearch.xpack.esql.expression.Order;
5557
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
5658
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
5759
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
@@ -89,6 +91,7 @@
8991
import org.elasticsearch.xpack.esql.plan.physical.ProjectExec;
9092
import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec;
9193
import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesSourceExec;
94+
import org.elasticsearch.xpack.esql.plan.physical.TopNAggregateExec;
9295
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
9396
import org.elasticsearch.xpack.esql.planner.FilterTests;
9497
import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
@@ -129,6 +132,8 @@
129132
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
130133
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultLookupResolution;
131134
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.indexWithDateDateNanosUnionType;
135+
import static org.elasticsearch.xpack.esql.core.expression.Expressions.name;
136+
import static org.elasticsearch.xpack.esql.core.expression.Expressions.names;
132137
import static org.elasticsearch.xpack.esql.core.querydsl.query.Query.unscore;
133138
import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS;
134139
import static org.elasticsearch.xpack.esql.plan.physical.EsStatsQueryExec.StatsType;
@@ -2053,6 +2058,44 @@ public void testToDateNanosPushDown() {
20532058
assertThat(expected.toString(), is(esQuery.query().toString()));
20542059
}
20552060

2061+
public void testTopNAggregate() {
2062+
var stats = EsqlTestUtils.statsForExistingField("first_name", "last_name");
2063+
var plan = plannerOptimizer.plan("""
2064+
from test
2065+
| stats x = count(first_name) by first_name, last_name
2066+
| sort x DESC, first_name NULLS LAST
2067+
| LIMIT 5
2068+
""", stats);
2069+
2070+
var aggregate1 = as(plan, TopNAggregateExec.class);
2071+
var exchange = as(aggregate1.child(), ExchangeExec.class);
2072+
var aggregate2 = as(exchange.child(), TopNAggregateExec.class);
2073+
2074+
var extract = as(aggregate2.child(), FieldExtractExec.class);
2075+
assertThat(names(extract.attributesToExtract()), equalTo(List.of("first_name", "last_name")));
2076+
var esQuery = as(extract.child(), EsQueryExec.class);
2077+
2078+
assertThat(aggregate1.groupings(), hasSize(2));
2079+
assertThat(aggregate1.estimatedRowSize(), equalTo(Long.BYTES + KEYWORD_EST + KEYWORD_EST));
2080+
assertThat(aggregate1.order(), hasSize(2));
2081+
var order1 = aggregate1.order().get(0);
2082+
assertThat(name(order1.child()), equalTo("x"));
2083+
assertThat(order1.direction(), equalTo(Order.OrderDirection.DESC));
2084+
assertThat(order1.nullsPosition(), equalTo(Order.NullsPosition.FIRST));
2085+
var order2 = aggregate1.order().get(1);
2086+
assertThat(name(order2.child()), equalTo("first_name"));
2087+
assertThat(order2.direction(), equalTo(Order.OrderDirection.ASC));
2088+
assertThat(order2.nullsPosition(), equalTo(Order.NullsPosition.LAST));
2089+
assertThat(aggregate1.limit().fold(FoldContext.small()), equalTo(5));
2090+
2091+
// Check that both agg nodes are identical
2092+
assertThat(aggregate1.aggregates(), equalTo(aggregate2.aggregates()));
2093+
assertThat(aggregate1.groupings(), equalTo(aggregate2.groupings()));
2094+
assertThat(aggregate1.estimatedRowSize(), equalTo(aggregate2.estimatedRowSize()));
2095+
assertThat(aggregate1.order(), equalTo(aggregate2.order()));
2096+
assertThat(aggregate1.limit(), equalTo(aggregate2.limit()));
2097+
}
2098+
20562099
private boolean isMultiTypeEsField(Expression e) {
20572100
return e instanceof FieldAttribute fa && fa.field() instanceof MultiTypeEsField;
20582101
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java

Lines changed: 101 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
108108
import org.elasticsearch.xpack.esql.plan.logical.Project;
109109
import org.elasticsearch.xpack.esql.plan.logical.TopN;
110+
import org.elasticsearch.xpack.esql.plan.logical.TopNAggregate;
110111
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
111112
import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes;
112113
import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation;
@@ -130,6 +131,7 @@
130131
import org.elasticsearch.xpack.esql.plan.physical.LookupJoinExec;
131132
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
132133
import org.elasticsearch.xpack.esql.plan.physical.ProjectExec;
134+
import org.elasticsearch.xpack.esql.plan.physical.TopNAggregateExec;
133135
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
134136
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
135137
import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders;
@@ -798,6 +800,32 @@ public void testDoNotExtractGroupingFields() {
798800
assertThat(source.estimatedRowSize(), equalTo(Integer.BYTES * 2));
799801
}
800802

803+
public void testDoNotExtractGroupingFieldsTopN() {
804+
var plan = physicalPlan("""
805+
from test
806+
| stats x = sum(salary) by first_name
807+
| sort first_name
808+
""");
809+
810+
var optimized = optimizedPlan(plan);
811+
var aggregate = as(optimized, TopNAggregateExec.class);
812+
assertThat(aggregate.estimatedRowSize(), equalTo(Long.BYTES + KEYWORD_EST));
813+
assertThat(aggregate.groupings(), hasSize(1));
814+
815+
var exchange = asRemoteExchange(aggregate.child());
816+
aggregate = as(exchange.child(), TopNAggregateExec.class);
817+
assertThat(aggregate.estimatedRowSize(), equalTo(Long.BYTES + KEYWORD_EST));
818+
assertThat(aggregate.groupings(), hasSize(1));
819+
820+
var extract = as(aggregate.child(), FieldExtractExec.class);
821+
assertThat(names(extract.attributesToExtract()), equalTo(List.of("salary")));
822+
823+
var source = source(extract.child());
824+
// doc id and salary are ints. salary isn't extracted.
825+
// TODO salary kind of is extracted. At least sometimes it is. should it count?
826+
assertThat(source.estimatedRowSize(), equalTo(Integer.BYTES * 2));
827+
}
828+
801829
public void testExtractGroupingFieldsIfAggd() {
802830
var plan = physicalPlan("""
803831
from test
@@ -822,6 +850,30 @@ public void testExtractGroupingFieldsIfAggd() {
822850
assertThat(source.estimatedRowSize(), equalTo(Integer.BYTES + KEYWORD_EST));
823851
}
824852

853+
public void testExtractGroupingFieldsIfAggdTopN() {
854+
var plan = physicalPlan("""
855+
from test
856+
| stats x = count(first_name) by first_name
857+
| sort first_name
858+
""");
859+
860+
var optimized = optimizedPlan(plan);
861+
var aggregate = as(optimized, TopNAggregateExec.class);
862+
assertThat(aggregate.groupings(), hasSize(1));
863+
assertThat(aggregate.estimatedRowSize(), equalTo(Long.BYTES + KEYWORD_EST));
864+
865+
var exchange = asRemoteExchange(aggregate.child());
866+
aggregate = as(exchange.child(), TopNAggregateExec.class);
867+
assertThat(aggregate.groupings(), hasSize(1));
868+
assertThat(aggregate.estimatedRowSize(), equalTo(Long.BYTES + KEYWORD_EST));
869+
870+
var extract = as(aggregate.child(), FieldExtractExec.class);
871+
assertThat(names(extract.attributesToExtract()), equalTo(List.of("first_name")));
872+
873+
var source = source(extract.child());
874+
assertThat(source.estimatedRowSize(), equalTo(Integer.BYTES + KEYWORD_EST));
875+
}
876+
825877
public void testExtractGroupingFieldsIfAggdWithEval() {
826878
var plan = physicalPlan("""
827879
from test
@@ -5519,12 +5571,11 @@ public void testPushSpatialDistanceEvalWithStatsToSource() {
55195571
| SORT count DESC, country ASC
55205572
""";
55215573
var plan = this.physicalPlan(query, airports);
5522-
var topN = as(plan, TopNExec.class);
5523-
var agg = as(topN.child(), AggregateExec.class);
5524-
var exchange = as(agg.child(), ExchangeExec.class);
5574+
var topNAgg = as(plan, TopNAggregateExec.class);
5575+
var exchange = as(topNAgg.child(), ExchangeExec.class);
55255576
var fragment = as(exchange.child(), FragmentExec.class);
5526-
var agg2 = as(fragment.fragment(), Aggregate.class);
5527-
var filter = as(agg2.child(), Filter.class);
5577+
var topNAgg2 = as(fragment.fragment(), TopNAggregate.class);
5578+
var filter = as(topNAgg2.child(), Filter.class);
55285579

55295580
// Validate the filter condition (two distance filters)
55305581
var and = as(filter.condition(), And.class);
@@ -5544,12 +5595,11 @@ public void testPushSpatialDistanceEvalWithStatsToSource() {
55445595

55455596
// Now optimize the plan
55465597
var optimized = optimizedPlan(plan);
5547-
var topLimit = as(optimized, TopNExec.class);
5548-
var aggExec = as(topLimit.child(), AggregateExec.class);
5549-
var exchangeExec = as(aggExec.child(), ExchangeExec.class);
5550-
var aggExec2 = as(exchangeExec.child(), AggregateExec.class);
5598+
var topNAggExec = as(optimized, TopNAggregateExec.class);
5599+
var exchangeExec = as(topNAggExec.child(), ExchangeExec.class);
5600+
var topNAggExec2 = as(exchangeExec.child(), TopNAggregateExec.class);
55515601
// TODO: Remove the eval entirely, since the distance is no longer required after filter pushdown
5552-
var evalExec = as(aggExec2.child(), EvalExec.class);
5602+
var evalExec = as(topNAggExec2.child(), EvalExec.class);
55535603
var stDistance = as(evalExec.fields().get(0).child(), StDistance.class);
55545604
assertThat("Expect distance function to expect doc-values", stDistance.leftDocValues(), is(true));
55555605
var source = assertChildIsGeoPointExtract(evalExec, FieldExtractPreference.DOC_VALUES);
@@ -8067,6 +8117,47 @@ public void testSamplePushDown() {
80678117
assertThat(randomSampling.hash(), equalTo(0));
80688118
}
80698119

8120+
public void testTopNStats() {
8121+
var plan = physicalPlan("""
8122+
from test
8123+
| stats x = count(first_name) by first_name, last_name
8124+
| sort x DESC, first_name NULLS LAST
8125+
| LIMIT 5
8126+
""");
8127+
8128+
var optimized = optimizedPlan(plan);
8129+
var aggregate1 = as(optimized, TopNAggregateExec.class);
8130+
8131+
var exchange = asRemoteExchange(aggregate1.child());
8132+
var aggregate2 = as(exchange.child(), TopNAggregateExec.class);
8133+
8134+
var extract = as(aggregate2.child(), FieldExtractExec.class);
8135+
assertThat(names(extract.attributesToExtract()), equalTo(List.of("first_name", "last_name")));
8136+
8137+
var source = source(extract.child());
8138+
assertThat(source.estimatedRowSize(), equalTo(Integer.BYTES + KEYWORD_EST + KEYWORD_EST));
8139+
8140+
assertThat(aggregate1.groupings(), hasSize(2));
8141+
assertThat(aggregate1.estimatedRowSize(), equalTo(Long.BYTES + KEYWORD_EST + KEYWORD_EST));
8142+
assertThat(aggregate1.order(), hasSize(2));
8143+
var order1 = aggregate1.order().get(0);
8144+
assertThat(name(order1.child()), equalTo("x"));
8145+
assertThat(order1.direction(), equalTo(Order.OrderDirection.DESC));
8146+
assertThat(order1.nullsPosition(), equalTo(Order.NullsPosition.FIRST));
8147+
var order2 = aggregate1.order().get(1);
8148+
assertThat(name(order2.child()), equalTo("first_name"));
8149+
assertThat(order2.direction(), equalTo(Order.OrderDirection.ASC));
8150+
assertThat(order2.nullsPosition(), equalTo(Order.NullsPosition.LAST));
8151+
assertThat(aggregate1.limit().fold(FoldContext.small()), equalTo(5));
8152+
8153+
// Check that both agg nodes are identical
8154+
assertThat(aggregate1.aggregates(), equalTo(aggregate2.aggregates()));
8155+
assertThat(aggregate1.groupings(), equalTo(aggregate2.groupings()));
8156+
assertThat(aggregate1.estimatedRowSize(), equalTo(aggregate2.estimatedRowSize()));
8157+
assertThat(aggregate1.order(), equalTo(aggregate2.order()));
8158+
assertThat(aggregate1.limit(), equalTo(aggregate2.limit()));
8159+
}
8160+
80708161
@SuppressWarnings("SameParameterValue")
80718162
private static void assertFilterCondition(
80728163
Filter filter,

0 commit comments

Comments
 (0)