Skip to content

Commit 5b6b50b

Browse files
authored
support window functions (#1112)
* support window functions * clippy * new test for distributed planner for a window expression query, enable previously disabled test
1 parent 68b2277 commit 5b6b50b

File tree

1 file changed

+121
-10
lines changed

1 file changed

+121
-10
lines changed

ballista/scheduler/src/planner.rs

Lines changed: 121 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ use ballista_core::{
2828
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
2929
use datafusion::physical_plan::repartition::RepartitionExec;
3030
use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
31-
use datafusion::physical_plan::windows::WindowAggExec;
3231
use datafusion::physical_plan::{
3332
with_new_children_if_necessary, ExecutionPlan, Partitioning,
3433
};
@@ -148,12 +147,6 @@ impl DistributedPlanner {
148147
Ok((children[0].clone(), stages))
149148
}
150149
}
151-
} else if let Some(window) =
152-
execution_plan.as_any().downcast_ref::<WindowAggExec>()
153-
{
154-
Err(BallistaError::NotImplemented(format!(
155-
"WindowAggExec with window {window:?}"
156-
)))
157150
} else {
158151
Ok((
159152
with_new_children_if_necessary(execution_plan, children)?,
@@ -305,15 +298,20 @@ mod test {
305298
use crate::planner::DistributedPlanner;
306299
use crate::test_utils::datafusion_test_context;
307300
use ballista_core::error::BallistaError;
308-
use ballista_core::execution_plans::UnresolvedShuffleExec;
301+
use ballista_core::execution_plans::{ShuffleWriterExec, UnresolvedShuffleExec};
309302
use ballista_core::serde::BallistaCodec;
303+
use datafusion::arrow::compute::SortOptions;
304+
use datafusion::physical_expr::expressions::Column;
310305
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode};
311306
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
307+
use datafusion::physical_plan::filter::FilterExec;
312308
use datafusion::physical_plan::joins::HashJoinExec;
313309
use datafusion::physical_plan::projection::ProjectionExec;
314310
use datafusion::physical_plan::sorts::sort::SortExec;
315311
use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
312+
use datafusion::physical_plan::windows::BoundedWindowAggExec;
316313
use datafusion::physical_plan::{displayable, ExecutionPlan};
314+
use datafusion::physical_plan::{InputOrderMode, Partitioning};
317315
use datafusion::prelude::SessionContext;
318316
use datafusion_proto::physical_plan::AsExecutionPlan;
319317
use datafusion_proto::protobuf::LogicalPlanNode;
@@ -592,8 +590,121 @@ order by
592590
Ok(())
593591
}
594592

595-
#[ignore]
596-
// enable when upgrading Datafusion, a bug is fixed with https://github.com/apache/datafusion/pull/11926/
593+
#[tokio::test]
594+
async fn distributed_window_plan() -> Result<(), BallistaError> {
595+
let ctx = datafusion_test_context("testdata").await?;
596+
let session_state = ctx.state();
597+
598+
// simplified form of TPC-DS query 67
599+
let df = ctx
600+
.sql(
601+
"
602+
select * from (
603+
select
604+
l_shipmode,
605+
l_shipdate,
606+
rank() over (partition by l_shipmode order by l_shipdate desc) rk
607+
from lineitem
608+
) alias1
609+
where rk <= 100 order by l_shipdate, rk;
610+
",
611+
)
612+
.await?;
613+
614+
let plan = df.into_optimized_plan()?;
615+
let plan = session_state.optimize(&plan)?;
616+
let plan = session_state.create_physical_plan(&plan).await?;
617+
618+
let mut planner = DistributedPlanner::new();
619+
let job_uuid = Uuid::new_v4();
620+
let stages = planner.plan_query_stages(&job_uuid.to_string(), plan)?;
621+
for (i, stage) in stages.iter().enumerate() {
622+
println!("Stage {i}:\n{}", displayable(stage.as_ref()).indent(false));
623+
}
624+
/*
625+
expected result:
626+
Stage 0:
627+
ShuffleWriterExec: Some(Hash([Column { name: "l_shipmode", index: 1 }], 2))
628+
CsvExec: file_groups={2 groups: [[testdata/lineitem/partition0.tbl], [testdata/lineitem/partition1.tbl]]}, projection=[l_shipdate, l_shipmode], has_header=false
629+
630+
Stage 1:
631+
ShuffleWriterExec: None
632+
SortExec: expr=[l_shipdate@1 ASC NULLS LAST,rk@2 ASC NULLS LAST], preserve_partitioning=[true]
633+
ProjectionExec: expr=[l_shipmode@1 as l_shipmode, l_shipdate@0 as l_shipdate, RANK() PARTITION BY [lineitem.l_shipmode] ORDER BY [lineitem.l_shipdate DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as rk]
634+
CoalesceBatchesExec: target_batch_size=8192
635+
FilterExec: RANK() PARTITION BY [lineitem.l_shipmode] ORDER BY [lineitem.l_shipdate DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 <= 100
636+
BoundedWindowAggExec: wdw=[RANK() PARTITION BY [lineitem.l_shipmode] ORDER BY [lineitem.l_shipdate DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "RANK() PARTITION BY [lineitem.l_shipmode] ORDER BY [lineitem.l_shipdate DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(IntervalMonthDayNano("NULL")), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]
637+
SortExec: expr=[l_shipmode@1 ASC NULLS LAST,l_shipdate@0 DESC], preserve_partitioning=[true]
638+
CoalesceBatchesExec: target_batch_size=8192
639+
UnresolvedShuffleExec
640+
641+
Stage 2:
642+
ShuffleWriterExec: None
643+
SortPreservingMergeExec: [l_shipdate@1 ASC NULLS LAST,rk@2 ASC NULLS LAST]
644+
UnresolvedShuffleExec
645+
646+
*/
647+
648+
assert_eq!(3, stages.len());
649+
650+
// stage0
651+
let stage0 = stages[0].clone();
652+
let shuffle_write = downcast_exec!(stage0, ShuffleWriterExec);
653+
let partitioning = shuffle_write.shuffle_output_partitioning().expect("stage0");
654+
assert_eq!(2, partitioning.partition_count());
655+
let partition_col = match partitioning {
656+
Partitioning::Hash(exprs, 2) => match exprs.as_slice() {
657+
[ref col] => col.as_any().downcast_ref::<Column>(),
658+
_ => None,
659+
},
660+
_ => None,
661+
};
662+
assert_eq!(Some(&Column::new("l_shipmode", 1)), partition_col);
663+
664+
// stage1
665+
let sort = downcast_exec!(stages[1].children()[0], SortExec);
666+
let projection = downcast_exec!(sort.children()[0], ProjectionExec);
667+
let coalesce = downcast_exec!(projection.children()[0], CoalesceBatchesExec);
668+
let filter = downcast_exec!(coalesce.children()[0], FilterExec);
669+
let window = downcast_exec!(filter.children()[0], BoundedWindowAggExec);
670+
let partition_by = match window.partition_keys.as_slice() {
671+
[ref col] => col.as_any().downcast_ref::<Column>(),
672+
_ => None,
673+
};
674+
assert_eq!(Some(&Column::new("l_shipmode", 1)), partition_by);
675+
assert_eq!(InputOrderMode::Sorted, window.input_order_mode);
676+
let sort = downcast_exec!(window.children()[0], SortExec);
677+
match sort.expr() {
678+
[expr1, expr2] => {
679+
assert_eq!(
680+
SortOptions {
681+
descending: false,
682+
nulls_first: false
683+
},
684+
expr1.options
685+
);
686+
assert_eq!(
687+
Some(&Column::new("l_shipmode", 1)),
688+
expr1.expr.as_any().downcast_ref()
689+
);
690+
assert_eq!(
691+
SortOptions {
692+
descending: true,
693+
nulls_first: true
694+
},
695+
expr2.options
696+
);
697+
assert_eq!(
698+
Some(&Column::new("l_shipdate", 0)),
699+
expr2.expr.as_any().downcast_ref()
700+
);
701+
}
702+
_ => panic!("invalid sort {:?}", sort),
703+
};
704+
705+
Ok(())
706+
}
707+
597708
#[tokio::test]
598709
async fn roundtrip_serde_aggregate() -> Result<(), BallistaError> {
599710
let ctx = datafusion_test_context("testdata").await?;

0 commit comments

Comments
 (0)