@@ -11,7 +11,7 @@ use datafusion::error::DataFusionError;
1111
1212use datafusion:: execution:: TaskContext ;
1313use datafusion:: logical_expr:: Accumulator ;
14- use datafusion:: physical_expr:: EquivalenceProperties ;
14+ use datafusion:: physical_expr:: { EquivalenceProperties , LexRequirement , PhysicalSortRequirement } ;
1515use datafusion:: physical_plan:: aggregates:: { create_accumulators, AccumulatorItem , AggregateMode } ;
1616use datafusion:: physical_plan:: common:: collect;
1717use datafusion:: physical_plan:: filter:: FilterExec ;
@@ -56,7 +56,8 @@ pub struct AggregateTopKExec {
5656 /// Always an instance of ClusterSendExec or WorkerExec.
5757 pub cluster : Arc < dyn ExecutionPlan > ,
5858 pub schema : SchemaRef ,
59- cache : PlanProperties ,
59+ pub cache : PlanProperties ,
60+ pub sort_requirement : LexRequirement ,
6061}
6162
6263/// Third item is the neutral value for the corresponding aggregate function.
@@ -72,6 +73,8 @@ impl AggregateTopKExec {
7273 having : Option < Arc < dyn PhysicalExpr > > ,
7374 cluster : Arc < dyn ExecutionPlan > ,
7475 schema : SchemaRef ,
76+ // sort_requirement is passed in by topk_plan mostly for the sake of code deduplication
77+ sort_requirement : LexRequirement ,
7578 ) -> AggregateTopKExec {
7679 assert_eq ! ( schema. fields( ) . len( ) , agg_expr. len( ) + key_len) ;
7780 assert_eq ! ( agg_fun. len( ) , agg_expr. len( ) ) ;
@@ -95,6 +98,7 @@ impl AggregateTopKExec {
9598 cluster,
9699 schema,
97100 cache,
101+ sort_requirement,
98102 }
99103 }
100104
@@ -171,13 +175,20 @@ impl ExecutionPlan for AggregateTopKExec {
171175 cluster,
172176 schema : self . schema . clone ( ) ,
173177 cache : self . cache . clone ( ) ,
178+ sort_requirement : self . sort_requirement . clone ( ) ,
174179 } ) )
175180 }
176181
177182 fn properties ( & self ) -> & PlanProperties {
178183 & self . cache
179184 }
180185
186+ // TODO upgrade DF: Probably should include output ordering in the PlanProperties.
187+
188+ fn required_input_ordering ( & self ) -> Vec < Option < LexRequirement > > {
189+ vec ! [ Some ( self . sort_requirement. clone( ) ) ]
190+ }
191+
181192 #[ tracing:: instrument( level = "trace" , skip( self ) ) ]
182193 fn execute (
183194 & self ,
@@ -996,6 +1007,7 @@ fn finalize_aggregation_into(
9961007#[ cfg( test) ]
9971008mod tests {
9981009 use super :: * ;
1010+ use crate :: queryplanner:: topk:: plan:: make_sort_expr;
9991011 use crate :: queryplanner:: topk:: { AggregateTopKExec , SortColumn } ;
10001012 use datafusion:: arrow:: array:: { Array , ArrayRef , Int64Array } ;
10011013 use datafusion:: arrow:: datatypes:: { DataType , Field , Schema , SchemaRef } ;
@@ -1418,17 +1430,21 @@ mod tests {
14181430 // config: ExecutionConfig::new(),
14191431 // execution_props: ExecutionProps::new(),
14201432 // };
1421- let agg_exprs = aggs
1433+ let agg_functions = aggs
14221434 . iter ( )
14231435 . enumerate ( )
1424- . map ( |( i, f) | Expr :: AggregateFunction ( AggregateFunction {
1436+ . map ( |( i, f) | AggregateFunction {
14251437 func : topk_fun_to_fusion_type ( & ctx, f) . unwrap ( ) ,
14261438 args : vec ! [ Expr :: Column ( Column :: from_name( format!( "agg{}" , i + 1 ) ) ) ] ,
14271439 distinct : false ,
14281440 filter : None ,
14291441 order_by : None ,
14301442 null_treatment : None ,
1431- } ) ) ;
1443+ } )
1444+ . collect :: < Vec < _ > > ( ) ;
1445+ let agg_exprs = agg_functions. iter ( ) . map ( |agg_fn|
1446+ Expr :: AggregateFunction ( agg_fn. clone ( ) )
1447+ ) ;
14321448 let physical_agg_exprs: Vec < ( AggregateFunctionExpr , Option < Arc < dyn PhysicalExpr > > , Option < Vec < datafusion:: physical_expr:: PhysicalSortExpr > > ) > = agg_exprs
14331449 . map ( |e| {
14341450 Ok ( create_aggregate_expr_and_maybe_filter (
@@ -1439,7 +1455,7 @@ mod tests {
14391455 ) ?)
14401456 } )
14411457 . collect :: < Result < Vec < _ > , DataFusionError > > ( ) ?;
1442- let ( agg_fn_exprs, agg_phys_exprs , _order_by) : ( Vec < _ > , Vec < _ > , Vec < _ > ) = itertools:: multiunzip ( physical_agg_exprs) ;
1458+ let ( agg_fn_exprs, _agg_phys_exprs , _order_by) : ( Vec < _ > , Vec < _ > , Vec < _ > ) = itertools:: multiunzip ( physical_agg_exprs) ;
14431459
14441460 let output_agg_fields = agg_fn_exprs
14451461 . iter ( )
@@ -1453,6 +1469,23 @@ mod tests {
14531469 . collect :: < Vec < _ > > ( ) ,
14541470 ) ) ;
14551471
1472+ let sort_requirement = order_by. iter ( ) . map ( |c| {
1473+ let i = key_len + c. agg_index ;
1474+ PhysicalSortRequirement {
1475+ expr : make_sort_expr (
1476+ & input_schema. inner ( ) ,
1477+ & aggs[ c. agg_index ] ,
1478+ Arc :: new ( datafusion:: physical_expr:: expressions:: Column :: new ( input_schema. field ( i) . name ( ) , i) ) ,
1479+ & agg_functions[ c. agg_index ] . args ,
1480+ & input_schema,
1481+ ) ,
1482+ options : Some ( SortOptions {
1483+ descending : !c. asc ,
1484+ nulls_first : c. nulls_first ,
1485+ } ) ,
1486+ }
1487+ } ) . collect ( ) ;
1488+
14561489 Ok ( AggregateTopKExec :: new (
14571490 limit,
14581491 key_len,
@@ -1462,6 +1495,7 @@ mod tests {
14621495 None ,
14631496 Arc :: new ( EmptyExec :: new ( input_schema. inner ( ) . clone ( ) ) ) ,
14641497 output_schema,
1498+ sort_requirement,
14651499 ) )
14661500 }
14671501
0 commit comments