@@ -258,3 +258,207 @@ impl TreeNodeRewriter for StagePlanner {
258258 }
259259 }
260260}
261+
262+ #[ cfg( test) ]
263+ mod tests {
264+ use crate :: assert_snapshot;
265+ use crate :: physical_optimizer:: DistributedPhysicalOptimizerRule ;
266+ use crate :: test_utils:: register_parquet_tables;
267+ use datafusion:: error:: DataFusionError ;
268+ use datafusion:: execution:: SessionStateBuilder ;
269+ use datafusion:: physical_plan:: displayable;
270+ use datafusion:: prelude:: { SessionConfig , SessionContext } ;
271+ use std:: sync:: Arc ;
272+
273+ /* shema for the "weather" table
274+
275+ MinTemp [type=DOUBLE] [repetitiontype=OPTIONAL]
276+ MaxTemp [type=DOUBLE] [repetitiontype=OPTIONAL]
277+ Rainfall [type=DOUBLE] [repetitiontype=OPTIONAL]
278+ Evaporation [type=DOUBLE] [repetitiontype=OPTIONAL]
279+ Sunshine [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
280+ WindGustDir [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
281+ WindGustSpeed [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
282+ WindDir9am [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
283+ WindDir3pm [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
284+ WindSpeed9am [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
285+ WindSpeed3pm [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
286+ Humidity9am [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
287+ Humidity3pm [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
288+ Pressure9am [type=DOUBLE] [repetitiontype=OPTIONAL]
289+ Pressure3pm [type=DOUBLE] [repetitiontype=OPTIONAL]
290+ Cloud9am [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
291+ Cloud3pm [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
292+ Temp9am [type=DOUBLE] [repetitiontype=OPTIONAL]
293+ Temp3pm [type=DOUBLE] [repetitiontype=OPTIONAL]
294+ RainToday [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
295+ RISK_MM [type=DOUBLE] [repetitiontype=OPTIONAL]
296+ RainTomorrow [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
297+ */
298+
299+ #[ tokio:: test]
300+ async fn test_select_all ( ) {
301+ let query = r#"SELECT * FROM weather"# ;
302+ let plan = sql_to_explain ( query) . await . unwrap ( ) ;
303+ assert_snapshot ! ( plan, @r"
304+ ┌───── Stage 1 Task: partitions: 0,unassigned]
305+ │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MinTemp, MaxTemp, Rainfall, Evaporation, Sunshine, WindGustDir, WindGustSpeed, WindDir9am, WindDir3pm, WindSpeed9am, WindSpeed3pm, Humidity9am, Humidity3pm, Pressure9am, Pressure3pm, Cloud9am, Cloud3pm, Temp9am, Temp3pm, RainToday, RISK_MM, RainTomorrow], file_type=parquet
306+ │
307+ └──────────────────────────────────────────────────
308+ " ) ;
309+ }
310+
311+ #[ tokio:: test]
312+ async fn test_aggregation ( ) {
313+ let query =
314+ r#"SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)"# ;
315+ let plan = sql_to_explain ( query) . await . unwrap ( ) ;
316+ assert_snapshot ! ( plan, @r"
317+ ┌───── Stage 3 Task: partitions: 0,unassigned]
318+ │partitions [out:1 <-- in:1 ] ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday]
319+ │partitions [out:1 <-- in:4 ] SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST]
320+ │partitions [out:4 <-- in:4 ] SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true]
321+ │partitions [out:4 <-- in:4 ] ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))]
322+ │partitions [out:4 <-- in:4 ] AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
323+ │partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
324+ │partitions [out:4 ] ArrowFlightReadExec: Stage 2
325+ │
326+ └──────────────────────────────────────────────────
327+ ┌───── Stage 2 Task: partitions: 0..3,unassigned]
328+ │partitions [out:4 <-- in:4 ] RepartitionExec: partitioning=Hash([RainToday@0], 4), input_partitions=4
329+ │partitions [out:4 ] ArrowFlightReadExec: Stage 1
330+ │
331+ └──────────────────────────────────────────────────
332+ ┌───── Stage 1 Task: partitions: 0..3,unassigned]
333+ │partitions [out:4 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
334+ │partitions [out:1 <-- in:1 ] AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
335+ │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet
336+ │
337+ └──────────────────────────────────────────────────
338+ " ) ;
339+ }
340+
341+ #[ tokio:: test]
342+ async fn test_aggregation_with_partitions_per_task ( ) {
343+ let query =
344+ r#"SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)"# ;
345+ let plan = sql_to_explain_partitions_per_task ( query, 2 ) . await . unwrap ( ) ;
346+ assert_snapshot ! ( plan, @r"
347+ ┌───── Stage 3 Task: partitions: 0,unassigned]
348+ │partitions [out:1 <-- in:1 ] ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday]
349+ │partitions [out:1 <-- in:4 ] SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST]
350+ │partitions [out:4 <-- in:4 ] SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true]
351+ │partitions [out:4 <-- in:4 ] ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))]
352+ │partitions [out:4 <-- in:4 ] AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
353+ │partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
354+ │partitions [out:4 ] ArrowFlightReadExec: Stage 2
355+ │
356+ └──────────────────────────────────────────────────
357+ ┌───── Stage 2 Task: partitions: 0,1,unassigned],Task: partitions: 2,3,unassigned]
358+ │partitions [out:4 <-- in:2 ] RepartitionExec: partitioning=Hash([RainToday@0], 4), input_partitions=2
359+ │partitions [out:2 <-- in:4 ] PartitionIsolatorExec [providing upto 2 partitions]
360+ │partitions [out:4 ] ArrowFlightReadExec: Stage 1
361+ │
362+ └──────────────────────────────────────────────────
363+ ┌───── Stage 1 Task: partitions: 0,1,unassigned],Task: partitions: 2,3,unassigned]
364+ │partitions [out:4 <-- in:2 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2
365+ │partitions [out:2 <-- in:1 ] PartitionIsolatorExec [providing upto 2 partitions]
366+ │partitions [out:1 <-- in:1 ] AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
367+ │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet
368+ │
369+ └──────────────────────────────────────────────────
370+ " ) ;
371+ }
372+
373+ #[ tokio:: test]
374+ async fn test_left_join ( ) {
375+ let query = r#"SELECT a."MinTemp", b."MaxTemp" FROM weather a LEFT JOIN weather b ON a."RainToday" = b."RainToday" "# ;
376+ let plan = sql_to_explain ( query) . await . unwrap ( ) ;
377+ assert_snapshot ! ( plan, @r"
378+ ┌───── Stage 1 Task: partitions: 0,unassigned]
379+ │partitions [out:1 <-- in:1 ] CoalesceBatchesExec: target_batch_size=8192
380+ │partitions [out:1 <-- in:1 ] HashJoinExec: mode=Partitioned, join_type=Left, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
381+ │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
382+ │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet
383+ │
384+ └──────────────────────────────────────────────────
385+ " ) ;
386+ }
387+
388+ #[ tokio:: test]
389+ async fn test_sort ( ) {
390+ let query = r#"SELECT * FROM weather ORDER BY "MinTemp" DESC "# ;
391+ let plan = sql_to_explain ( query) . await . unwrap ( ) ;
392+ assert_snapshot ! ( plan, @r"
393+ ┌───── Stage 1 Task: partitions: 0,unassigned]
394+ │partitions [out:1 <-- in:1 ] SortExec: expr=[MinTemp@0 DESC], preserve_partitioning=[false]
395+ │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MinTemp, MaxTemp, Rainfall, Evaporation, Sunshine, WindGustDir, WindGustSpeed, WindDir9am, WindDir3pm, WindSpeed9am, WindSpeed3pm, Humidity9am, Humidity3pm, Pressure9am, Pressure3pm, Cloud9am, Cloud3pm, Temp9am, Temp3pm, RainToday, RISK_MM, RainTomorrow], file_type=parquet
396+ │
397+ └──────────────────────────────────────────────────
398+ " ) ;
399+ }
400+
401+ #[ tokio:: test]
402+ async fn test_distinct ( ) {
403+ let query = r#"SELECT DISTINCT "RainToday", "WindGustDir" FROM weather"# ;
404+ let plan = sql_to_explain ( query) . await . unwrap ( ) ;
405+ assert_snapshot ! ( plan, @r"
406+ ┌───── Stage 3 Task: partitions: 0..3,unassigned]
407+ │partitions [out:4 <-- in:4 ] AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday, WindGustDir@1 as WindGustDir], aggr=[]
408+ │partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192
409+ │partitions [out:4 ] ArrowFlightReadExec: Stage 2
410+ │
411+ └──────────────────────────────────────────────────
412+ ┌───── Stage 2 Task: partitions: 0..3,unassigned]
413+ │partitions [out:4 <-- in:4 ] RepartitionExec: partitioning=Hash([RainToday@0, WindGustDir@1], 4), input_partitions=4
414+ │partitions [out:4 ] ArrowFlightReadExec: Stage 1
415+ │
416+ └──────────────────────────────────────────────────
417+ ┌───── Stage 1 Task: partitions: 0..3,unassigned]
418+ │partitions [out:4 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
419+ │partitions [out:1 <-- in:1 ] AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday, WindGustDir@1 as WindGustDir], aggr=[]
420+ │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday, WindGustDir], file_type=parquet
421+ │
422+ └──────────────────────────────────────────────────
423+ " ) ;
424+ }
425+
426+ async fn sql_to_explain ( query : & str ) -> Result < String , DataFusionError > {
427+ sql_to_explain_with_rule ( query, DistributedPhysicalOptimizerRule :: new ( ) ) . await
428+ }
429+
430+ async fn sql_to_explain_partitions_per_task (
431+ query : & str ,
432+ partitions_per_task : usize ,
433+ ) -> Result < String , DataFusionError > {
434+ sql_to_explain_with_rule (
435+ query,
436+ DistributedPhysicalOptimizerRule :: new ( )
437+ . with_maximum_partitions_per_task ( partitions_per_task) ,
438+ )
439+ . await
440+ }
441+
442+ async fn sql_to_explain_with_rule (
443+ query : & str ,
444+ rule : DistributedPhysicalOptimizerRule ,
445+ ) -> Result < String , DataFusionError > {
446+ let config = SessionConfig :: new ( ) . with_target_partitions ( 4 ) ;
447+
448+ let state = SessionStateBuilder :: new ( )
449+ . with_default_features ( )
450+ . with_physical_optimizer_rule ( Arc :: new ( rule) )
451+ . with_config ( config)
452+ . build ( ) ;
453+
454+ let ctx = SessionContext :: new_with_state ( state) ;
455+ register_parquet_tables ( & ctx) . await ?;
456+
457+ let df = ctx. sql ( query) . await ?;
458+
459+ let physical_plan = df. create_physical_plan ( ) . await ?;
460+ let display = displayable ( physical_plan. as_ref ( ) ) . indent ( true ) . to_string ( ) ;
461+
462+ Ok ( display)
463+ }
464+ }
0 commit comments