@@ -258,3 +258,280 @@ impl TreeNodeRewriter for StagePlanner {
258258 }
259259 }
260260}
261+
262+ #[ cfg( test) ]
263+ mod tests {
264+ use crate :: assert_snapshot;
265+ use crate :: physical_optimizer:: DistributedPhysicalOptimizerRule ;
266+ use crate :: test_utils:: register_parquet_tables;
267+ use datafusion:: error:: DataFusionError ;
268+ use datafusion:: execution:: SessionStateBuilder ;
269+ use datafusion:: physical_plan:: displayable;
270+ use datafusion:: prelude:: { SessionConfig , SessionContext } ;
271+ use std:: sync:: Arc ;
272+
273+ /* shema for the "weather" table
274+
275+ MinTemp [type=DOUBLE] [repetitiontype=OPTIONAL]
276+ MaxTemp [type=DOUBLE] [repetitiontype=OPTIONAL]
277+ Rainfall [type=DOUBLE] [repetitiontype=OPTIONAL]
278+ Evaporation [type=DOUBLE] [repetitiontype=OPTIONAL]
279+ Sunshine [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
280+ WindGustDir [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
281+ WindGustSpeed [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
282+ WindDir9am [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
283+ WindDir3pm [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
284+ WindSpeed9am [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
285+ WindSpeed3pm [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
286+ Humidity9am [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
287+ Humidity3pm [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
288+ Pressure9am [type=DOUBLE] [repetitiontype=OPTIONAL]
289+ Pressure3pm [type=DOUBLE] [repetitiontype=OPTIONAL]
290+ Cloud9am [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
291+ Cloud3pm [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
292+ Temp9am [type=DOUBLE] [repetitiontype=OPTIONAL]
293+ Temp3pm [type=DOUBLE] [repetitiontype=OPTIONAL]
294+ RainToday [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
295+ RISK_MM [type=DOUBLE] [repetitiontype=OPTIONAL]
296+ RainTomorrow [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
297+ */
298+
299+ #[ tokio:: test]
300+ async fn test_select_all ( ) {
301+ let query = r#"SELECT * FROM weather"# ;
302+ let plan = sql_to_explain ( query) . await . unwrap ( ) ;
303+ assert_snapshot ! ( plan, @r"
304+ ┌───── Stage 1 Task: partitions: 0,unassigned]
305+ │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MinTemp, MaxTemp, Rainfall, Evaporation, Sunshine, WindGustDir, WindGustSpeed, WindDir9am, WindDir3pm, WindSpeed9am, WindSpeed3pm, Humidity9am, Humidity3pm, Pressure9am, Pressure3pm, Cloud9am, Cloud3pm, Temp9am, Temp3pm, RainToday, RISK_MM, RainTomorrow], file_type=parquet
306+ │
307+ └──────────────────────────────────────────────────
308+ " ) ;
309+ }
310+
311+ #[ tokio:: test]
312+ async fn test_aggregation ( ) {
313+ let query =
314+ r#"SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)"# ;
315+ let plan = sql_to_explain ( query) . await . unwrap ( ) ;
316+ assert_snapshot ! ( plan, @r"
317+ ┌───── Stage 3 Task: partitions: 0,unassigned]
318+ │partitions [out:1 <-- in:1 ] ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday]
319+ │partitions [out:1 <-- in:4 ] SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST]
320+ │partitions [out:4 <-- in:4 ] SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true]
321+ │partitions [out:4 <-- in:4 ] ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))]
322+ │partitions [out:4 <-- in:4 ] AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
323+ │partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
324+ │partitions [out:4 ] ArrowFlightReadExec: Stage 2
325+ │
326+ └──────────────────────────────────────────────────
327+ ┌───── Stage 2 Task: partitions: 0..3,unassigned]
328+ │partitions [out:4 <-- in:4 ] RepartitionExec: partitioning=Hash([RainToday@0], 4), input_partitions=4
329+ │partitions [out:4 ] ArrowFlightReadExec: Stage 1
330+ │
331+ └──────────────────────────────────────────────────
332+ ┌───── Stage 1 Task: partitions: 0..3,unassigned]
333+ │partitions [out:4 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
334+ │partitions [out:1 <-- in:1 ] AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
335+ │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet
336+ │
337+ └──────────────────────────────────────────────────
338+ " ) ;
339+ }
340+
341+ #[ tokio:: test]
342+ async fn test_aggregation_with_partitions_per_task ( ) {
343+ let query =
344+ r#"SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)"# ;
345+ let plan = sql_to_explain_partitions_per_task ( query, 2 ) . await . unwrap ( ) ;
346+ assert_snapshot ! ( plan, @r"
347+ ┌───── Stage 3 Task: partitions: 0,unassigned]
348+ │partitions [out:1 <-- in:1 ] ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday]
349+ │partitions [out:1 <-- in:4 ] SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST]
350+ │partitions [out:4 <-- in:4 ] SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true]
351+ │partitions [out:4 <-- in:4 ] ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))]
352+ │partitions [out:4 <-- in:4 ] AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
353+ │partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
354+ │partitions [out:4 ] ArrowFlightReadExec: Stage 2
355+ │
356+ └──────────────────────────────────────────────────
357+ ┌───── Stage 2 Task: partitions: 0,1,unassigned],Task: partitions: 2,3,unassigned]
358+ │partitions [out:4 <-- in:2 ] RepartitionExec: partitioning=Hash([RainToday@0], 4), input_partitions=2
359+ │partitions [out:2 <-- in:4 ] PartitionIsolatorExec [providing upto 2 partitions]
360+ │partitions [out:4 ] ArrowFlightReadExec: Stage 1
361+ │
362+ └──────────────────────────────────────────────────
363+ ┌───── Stage 1 Task: partitions: 0,1,unassigned],Task: partitions: 2,3,unassigned]
364+ │partitions [out:4 <-- in:2 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2
365+ │partitions [out:2 <-- in:1 ] PartitionIsolatorExec [providing upto 2 partitions]
366+ │partitions [out:1 <-- in:1 ] AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
367+ │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet
368+ │
369+ └──────────────────────────────────────────────────
370+ " ) ;
371+ }
372+
373+ #[ tokio:: test]
374+ async fn test_left_join ( ) {
375+ let query = r#"SELECT a."MinTemp", b."MaxTemp" FROM weather a LEFT JOIN weather b ON a."RainToday" = b."RainToday" "# ;
376+ let plan = sql_to_explain ( query) . await . unwrap ( ) ;
377+ assert_snapshot ! ( plan, @r"
378+ ┌───── Stage 1 Task: partitions: 0,unassigned]
379+ │partitions [out:1 <-- in:1 ] CoalesceBatchesExec: target_batch_size=8192
380+ │partitions [out:1 <-- in:1 ] HashJoinExec: mode=Partitioned, join_type=Left, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
381+ │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
382+ │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet
383+ │
384+ └──────────────────────────────────────────────────
385+ " ) ;
386+ }
387+
388+ #[ tokio:: test]
389+ async fn test_left_join_distributed ( ) {
390+ let query = r#"
391+ WITH a AS (
392+ SELECT
393+ AVG("MinTemp") as "MinTemp",
394+ "RainTomorrow"
395+ FROM weather
396+ WHERE "RainToday" = 'yes'
397+ GROUP BY "RainTomorrow"
398+ ), b AS (
399+ SELECT
400+ AVG("MaxTemp") as "MaxTemp",
401+ "RainTomorrow"
402+ FROM weather
403+ WHERE "RainToday" = 'no'
404+ GROUP BY "RainTomorrow"
405+ )
406+ SELECT
407+ a."MinTemp",
408+ b."MaxTemp"
409+ FROM a
410+ LEFT JOIN b
411+ ON a."RainTomorrow" = b."RainTomorrow"
412+
413+ "# ;
414+ let plan = sql_to_explain ( query) . await . unwrap ( ) ;
415+ assert_snapshot ! ( plan, @r"
416+ ┌───── Stage 5 Task: partitions: 0..3,unassigned]
417+ │partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
418+ │partitions [out:4 <-- in:1 ] HashJoinExec: mode=CollectLeft, join_type=Left, on=[(RainTomorrow@1, RainTomorrow@1)], projection=[MinTemp@0, MaxTemp@2]
419+ │partitions [out:1 <-- in:4 ] CoalescePartitionsExec
420+ │partitions [out:4 <-- in:4 ] ProjectionExec: expr=[avg(weather.MinTemp)@1 as MinTemp, RainTomorrow@0 as RainTomorrow]
421+ │partitions [out:4 <-- in:4 ] AggregateExec: mode=FinalPartitioned, gby=[RainTomorrow@0 as RainTomorrow], aggr=[avg(weather.MinTemp)]
422+ │partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
423+ │partitions [out:4 ] ArrowFlightReadExec: Stage 2
424+ │partitions [out:4 <-- in:4 ] ProjectionExec: expr=[avg(weather.MaxTemp)@1 as MaxTemp, RainTomorrow@0 as RainTomorrow]
425+ │partitions [out:4 <-- in:4 ] AggregateExec: mode=FinalPartitioned, gby=[RainTomorrow@0 as RainTomorrow], aggr=[avg(weather.MaxTemp)]
426+ │partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
427+ │partitions [out:4 ] ArrowFlightReadExec: Stage 4
428+ │
429+ └──────────────────────────────────────────────────
430+ ┌───── Stage 4 Task: partitions: 0..3,unassigned]
431+ │partitions [out:4 <-- in:4 ] RepartitionExec: partitioning=Hash([RainTomorrow@0], 4), input_partitions=4
432+ │partitions [out:4 <-- in:4 ] AggregateExec: mode=Partial, gby=[RainTomorrow@1 as RainTomorrow], aggr=[avg(weather.MaxTemp)]
433+ │partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
434+ │partitions [out:4 <-- in:4 ] FilterExec: RainToday@1 = no, projection=[MaxTemp@0, RainTomorrow@2]
435+ │partitions [out:4 ] ArrowFlightReadExec: Stage 3
436+ │
437+ └──────────────────────────────────────────────────
438+ ┌───── Stage 3 Task: partitions: 0..3,unassigned]
439+ │partitions [out:4 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
440+ │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MaxTemp, RainToday, RainTomorrow], file_type=parquet, predicate=RainToday@1 = no, pruning_predicate=RainToday_null_count@2 != row_count@3 AND RainToday_min@0 <= no AND no <= RainToday_max@1, required_guarantees=[RainToday in (no)]
441+ │
442+ │
443+ └──────────────────────────────────────────────────
444+ ┌───── Stage 2 Task: partitions: 0..3,unassigned]
445+ │partitions [out:4 <-- in:4 ] RepartitionExec: partitioning=Hash([RainTomorrow@0], 4), input_partitions=4
446+ │partitions [out:4 <-- in:4 ] AggregateExec: mode=Partial, gby=[RainTomorrow@1 as RainTomorrow], aggr=[avg(weather.MinTemp)]
447+ │partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
448+ │partitions [out:4 <-- in:4 ] FilterExec: RainToday@1 = yes, projection=[MinTemp@0, RainTomorrow@2]
449+ │partitions [out:4 ] ArrowFlightReadExec: Stage 1
450+ │
451+ └──────────────────────────────────────────────────
452+ ┌───── Stage 1 Task: partitions: 0..3,unassigned]
453+ │partitions [out:4 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
454+ │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MinTemp, RainToday, RainTomorrow], file_type=parquet, predicate=RainToday@1 = yes, pruning_predicate=RainToday_null_count@2 != row_count@3 AND RainToday_min@0 <= yes AND yes <= RainToday_max@1, required_guarantees=[RainToday in (yes)]
455+ │
456+ │
457+ └──────────────────────────────────────────────────
458+ " ) ;
459+ }
460+
461+ #[ tokio:: test]
462+ async fn test_sort ( ) {
463+ let query = r#"SELECT * FROM weather ORDER BY "MinTemp" DESC "# ;
464+ let plan = sql_to_explain ( query) . await . unwrap ( ) ;
465+ assert_snapshot ! ( plan, @r"
466+ ┌───── Stage 1 Task: partitions: 0,unassigned]
467+ │partitions [out:1 <-- in:1 ] SortExec: expr=[MinTemp@0 DESC], preserve_partitioning=[false]
468+ │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MinTemp, MaxTemp, Rainfall, Evaporation, Sunshine, WindGustDir, WindGustSpeed, WindDir9am, WindDir3pm, WindSpeed9am, WindSpeed3pm, Humidity9am, Humidity3pm, Pressure9am, Pressure3pm, Cloud9am, Cloud3pm, Temp9am, Temp3pm, RainToday, RISK_MM, RainTomorrow], file_type=parquet
469+ │
470+ └──────────────────────────────────────────────────
471+ " ) ;
472+ }
473+
474+ #[ tokio:: test]
475+ async fn test_distinct ( ) {
476+ let query = r#"SELECT DISTINCT "RainToday", "WindGustDir" FROM weather"# ;
477+ let plan = sql_to_explain ( query) . await . unwrap ( ) ;
478+ assert_snapshot ! ( plan, @r"
479+ ┌───── Stage 3 Task: partitions: 0..3,unassigned]
480+ │partitions [out:4 <-- in:4 ] AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday, WindGustDir@1 as WindGustDir], aggr=[]
481+ │partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
482+ │partitions [out:4 ] ArrowFlightReadExec: Stage 2
483+ │
484+ └──────────────────────────────────────────────────
485+ ┌───── Stage 2 Task: partitions: 0..3,unassigned]
486+ │partitions [out:4 <-- in:4 ] RepartitionExec: partitioning=Hash([RainToday@0, WindGustDir@1], 4), input_partitions=4
487+ │partitions [out:4 ] ArrowFlightReadExec: Stage 1
488+ │
489+ └──────────────────────────────────────────────────
490+ ┌───── Stage 1 Task: partitions: 0..3,unassigned]
491+ │partitions [out:4 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
492+ │partitions [out:1 <-- in:1 ] AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday, WindGustDir@1 as WindGustDir], aggr=[]
493+ │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday, WindGustDir], file_type=parquet
494+ │
495+ └──────────────────────────────────────────────────
496+ " ) ;
497+ }
498+
499+ async fn sql_to_explain ( query : & str ) -> Result < String , DataFusionError > {
500+ sql_to_explain_with_rule ( query, DistributedPhysicalOptimizerRule :: new ( ) ) . await
501+ }
502+
503+ async fn sql_to_explain_partitions_per_task (
504+ query : & str ,
505+ partitions_per_task : usize ,
506+ ) -> Result < String , DataFusionError > {
507+ sql_to_explain_with_rule (
508+ query,
509+ DistributedPhysicalOptimizerRule :: new ( )
510+ . with_maximum_partitions_per_task ( partitions_per_task) ,
511+ )
512+ . await
513+ }
514+
515+ async fn sql_to_explain_with_rule (
516+ query : & str ,
517+ rule : DistributedPhysicalOptimizerRule ,
518+ ) -> Result < String , DataFusionError > {
519+ let config = SessionConfig :: new ( ) . with_target_partitions ( 4 ) ;
520+
521+ let state = SessionStateBuilder :: new ( )
522+ . with_default_features ( )
523+ . with_physical_optimizer_rule ( Arc :: new ( rule) )
524+ . with_config ( config)
525+ . build ( ) ;
526+
527+ let ctx = SessionContext :: new_with_state ( state) ;
528+ register_parquet_tables ( & ctx) . await ?;
529+
530+ let df = ctx. sql ( query) . await ?;
531+
532+ let physical_plan = df. create_physical_plan ( ) . await ?;
533+ let display = displayable ( physical_plan. as_ref ( ) ) . indent ( true ) . to_string ( ) ;
534+
535+ Ok ( display)
536+ }
537+ }
0 commit comments