Skip to content

Commit b04c86a

Browse files
authored
Add stage planner tests (#78)
* Uncomment tests in favor of just #[ignore]-ing them * Remove unused import statements * Move common tpch module to common * Unignore one more test * Fix ArrowFlightReadExec * Add stage planner tests * Add test_left_join_distributed
1 parent 49af3ae commit b04c86a

File tree

4 files changed

+324
-2
lines changed

4 files changed

+324
-2
lines changed

src/physical_optimizer.rs

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,3 +258,280 @@ impl TreeNodeRewriter for StagePlanner {
258258
}
259259
}
260260
}
261+
262+
#[cfg(test)]
263+
mod tests {
264+
use crate::assert_snapshot;
265+
use crate::physical_optimizer::DistributedPhysicalOptimizerRule;
266+
use crate::test_utils::register_parquet_tables;
267+
use datafusion::error::DataFusionError;
268+
use datafusion::execution::SessionStateBuilder;
269+
use datafusion::physical_plan::displayable;
270+
use datafusion::prelude::{SessionConfig, SessionContext};
271+
use std::sync::Arc;
272+
273+
/* shema for the "weather" table
274+
275+
MinTemp [type=DOUBLE] [repetitiontype=OPTIONAL]
276+
MaxTemp [type=DOUBLE] [repetitiontype=OPTIONAL]
277+
Rainfall [type=DOUBLE] [repetitiontype=OPTIONAL]
278+
Evaporation [type=DOUBLE] [repetitiontype=OPTIONAL]
279+
Sunshine [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
280+
WindGustDir [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
281+
WindGustSpeed [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
282+
WindDir9am [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
283+
WindDir3pm [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
284+
WindSpeed9am [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
285+
WindSpeed3pm [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
286+
Humidity9am [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
287+
Humidity3pm [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
288+
Pressure9am [type=DOUBLE] [repetitiontype=OPTIONAL]
289+
Pressure3pm [type=DOUBLE] [repetitiontype=OPTIONAL]
290+
Cloud9am [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
291+
Cloud3pm [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
292+
Temp9am [type=DOUBLE] [repetitiontype=OPTIONAL]
293+
Temp3pm [type=DOUBLE] [repetitiontype=OPTIONAL]
294+
RainToday [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
295+
RISK_MM [type=DOUBLE] [repetitiontype=OPTIONAL]
296+
RainTomorrow [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
297+
*/
298+
299+
#[tokio::test]
300+
async fn test_select_all() {
301+
let query = r#"SELECT * FROM weather"#;
302+
let plan = sql_to_explain(query).await.unwrap();
303+
assert_snapshot!(plan, @r"
304+
┌───── Stage 1 Task: partitions: 0,unassigned]
305+
│partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MinTemp, MaxTemp, Rainfall, Evaporation, Sunshine, WindGustDir, WindGustSpeed, WindDir9am, WindDir3pm, WindSpeed9am, WindSpeed3pm, Humidity9am, Humidity3pm, Pressure9am, Pressure3pm, Cloud9am, Cloud3pm, Temp9am, Temp3pm, RainToday, RISK_MM, RainTomorrow], file_type=parquet
306+
307+
└──────────────────────────────────────────────────
308+
");
309+
}
310+
311+
#[tokio::test]
312+
async fn test_aggregation() {
313+
let query =
314+
r#"SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)"#;
315+
let plan = sql_to_explain(query).await.unwrap();
316+
assert_snapshot!(plan, @r"
317+
┌───── Stage 3 Task: partitions: 0,unassigned]
318+
│partitions [out:1 <-- in:1 ] ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday]
319+
│partitions [out:1 <-- in:4 ] SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST]
320+
│partitions [out:4 <-- in:4 ] SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true]
321+
│partitions [out:4 <-- in:4 ] ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))]
322+
│partitions [out:4 <-- in:4 ] AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
323+
│partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
324+
│partitions [out:4 ] ArrowFlightReadExec: Stage 2
325+
326+
└──────────────────────────────────────────────────
327+
┌───── Stage 2 Task: partitions: 0..3,unassigned]
328+
│partitions [out:4 <-- in:4 ] RepartitionExec: partitioning=Hash([RainToday@0], 4), input_partitions=4
329+
│partitions [out:4 ] ArrowFlightReadExec: Stage 1
330+
331+
└──────────────────────────────────────────────────
332+
┌───── Stage 1 Task: partitions: 0..3,unassigned]
333+
│partitions [out:4 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
334+
│partitions [out:1 <-- in:1 ] AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
335+
│partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet
336+
337+
└──────────────────────────────────────────────────
338+
");
339+
}
340+
341+
#[tokio::test]
342+
async fn test_aggregation_with_partitions_per_task() {
343+
let query =
344+
r#"SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)"#;
345+
let plan = sql_to_explain_partitions_per_task(query, 2).await.unwrap();
346+
assert_snapshot!(plan, @r"
347+
┌───── Stage 3 Task: partitions: 0,unassigned]
348+
│partitions [out:1 <-- in:1 ] ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday]
349+
│partitions [out:1 <-- in:4 ] SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST]
350+
│partitions [out:4 <-- in:4 ] SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true]
351+
│partitions [out:4 <-- in:4 ] ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))]
352+
│partitions [out:4 <-- in:4 ] AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
353+
│partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
354+
│partitions [out:4 ] ArrowFlightReadExec: Stage 2
355+
356+
└──────────────────────────────────────────────────
357+
┌───── Stage 2 Task: partitions: 0,1,unassigned],Task: partitions: 2,3,unassigned]
358+
│partitions [out:4 <-- in:2 ] RepartitionExec: partitioning=Hash([RainToday@0], 4), input_partitions=2
359+
│partitions [out:2 <-- in:4 ] PartitionIsolatorExec [providing upto 2 partitions]
360+
│partitions [out:4 ] ArrowFlightReadExec: Stage 1
361+
362+
└──────────────────────────────────────────────────
363+
┌───── Stage 1 Task: partitions: 0,1,unassigned],Task: partitions: 2,3,unassigned]
364+
│partitions [out:4 <-- in:2 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2
365+
│partitions [out:2 <-- in:1 ] PartitionIsolatorExec [providing upto 2 partitions]
366+
│partitions [out:1 <-- in:1 ] AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
367+
│partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet
368+
369+
└──────────────────────────────────────────────────
370+
");
371+
}
372+
373+
#[tokio::test]
374+
async fn test_left_join() {
375+
let query = r#"SELECT a."MinTemp", b."MaxTemp" FROM weather a LEFT JOIN weather b ON a."RainToday" = b."RainToday" "#;
376+
let plan = sql_to_explain(query).await.unwrap();
377+
assert_snapshot!(plan, @r"
378+
┌───── Stage 1 Task: partitions: 0,unassigned]
379+
│partitions [out:1 <-- in:1 ] CoalesceBatchesExec: target_batch_size=8192
380+
│partitions [out:1 <-- in:1 ] HashJoinExec: mode=Partitioned, join_type=Left, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
381+
│partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
382+
│partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet
383+
384+
└──────────────────────────────────────────────────
385+
");
386+
}
387+
388+
#[tokio::test]
389+
async fn test_left_join_distributed() {
390+
let query = r#"
391+
WITH a AS (
392+
SELECT
393+
AVG("MinTemp") as "MinTemp",
394+
"RainTomorrow"
395+
FROM weather
396+
WHERE "RainToday" = 'yes'
397+
GROUP BY "RainTomorrow"
398+
), b AS (
399+
SELECT
400+
AVG("MaxTemp") as "MaxTemp",
401+
"RainTomorrow"
402+
FROM weather
403+
WHERE "RainToday" = 'no'
404+
GROUP BY "RainTomorrow"
405+
)
406+
SELECT
407+
a."MinTemp",
408+
b."MaxTemp"
409+
FROM a
410+
LEFT JOIN b
411+
ON a."RainTomorrow" = b."RainTomorrow"
412+
413+
"#;
414+
let plan = sql_to_explain(query).await.unwrap();
415+
assert_snapshot!(plan, @r"
416+
┌───── Stage 5 Task: partitions: 0..3,unassigned]
417+
│partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
418+
│partitions [out:4 <-- in:1 ] HashJoinExec: mode=CollectLeft, join_type=Left, on=[(RainTomorrow@1, RainTomorrow@1)], projection=[MinTemp@0, MaxTemp@2]
419+
│partitions [out:1 <-- in:4 ] CoalescePartitionsExec
420+
│partitions [out:4 <-- in:4 ] ProjectionExec: expr=[avg(weather.MinTemp)@1 as MinTemp, RainTomorrow@0 as RainTomorrow]
421+
│partitions [out:4 <-- in:4 ] AggregateExec: mode=FinalPartitioned, gby=[RainTomorrow@0 as RainTomorrow], aggr=[avg(weather.MinTemp)]
422+
│partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
423+
│partitions [out:4 ] ArrowFlightReadExec: Stage 2
424+
│partitions [out:4 <-- in:4 ] ProjectionExec: expr=[avg(weather.MaxTemp)@1 as MaxTemp, RainTomorrow@0 as RainTomorrow]
425+
│partitions [out:4 <-- in:4 ] AggregateExec: mode=FinalPartitioned, gby=[RainTomorrow@0 as RainTomorrow], aggr=[avg(weather.MaxTemp)]
426+
│partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
427+
│partitions [out:4 ] ArrowFlightReadExec: Stage 4
428+
429+
└──────────────────────────────────────────────────
430+
┌───── Stage 4 Task: partitions: 0..3,unassigned]
431+
│partitions [out:4 <-- in:4 ] RepartitionExec: partitioning=Hash([RainTomorrow@0], 4), input_partitions=4
432+
│partitions [out:4 <-- in:4 ] AggregateExec: mode=Partial, gby=[RainTomorrow@1 as RainTomorrow], aggr=[avg(weather.MaxTemp)]
433+
│partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
434+
│partitions [out:4 <-- in:4 ] FilterExec: RainToday@1 = no, projection=[MaxTemp@0, RainTomorrow@2]
435+
│partitions [out:4 ] ArrowFlightReadExec: Stage 3
436+
437+
└──────────────────────────────────────────────────
438+
┌───── Stage 3 Task: partitions: 0..3,unassigned]
439+
│partitions [out:4 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
440+
│partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MaxTemp, RainToday, RainTomorrow], file_type=parquet, predicate=RainToday@1 = no, pruning_predicate=RainToday_null_count@2 != row_count@3 AND RainToday_min@0 <= no AND no <= RainToday_max@1, required_guarantees=[RainToday in (no)]
441+
442+
443+
└──────────────────────────────────────────────────
444+
┌───── Stage 2 Task: partitions: 0..3,unassigned]
445+
│partitions [out:4 <-- in:4 ] RepartitionExec: partitioning=Hash([RainTomorrow@0], 4), input_partitions=4
446+
│partitions [out:4 <-- in:4 ] AggregateExec: mode=Partial, gby=[RainTomorrow@1 as RainTomorrow], aggr=[avg(weather.MinTemp)]
447+
│partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
448+
│partitions [out:4 <-- in:4 ] FilterExec: RainToday@1 = yes, projection=[MinTemp@0, RainTomorrow@2]
449+
│partitions [out:4 ] ArrowFlightReadExec: Stage 1
450+
451+
└──────────────────────────────────────────────────
452+
┌───── Stage 1 Task: partitions: 0..3,unassigned]
453+
│partitions [out:4 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
454+
│partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MinTemp, RainToday, RainTomorrow], file_type=parquet, predicate=RainToday@1 = yes, pruning_predicate=RainToday_null_count@2 != row_count@3 AND RainToday_min@0 <= yes AND yes <= RainToday_max@1, required_guarantees=[RainToday in (yes)]
455+
456+
457+
└──────────────────────────────────────────────────
458+
");
459+
}
460+
461+
#[tokio::test]
462+
async fn test_sort() {
463+
let query = r#"SELECT * FROM weather ORDER BY "MinTemp" DESC "#;
464+
let plan = sql_to_explain(query).await.unwrap();
465+
assert_snapshot!(plan, @r"
466+
┌───── Stage 1 Task: partitions: 0,unassigned]
467+
│partitions [out:1 <-- in:1 ] SortExec: expr=[MinTemp@0 DESC], preserve_partitioning=[false]
468+
│partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MinTemp, MaxTemp, Rainfall, Evaporation, Sunshine, WindGustDir, WindGustSpeed, WindDir9am, WindDir3pm, WindSpeed9am, WindSpeed3pm, Humidity9am, Humidity3pm, Pressure9am, Pressure3pm, Cloud9am, Cloud3pm, Temp9am, Temp3pm, RainToday, RISK_MM, RainTomorrow], file_type=parquet
469+
470+
└──────────────────────────────────────────────────
471+
");
472+
}
473+
474+
#[tokio::test]
475+
async fn test_distinct() {
476+
let query = r#"SELECT DISTINCT "RainToday", "WindGustDir" FROM weather"#;
477+
let plan = sql_to_explain(query).await.unwrap();
478+
assert_snapshot!(plan, @r"
479+
┌───── Stage 3 Task: partitions: 0..3,unassigned]
480+
│partitions [out:4 <-- in:4 ] AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday, WindGustDir@1 as WindGustDir], aggr=[]
481+
│partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
482+
│partitions [out:4 ] ArrowFlightReadExec: Stage 2
483+
484+
└──────────────────────────────────────────────────
485+
┌───── Stage 2 Task: partitions: 0..3,unassigned]
486+
│partitions [out:4 <-- in:4 ] RepartitionExec: partitioning=Hash([RainToday@0, WindGustDir@1], 4), input_partitions=4
487+
│partitions [out:4 ] ArrowFlightReadExec: Stage 1
488+
489+
└──────────────────────────────────────────────────
490+
┌───── Stage 1 Task: partitions: 0..3,unassigned]
491+
│partitions [out:4 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
492+
│partitions [out:1 <-- in:1 ] AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday, WindGustDir@1 as WindGustDir], aggr=[]
493+
│partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday, WindGustDir], file_type=parquet
494+
495+
└──────────────────────────────────────────────────
496+
");
497+
}
498+
499+
async fn sql_to_explain(query: &str) -> Result<String, DataFusionError> {
500+
sql_to_explain_with_rule(query, DistributedPhysicalOptimizerRule::new()).await
501+
}
502+
503+
async fn sql_to_explain_partitions_per_task(
504+
query: &str,
505+
partitions_per_task: usize,
506+
) -> Result<String, DataFusionError> {
507+
sql_to_explain_with_rule(
508+
query,
509+
DistributedPhysicalOptimizerRule::new()
510+
.with_maximum_partitions_per_task(partitions_per_task),
511+
)
512+
.await
513+
}
514+
515+
async fn sql_to_explain_with_rule(
516+
query: &str,
517+
rule: DistributedPhysicalOptimizerRule,
518+
) -> Result<String, DataFusionError> {
519+
let config = SessionConfig::new().with_target_partitions(4);
520+
521+
let state = SessionStateBuilder::new()
522+
.with_default_features()
523+
.with_physical_optimizer_rule(Arc::new(rule))
524+
.with_config(config)
525+
.build();
526+
527+
let ctx = SessionContext::new_with_state(state);
528+
register_parquet_tables(&ctx).await?;
529+
530+
let df = ctx.sql(query).await?;
531+
532+
let physical_plan = df.create_physical_plan().await?;
533+
let display = displayable(physical_plan.as_ref()).indent(true).to_string();
534+
535+
Ok(display)
536+
}
537+
}

src/test_utils/insta.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
use std::env;
2+
3+
#[macro_export]
4+
macro_rules! assert_snapshot {
5+
($($arg:tt)*) => {
6+
crate::test_utils::insta::settings().bind(|| {
7+
insta::assert_snapshot!($($arg)*);
8+
})
9+
};
10+
}
11+
12+
pub fn settings() -> insta::Settings {
13+
env::set_var("INSTA_WORKSPACE_ROOT", env!("CARGO_MANIFEST_DIR"));
14+
let mut settings = insta::Settings::clone_current();
15+
let cwd = env::current_dir().unwrap();
16+
let cwd = cwd.to_str().unwrap();
17+
settings.add_filter(cwd.trim_start_matches("/"), "");
18+
settings.add_filter(
19+
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}",
20+
"UUID",
21+
);
22+
23+
settings
24+
}

src/test_utils/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
#[cfg(test)]
1+
pub mod insta;
22
mod mock_exec;
3+
mod parquet;
34

4-
#[cfg(test)]
55
pub use mock_exec::MockExec;
6+
pub use parquet::register_parquet_tables;

0 commit comments

Comments
 (0)