@@ -28,7 +28,6 @@ use ballista_core::{
2828use datafusion:: physical_plan:: coalesce_partitions:: CoalescePartitionsExec ;
2929use datafusion:: physical_plan:: repartition:: RepartitionExec ;
3030use datafusion:: physical_plan:: sorts:: sort_preserving_merge:: SortPreservingMergeExec ;
31- use datafusion:: physical_plan:: windows:: WindowAggExec ;
3231use datafusion:: physical_plan:: {
3332 with_new_children_if_necessary, ExecutionPlan , Partitioning ,
3433} ;
@@ -148,12 +147,6 @@ impl DistributedPlanner {
148147 Ok ( ( children[ 0 ] . clone ( ) , stages) )
149148 }
150149 }
151- } else if let Some ( window) =
152- execution_plan. as_any ( ) . downcast_ref :: < WindowAggExec > ( )
153- {
154- Err ( BallistaError :: NotImplemented ( format ! (
155- "WindowAggExec with window {window:?}"
156- ) ) )
157150 } else {
158151 Ok ( (
159152 with_new_children_if_necessary ( execution_plan, children) ?,
@@ -305,15 +298,20 @@ mod test {
305298 use crate :: planner:: DistributedPlanner ;
306299 use crate :: test_utils:: datafusion_test_context;
307300 use ballista_core:: error:: BallistaError ;
308- use ballista_core:: execution_plans:: UnresolvedShuffleExec ;
301+ use ballista_core:: execution_plans:: { ShuffleWriterExec , UnresolvedShuffleExec } ;
309302 use ballista_core:: serde:: BallistaCodec ;
303+ use datafusion:: arrow:: compute:: SortOptions ;
304+ use datafusion:: physical_expr:: expressions:: Column ;
310305 use datafusion:: physical_plan:: aggregates:: { AggregateExec , AggregateMode } ;
311306 use datafusion:: physical_plan:: coalesce_batches:: CoalesceBatchesExec ;
307+ use datafusion:: physical_plan:: filter:: FilterExec ;
312308 use datafusion:: physical_plan:: joins:: HashJoinExec ;
313309 use datafusion:: physical_plan:: projection:: ProjectionExec ;
314310 use datafusion:: physical_plan:: sorts:: sort:: SortExec ;
315311 use datafusion:: physical_plan:: sorts:: sort_preserving_merge:: SortPreservingMergeExec ;
312+ use datafusion:: physical_plan:: windows:: BoundedWindowAggExec ;
316313 use datafusion:: physical_plan:: { displayable, ExecutionPlan } ;
314+ use datafusion:: physical_plan:: { InputOrderMode , Partitioning } ;
317315 use datafusion:: prelude:: SessionContext ;
318316 use datafusion_proto:: physical_plan:: AsExecutionPlan ;
319317 use datafusion_proto:: protobuf:: LogicalPlanNode ;
@@ -592,8 +590,121 @@ order by
592590 Ok ( ( ) )
593591 }
594592
595- #[ ignore]
596- // enable when upgrading Datafusion, a bug is fixed with https://github.com/apache/datafusion/pull/11926/
593+ #[ tokio:: test]
594+ async fn distributed_window_plan ( ) -> Result < ( ) , BallistaError > {
595+ let ctx = datafusion_test_context ( "testdata" ) . await ?;
596+ let session_state = ctx. state ( ) ;
597+
598+ // simplified form of TPC-DS query 67
599+ let df = ctx
600+ . sql (
601+ "
602+ select * from (
603+ select
604+ l_shipmode,
605+ l_shipdate,
606+ rank() over (partition by l_shipmode order by l_shipdate desc) rk
607+ from lineitem
608+ ) alias1
609+ where rk <= 100 order by l_shipdate, rk;
610+ " ,
611+ )
612+ . await ?;
613+
614+ let plan = df. into_optimized_plan ( ) ?;
615+ let plan = session_state. optimize ( & plan) ?;
616+ let plan = session_state. create_physical_plan ( & plan) . await ?;
617+
618+ let mut planner = DistributedPlanner :: new ( ) ;
619+ let job_uuid = Uuid :: new_v4 ( ) ;
620+ let stages = planner. plan_query_stages ( & job_uuid. to_string ( ) , plan) ?;
621+ for ( i, stage) in stages. iter ( ) . enumerate ( ) {
622+ println ! ( "Stage {i}:\n {}" , displayable( stage. as_ref( ) ) . indent( false ) ) ;
623+ }
624+ /*
625+ expected result:
626+ Stage 0:
627+ ShuffleWriterExec: Some(Hash([Column { name: "l_shipmode", index: 1 }], 2))
628+ CsvExec: file_groups={2 groups: [[testdata/lineitem/partition0.tbl], [testdata/lineitem/partition1.tbl]]}, projection=[l_shipdate, l_shipmode], has_header=false
629+
630+ Stage 1:
631+ ShuffleWriterExec: None
632+ SortExec: expr=[l_shipdate@1 ASC NULLS LAST,rk@2 ASC NULLS LAST], preserve_partitioning=[true]
633+ ProjectionExec: expr=[l_shipmode@1 as l_shipmode, l_shipdate@0 as l_shipdate, RANK() PARTITION BY [lineitem.l_shipmode] ORDER BY [lineitem.l_shipdate DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as rk]
634+ CoalesceBatchesExec: target_batch_size=8192
635+ FilterExec: RANK() PARTITION BY [lineitem.l_shipmode] ORDER BY [lineitem.l_shipdate DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 <= 100
636+ BoundedWindowAggExec: wdw=[RANK() PARTITION BY [lineitem.l_shipmode] ORDER BY [lineitem.l_shipdate DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "RANK() PARTITION BY [lineitem.l_shipmode] ORDER BY [lineitem.l_shipdate DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(IntervalMonthDayNano("NULL")), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]
637+ SortExec: expr=[l_shipmode@1 ASC NULLS LAST,l_shipdate@0 DESC], preserve_partitioning=[true]
638+ CoalesceBatchesExec: target_batch_size=8192
639+ UnresolvedShuffleExec
640+
641+ Stage 2:
642+ ShuffleWriterExec: None
643+ SortPreservingMergeExec: [l_shipdate@1 ASC NULLS LAST,rk@2 ASC NULLS LAST]
644+ UnresolvedShuffleExec
645+
646+ */
647+
648+ assert_eq ! ( 3 , stages. len( ) ) ;
649+
650+ // stage0
651+ let stage0 = stages[ 0 ] . clone ( ) ;
652+ let shuffle_write = downcast_exec ! ( stage0, ShuffleWriterExec ) ;
653+ let partitioning = shuffle_write. shuffle_output_partitioning ( ) . expect ( "stage0" ) ;
654+ assert_eq ! ( 2 , partitioning. partition_count( ) ) ;
655+ let partition_col = match partitioning {
656+ Partitioning :: Hash ( exprs, 2 ) => match exprs. as_slice ( ) {
657+ [ ref col] => col. as_any ( ) . downcast_ref :: < Column > ( ) ,
658+ _ => None ,
659+ } ,
660+ _ => None ,
661+ } ;
662+ assert_eq ! ( Some ( & Column :: new( "l_shipmode" , 1 ) ) , partition_col) ;
663+
664+ // stage1
665+ let sort = downcast_exec ! ( stages[ 1 ] . children( ) [ 0 ] , SortExec ) ;
666+ let projection = downcast_exec ! ( sort. children( ) [ 0 ] , ProjectionExec ) ;
667+ let coalesce = downcast_exec ! ( projection. children( ) [ 0 ] , CoalesceBatchesExec ) ;
668+ let filter = downcast_exec ! ( coalesce. children( ) [ 0 ] , FilterExec ) ;
669+ let window = downcast_exec ! ( filter. children( ) [ 0 ] , BoundedWindowAggExec ) ;
670+ let partition_by = match window. partition_keys . as_slice ( ) {
671+ [ ref col] => col. as_any ( ) . downcast_ref :: < Column > ( ) ,
672+ _ => None ,
673+ } ;
674+ assert_eq ! ( Some ( & Column :: new( "l_shipmode" , 1 ) ) , partition_by) ;
675+ assert_eq ! ( InputOrderMode :: Sorted , window. input_order_mode) ;
676+ let sort = downcast_exec ! ( window. children( ) [ 0 ] , SortExec ) ;
677+ match sort. expr ( ) {
678+ [ expr1, expr2] => {
679+ assert_eq ! (
680+ SortOptions {
681+ descending: false ,
682+ nulls_first: false
683+ } ,
684+ expr1. options
685+ ) ;
686+ assert_eq ! (
687+ Some ( & Column :: new( "l_shipmode" , 1 ) ) ,
688+ expr1. expr. as_any( ) . downcast_ref( )
689+ ) ;
690+ assert_eq ! (
691+ SortOptions {
692+ descending: true ,
693+ nulls_first: true
694+ } ,
695+ expr2. options
696+ ) ;
697+ assert_eq ! (
698+ Some ( & Column :: new( "l_shipdate" , 0 ) ) ,
699+ expr2. expr. as_any( ) . downcast_ref( )
700+ ) ;
701+ }
702+ _ => panic ! ( "invalid sort {:?}" , sort) ,
703+ } ;
704+
705+ Ok ( ( ) )
706+ }
707+
597708 #[ tokio:: test]
598709 async fn roundtrip_serde_aggregate ( ) -> Result < ( ) , BallistaError > {
599710 let ctx = datafusion_test_context ( "testdata" ) . await ?;
0 commit comments