107107import org .elasticsearch .xpack .esql .plan .logical .LogicalPlan ;
108108import org .elasticsearch .xpack .esql .plan .logical .Project ;
109109import org .elasticsearch .xpack .esql .plan .logical .TopN ;
110+ import org .elasticsearch .xpack .esql .plan .logical .TopNAggregate ;
110111import org .elasticsearch .xpack .esql .plan .logical .join .Join ;
111112import org .elasticsearch .xpack .esql .plan .logical .join .JoinTypes ;
112113import org .elasticsearch .xpack .esql .plan .logical .local .LocalRelation ;
130131import org .elasticsearch .xpack .esql .plan .physical .LookupJoinExec ;
131132import org .elasticsearch .xpack .esql .plan .physical .PhysicalPlan ;
132133import org .elasticsearch .xpack .esql .plan .physical .ProjectExec ;
134+ import org .elasticsearch .xpack .esql .plan .physical .TopNAggregateExec ;
133135import org .elasticsearch .xpack .esql .plan .physical .TopNExec ;
134136import org .elasticsearch .xpack .esql .plan .physical .UnaryExec ;
135137import org .elasticsearch .xpack .esql .planner .EsPhysicalOperationProviders ;
@@ -798,6 +800,32 @@ public void testDoNotExtractGroupingFields() {
798800 assertThat (source .estimatedRowSize (), equalTo (Integer .BYTES * 2 ));
799801 }
800802
803+ public void testDoNotExtractGroupingFieldsTopN () {
804+ var plan = physicalPlan ("""
805+ from test
806+ | stats x = sum(salary) by first_name
807+ | sort first_name
808+ """ );
809+
810+ var optimized = optimizedPlan (plan );
811+ var aggregate = as (optimized , TopNAggregateExec .class );
812+ assertThat (aggregate .estimatedRowSize (), equalTo (Long .BYTES + KEYWORD_EST ));
813+ assertThat (aggregate .groupings (), hasSize (1 ));
814+
815+ var exchange = asRemoteExchange (aggregate .child ());
816+ aggregate = as (exchange .child (), TopNAggregateExec .class );
817+ assertThat (aggregate .estimatedRowSize (), equalTo (Long .BYTES + KEYWORD_EST ));
818+ assertThat (aggregate .groupings (), hasSize (1 ));
819+
820+ var extract = as (aggregate .child (), FieldExtractExec .class );
821+ assertThat (names (extract .attributesToExtract ()), equalTo (List .of ("salary" )));
822+
823+ var source = source (extract .child ());
824+ // doc id and salary are ints. salary isn't extracted.
825+ // TODO salary kind of is extracted. At least sometimes it is. should it count?
826+ assertThat (source .estimatedRowSize (), equalTo (Integer .BYTES * 2 ));
827+ }
828+
801829 public void testExtractGroupingFieldsIfAggd () {
802830 var plan = physicalPlan ("""
803831 from test
@@ -822,6 +850,30 @@ public void testExtractGroupingFieldsIfAggd() {
822850 assertThat (source .estimatedRowSize (), equalTo (Integer .BYTES + KEYWORD_EST ));
823851 }
824852
853+ public void testExtractGroupingFieldsIfAggdTopN () {
854+ var plan = physicalPlan ("""
855+ from test
856+ | stats x = count(first_name) by first_name
857+ | sort first_name
858+ """ );
859+
860+ var optimized = optimizedPlan (plan );
861+ var aggregate = as (optimized , TopNAggregateExec .class );
862+ assertThat (aggregate .groupings (), hasSize (1 ));
863+ assertThat (aggregate .estimatedRowSize (), equalTo (Long .BYTES + KEYWORD_EST ));
864+
865+ var exchange = asRemoteExchange (aggregate .child ());
866+ aggregate = as (exchange .child (), TopNAggregateExec .class );
867+ assertThat (aggregate .groupings (), hasSize (1 ));
868+ assertThat (aggregate .estimatedRowSize (), equalTo (Long .BYTES + KEYWORD_EST ));
869+
870+ var extract = as (aggregate .child (), FieldExtractExec .class );
871+ assertThat (names (extract .attributesToExtract ()), equalTo (List .of ("first_name" )));
872+
873+ var source = source (extract .child ());
874+ assertThat (source .estimatedRowSize (), equalTo (Integer .BYTES + KEYWORD_EST ));
875+ }
876+
825877 public void testExtractGroupingFieldsIfAggdWithEval () {
826878 var plan = physicalPlan ("""
827879 from test
@@ -5519,12 +5571,11 @@ public void testPushSpatialDistanceEvalWithStatsToSource() {
55195571 | SORT count DESC, country ASC
55205572 """ ;
55215573 var plan = this .physicalPlan (query , airports );
5522- var topN = as (plan , TopNExec .class );
5523- var agg = as (topN .child (), AggregateExec .class );
5524- var exchange = as (agg .child (), ExchangeExec .class );
5574+ var topNAgg = as (plan , TopNAggregateExec .class );
5575+ var exchange = as (topNAgg .child (), ExchangeExec .class );
55255576 var fragment = as (exchange .child (), FragmentExec .class );
5526- var agg2 = as (fragment .fragment (), Aggregate .class );
5527- var filter = as (agg2 .child (), Filter .class );
5577+ var topNAgg2 = as (fragment .fragment (), TopNAggregate .class );
5578+ var filter = as (topNAgg2 .child (), Filter .class );
55285579
55295580 // Validate the filter condition (two distance filters)
55305581 var and = as (filter .condition (), And .class );
@@ -5544,12 +5595,11 @@ public void testPushSpatialDistanceEvalWithStatsToSource() {
55445595
55455596 // Now optimize the plan
55465597 var optimized = optimizedPlan (plan );
5547- var topLimit = as (optimized , TopNExec .class );
5548- var aggExec = as (topLimit .child (), AggregateExec .class );
5549- var exchangeExec = as (aggExec .child (), ExchangeExec .class );
5550- var aggExec2 = as (exchangeExec .child (), AggregateExec .class );
5598+ var topNAggExec = as (optimized , TopNAggregateExec .class );
5599+ var exchangeExec = as (topNAggExec .child (), ExchangeExec .class );
5600+ var topNAggExec2 = as (exchangeExec .child (), TopNAggregateExec .class );
55515601 // TODO: Remove the eval entirely, since the distance is no longer required after filter pushdown
5552- var evalExec = as (aggExec2 .child (), EvalExec .class );
5602+ var evalExec = as (topNAggExec2 .child (), EvalExec .class );
55535603 var stDistance = as (evalExec .fields ().get (0 ).child (), StDistance .class );
55545604 assertThat ("Expect distance function to expect doc-values" , stDistance .leftDocValues (), is (true ));
55555605 var source = assertChildIsGeoPointExtract (evalExec , FieldExtractPreference .DOC_VALUES );
@@ -8067,6 +8117,47 @@ public void testSamplePushDown() {
80678117 assertThat (randomSampling .hash (), equalTo (0 ));
80688118 }
80698119
8120+ public void testTopNStats () {
8121+ var plan = physicalPlan ("""
8122+ from test
8123+ | stats x = count(first_name) by first_name, last_name
8124+ | sort x DESC, first_name NULLS LAST
8125+ | LIMIT 5
8126+ """ );
8127+
8128+ var optimized = optimizedPlan (plan );
8129+ var aggregate1 = as (optimized , TopNAggregateExec .class );
8130+
8131+ var exchange = asRemoteExchange (aggregate1 .child ());
8132+ var aggregate2 = as (exchange .child (), TopNAggregateExec .class );
8133+
8134+ var extract = as (aggregate2 .child (), FieldExtractExec .class );
8135+ assertThat (names (extract .attributesToExtract ()), equalTo (List .of ("first_name" , "last_name" )));
8136+
8137+ var source = source (extract .child ());
8138+ assertThat (source .estimatedRowSize (), equalTo (Integer .BYTES + KEYWORD_EST + KEYWORD_EST ));
8139+
8140+ assertThat (aggregate1 .groupings (), hasSize (2 ));
8141+ assertThat (aggregate1 .estimatedRowSize (), equalTo (Long .BYTES + KEYWORD_EST + KEYWORD_EST ));
8142+ assertThat (aggregate1 .order (), hasSize (2 ));
8143+ var order1 = aggregate1 .order ().get (0 );
8144+ assertThat (name (order1 .child ()), equalTo ("x" ));
8145+ assertThat (order1 .direction (), equalTo (Order .OrderDirection .DESC ));
8146+ assertThat (order1 .nullsPosition (), equalTo (Order .NullsPosition .FIRST ));
8147+ var order2 = aggregate1 .order ().get (1 );
8148+ assertThat (name (order2 .child ()), equalTo ("first_name" ));
8149+ assertThat (order2 .direction (), equalTo (Order .OrderDirection .ASC ));
8150+ assertThat (order2 .nullsPosition (), equalTo (Order .NullsPosition .LAST ));
8151+ assertThat (aggregate1 .limit ().fold (FoldContext .small ()), equalTo (5 ));
8152+
8153+ // Check that both agg nodes are identical
8154+ assertThat (aggregate1 .aggregates (), equalTo (aggregate2 .aggregates ()));
8155+ assertThat (aggregate1 .groupings (), equalTo (aggregate2 .groupings ()));
8156+ assertThat (aggregate1 .estimatedRowSize (), equalTo (aggregate2 .estimatedRowSize ()));
8157+ assertThat (aggregate1 .order (), equalTo (aggregate2 .order ()));
8158+ assertThat (aggregate1 .limit (), equalTo (aggregate2 .limit ()));
8159+ }
8160+
80708161 @ SuppressWarnings ("SameParameterValue" )
80718162 private static void assertFilterCondition (
80728163 Filter filter ,
0 commit comments