Skip to content

Commit 8067b44

Browse files
committed
Add stage planner tests
1 parent 17be2c5 commit 8067b44

File tree

4 files changed

+251
-2
lines changed

4 files changed

+251
-2
lines changed

src/physical_optimizer.rs

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,3 +258,207 @@ impl TreeNodeRewriter for StagePlanner {
258258
}
259259
}
260260
}
261+
262+
#[cfg(test)]
263+
mod tests {
264+
use crate::assert_snapshot;
265+
use crate::physical_optimizer::DistributedPhysicalOptimizerRule;
266+
use crate::test_utils::register_parquet_tables;
267+
use datafusion::error::DataFusionError;
268+
use datafusion::execution::SessionStateBuilder;
269+
use datafusion::physical_plan::displayable;
270+
use datafusion::prelude::{SessionConfig, SessionContext};
271+
use std::sync::Arc;
272+
273+
/* shema for the "weather" table
274+
275+
MinTemp [type=DOUBLE] [repetitiontype=OPTIONAL]
276+
MaxTemp [type=DOUBLE] [repetitiontype=OPTIONAL]
277+
Rainfall [type=DOUBLE] [repetitiontype=OPTIONAL]
278+
Evaporation [type=DOUBLE] [repetitiontype=OPTIONAL]
279+
Sunshine [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
280+
WindGustDir [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
281+
WindGustSpeed [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
282+
WindDir9am [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
283+
WindDir3pm [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
284+
WindSpeed9am [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
285+
WindSpeed3pm [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
286+
Humidity9am [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
287+
Humidity3pm [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
288+
Pressure9am [type=DOUBLE] [repetitiontype=OPTIONAL]
289+
Pressure3pm [type=DOUBLE] [repetitiontype=OPTIONAL]
290+
Cloud9am [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
291+
Cloud3pm [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
292+
Temp9am [type=DOUBLE] [repetitiontype=OPTIONAL]
293+
Temp3pm [type=DOUBLE] [repetitiontype=OPTIONAL]
294+
RainToday [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
295+
RISK_MM [type=DOUBLE] [repetitiontype=OPTIONAL]
296+
RainTomorrow [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
297+
*/
298+
299+
#[tokio::test]
300+
async fn test_select_all() {
301+
let query = r#"SELECT * FROM weather"#;
302+
let plan = sql_to_explain(query).await.unwrap();
303+
assert_snapshot!(plan, @r"
304+
┌───── Stage 1 Task: partitions: 0,unassigned]
305+
│partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MinTemp, MaxTemp, Rainfall, Evaporation, Sunshine, WindGustDir, WindGustSpeed, WindDir9am, WindDir3pm, WindSpeed9am, WindSpeed3pm, Humidity9am, Humidity3pm, Pressure9am, Pressure3pm, Cloud9am, Cloud3pm, Temp9am, Temp3pm, RainToday, RISK_MM, RainTomorrow], file_type=parquet
306+
307+
└──────────────────────────────────────────────────
308+
");
309+
}
310+
311+
#[tokio::test]
312+
async fn test_aggregation() {
313+
let query =
314+
r#"SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)"#;
315+
let plan = sql_to_explain(query).await.unwrap();
316+
assert_snapshot!(plan, @r"
317+
┌───── Stage 3 Task: partitions: 0,unassigned]
318+
│partitions [out:1 <-- in:1 ] ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday]
319+
│partitions [out:1 <-- in:4 ] SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST]
320+
│partitions [out:4 <-- in:4 ] SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true]
321+
│partitions [out:4 <-- in:4 ] ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))]
322+
│partitions [out:4 <-- in:4 ] AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
323+
│partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
324+
│partitions [out:4 ] ArrowFlightReadExec: Stage 2
325+
326+
└──────────────────────────────────────────────────
327+
┌───── Stage 2 Task: partitions: 0..3,unassigned]
328+
│partitions [out:4 <-- in:4 ] RepartitionExec: partitioning=Hash([RainToday@0], 4), input_partitions=4
329+
│partitions [out:4 ] ArrowFlightReadExec: Stage 1
330+
331+
└──────────────────────────────────────────────────
332+
┌───── Stage 1 Task: partitions: 0..3,unassigned]
333+
│partitions [out:4 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
334+
│partitions [out:1 <-- in:1 ] AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
335+
│partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet
336+
337+
└──────────────────────────────────────────────────
338+
");
339+
}
340+
341+
#[tokio::test]
342+
async fn test_aggregation_with_partitions_per_task() {
343+
let query =
344+
r#"SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)"#;
345+
let plan = sql_to_explain_partitions_per_task(query, 2).await.unwrap();
346+
assert_snapshot!(plan, @r"
347+
┌───── Stage 3 Task: partitions: 0,unassigned]
348+
│partitions [out:1 <-- in:1 ] ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday]
349+
│partitions [out:1 <-- in:4 ] SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST]
350+
│partitions [out:4 <-- in:4 ] SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true]
351+
│partitions [out:4 <-- in:4 ] ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))]
352+
│partitions [out:4 <-- in:4 ] AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
353+
│partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
354+
│partitions [out:4 ] ArrowFlightReadExec: Stage 2
355+
356+
└──────────────────────────────────────────────────
357+
┌───── Stage 2 Task: partitions: 0,1,unassigned],Task: partitions: 2,3,unassigned]
358+
│partitions [out:4 <-- in:2 ] RepartitionExec: partitioning=Hash([RainToday@0], 4), input_partitions=2
359+
│partitions [out:2 <-- in:4 ] PartitionIsolatorExec [providing upto 2 partitions]
360+
│partitions [out:4 ] ArrowFlightReadExec: Stage 1
361+
362+
└──────────────────────────────────────────────────
363+
┌───── Stage 1 Task: partitions: 0,1,unassigned],Task: partitions: 2,3,unassigned]
364+
│partitions [out:4 <-- in:2 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2
365+
│partitions [out:2 <-- in:1 ] PartitionIsolatorExec [providing upto 2 partitions]
366+
│partitions [out:1 <-- in:1 ] AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
367+
│partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet
368+
369+
└──────────────────────────────────────────────────
370+
");
371+
}
372+
373+
#[tokio::test]
374+
async fn test_left_join() {
375+
let query = r#"SELECT a."MinTemp", b."MaxTemp" FROM weather a LEFT JOIN weather b ON a."RainToday" = b."RainToday" "#;
376+
let plan = sql_to_explain(query).await.unwrap();
377+
assert_snapshot!(plan, @r"
378+
┌───── Stage 1 Task: partitions: 0,unassigned]
379+
│partitions [out:1 <-- in:1 ] CoalesceBatchesExec: target_batch_size=8192
380+
│partitions [out:1 <-- in:1 ] HashJoinExec: mode=Partitioned, join_type=Left, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
381+
│partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
382+
│partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet
383+
384+
└──────────────────────────────────────────────────
385+
");
386+
}
387+
388+
#[tokio::test]
389+
async fn test_sort() {
390+
let query = r#"SELECT * FROM weather ORDER BY "MinTemp" DESC "#;
391+
let plan = sql_to_explain(query).await.unwrap();
392+
assert_snapshot!(plan, @r"
393+
┌───── Stage 1 Task: partitions: 0,unassigned]
394+
│partitions [out:1 <-- in:1 ] SortExec: expr=[MinTemp@0 DESC], preserve_partitioning=[false]
395+
│partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MinTemp, MaxTemp, Rainfall, Evaporation, Sunshine, WindGustDir, WindGustSpeed, WindDir9am, WindDir3pm, WindSpeed9am, WindSpeed3pm, Humidity9am, Humidity3pm, Pressure9am, Pressure3pm, Cloud9am, Cloud3pm, Temp9am, Temp3pm, RainToday, RISK_MM, RainTomorrow], file_type=parquet
396+
397+
└──────────────────────────────────────────────────
398+
");
399+
}
400+
401+
#[tokio::test]
402+
async fn test_distinct() {
403+
let query = r#"SELECT DISTINCT "RainToday", "WindGustDir" FROM weather"#;
404+
let plan = sql_to_explain(query).await.unwrap();
405+
assert_snapshot!(plan, @r"
406+
┌───── Stage 3 Task: partitions: 0..3,unassigned]
407+
│partitions [out:4 <-- in:4 ] AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday, WindGustDir@1 as WindGustDir], aggr=[]
408+
│partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
409+
│partitions [out:4 ] ArrowFlightReadExec: Stage 2
410+
411+
└──────────────────────────────────────────────────
412+
┌───── Stage 2 Task: partitions: 0..3,unassigned]
413+
│partitions [out:4 <-- in:4 ] RepartitionExec: partitioning=Hash([RainToday@0, WindGustDir@1], 4), input_partitions=4
414+
│partitions [out:4 ] ArrowFlightReadExec: Stage 1
415+
416+
└──────────────────────────────────────────────────
417+
┌───── Stage 1 Task: partitions: 0..3,unassigned]
418+
│partitions [out:4 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
419+
│partitions [out:1 <-- in:1 ] AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday, WindGustDir@1 as WindGustDir], aggr=[]
420+
│partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday, WindGustDir], file_type=parquet
421+
422+
└──────────────────────────────────────────────────
423+
");
424+
}
425+
426+
async fn sql_to_explain(query: &str) -> Result<String, DataFusionError> {
427+
sql_to_explain_with_rule(query, DistributedPhysicalOptimizerRule::new()).await
428+
}
429+
430+
async fn sql_to_explain_partitions_per_task(
431+
query: &str,
432+
partitions_per_task: usize,
433+
) -> Result<String, DataFusionError> {
434+
sql_to_explain_with_rule(
435+
query,
436+
DistributedPhysicalOptimizerRule::new()
437+
.with_maximum_partitions_per_task(partitions_per_task),
438+
)
439+
.await
440+
}
441+
442+
async fn sql_to_explain_with_rule(
443+
query: &str,
444+
rule: DistributedPhysicalOptimizerRule,
445+
) -> Result<String, DataFusionError> {
446+
let config = SessionConfig::new().with_target_partitions(4);
447+
448+
let state = SessionStateBuilder::new()
449+
.with_default_features()
450+
.with_physical_optimizer_rule(Arc::new(rule))
451+
.with_config(config)
452+
.build();
453+
454+
let ctx = SessionContext::new_with_state(state);
455+
register_parquet_tables(&ctx).await?;
456+
457+
let df = ctx.sql(query).await?;
458+
459+
let physical_plan = df.create_physical_plan().await?;
460+
let display = displayable(physical_plan.as_ref()).indent(true).to_string();
461+
462+
Ok(display)
463+
}
464+
}

src/test_utils/insta.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
use std::env;
2+
3+
#[macro_export]
4+
macro_rules! assert_snapshot {
5+
($($arg:tt)*) => {
6+
crate::test_utils::insta::settings().bind(|| {
7+
insta::assert_snapshot!($($arg)*);
8+
})
9+
};
10+
}
11+
12+
pub fn settings() -> insta::Settings {
13+
env::set_var("INSTA_WORKSPACE_ROOT", env!("CARGO_MANIFEST_DIR"));
14+
let mut settings = insta::Settings::clone_current();
15+
let cwd = env::current_dir().unwrap();
16+
let cwd = cwd.to_str().unwrap();
17+
settings.add_filter(cwd.trim_start_matches("/"), "");
18+
settings.add_filter(
19+
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}",
20+
"UUID",
21+
);
22+
23+
settings
24+
}

src/test_utils/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
#[cfg(test)]
1+
pub mod insta;
22
mod mock_exec;
3+
mod parquet;
34

4-
#[cfg(test)]
55
pub use mock_exec::MockExec;
6+
pub use parquet::register_parquet_tables;

src/test_utils/parquet.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
use datafusion::error::DataFusionError;
2+
use datafusion::prelude::{ParquetReadOptions, SessionContext};
3+
4+
pub async fn register_parquet_tables(ctx: &SessionContext) -> Result<(), DataFusionError> {
5+
ctx.register_parquet(
6+
"flights_1m",
7+
"testdata/flights-1m.parquet",
8+
ParquetReadOptions::default(),
9+
)
10+
.await?;
11+
12+
ctx.register_parquet(
13+
"weather",
14+
"testdata/weather.parquet",
15+
ParquetReadOptions::default(),
16+
)
17+
.await?;
18+
19+
Ok(())
20+
}

0 commit comments

Comments
 (0)