@@ -3,15 +3,13 @@ use std::{collections::HashMap, sync::Arc};
33use datafusion:: {
44 error:: { DataFusionError , Result } ,
55 logical_plan:: {
6- plan:: {
7- Aggregate , CrossJoin , Distinct , Join , Limit , Projection , Sort , Subquery , Union , Window ,
8- } ,
6+ plan:: { Aggregate , Distinct , Limit , Projection , Sort , Subquery , Union , Window } ,
97 Column , DFSchema , Expr , Filter , LogicalPlan ,
108 } ,
119 optimizer:: optimizer:: { OptimizerConfig , OptimizerRule } ,
1210} ;
1311
14- use super :: utils:: { get_schema_columns , is_column_expr, plan_has_projections, rewrite} ;
12+ use super :: utils:: { is_column_expr, plan_has_projections, rewrite} ;
1513
1614/// Sort Push Down optimizer rule pushes ORDER BY clauses consisting of specific,
1715/// mostly simple, expressions down the plan, all the way to the Projection
@@ -167,97 +165,6 @@ fn sort_push_down(
167165 optimizer_config,
168166 )
169167 }
170- LogicalPlan :: Join ( Join {
171- left,
172- right,
173- on,
174- join_type,
175- join_constraint,
176- schema,
177- null_equals_null,
178- } ) => {
179- // DataFusion preserves the sorting of the joined plans, prioritizing left side.
180- // Taking this into account, we can push Sort down the left plan if Sort references
181- // columns just from the left side.
182- // TODO: check if this is still the case with multiple target partitions
183- if let Some ( some_sort_expr) = & sort_expr {
184- let left_columns = get_schema_columns ( left. schema ( ) ) ;
185- if some_sort_expr. iter ( ) . all ( |expr| {
186- if let Expr :: Sort { expr, .. } = expr {
187- if let Expr :: Column ( column) = expr. as_ref ( ) {
188- return left_columns. contains ( column) ;
189- }
190- }
191- false
192- } ) {
193- return Ok ( LogicalPlan :: Join ( Join {
194- left : Arc :: new ( sort_push_down (
195- optimizer,
196- left,
197- sort_expr,
198- optimizer_config,
199- ) ?) ,
200- right : Arc :: new ( sort_push_down ( optimizer, right, None , optimizer_config) ?) ,
201- on : on. clone ( ) ,
202- join_type : * join_type,
203- join_constraint : * join_constraint,
204- schema : schema. clone ( ) ,
205- null_equals_null : * null_equals_null,
206- } ) ) ;
207- }
208- }
209-
210- issue_sort (
211- sort_expr,
212- LogicalPlan :: Join ( Join {
213- left : Arc :: new ( sort_push_down ( optimizer, left, None , optimizer_config) ?) ,
214- right : Arc :: new ( sort_push_down ( optimizer, right, None , optimizer_config) ?) ,
215- on : on. clone ( ) ,
216- join_type : * join_type,
217- join_constraint : * join_constraint,
218- schema : schema. clone ( ) ,
219- null_equals_null : * null_equals_null,
220- } ) ,
221- )
222- }
223- LogicalPlan :: CrossJoin ( CrossJoin {
224- left,
225- right,
226- schema,
227- } ) => {
228- // See `LogicalPlan::Join` notes above.
229- if let Some ( some_sort_expr) = & sort_expr {
230- let left_columns = get_schema_columns ( left. schema ( ) ) ;
231- if some_sort_expr. iter ( ) . all ( |expr| {
232- if let Expr :: Sort { expr, .. } = expr {
233- if let Expr :: Column ( column) = expr. as_ref ( ) {
234- return left_columns. contains ( column) ;
235- }
236- }
237- false
238- } ) {
239- return Ok ( LogicalPlan :: CrossJoin ( CrossJoin {
240- left : Arc :: new ( sort_push_down (
241- optimizer,
242- left,
243- sort_expr,
244- optimizer_config,
245- ) ?) ,
246- right : Arc :: new ( sort_push_down ( optimizer, right, None , optimizer_config) ?) ,
247- schema : schema. clone ( ) ,
248- } ) ) ;
249- }
250- }
251-
252- issue_sort (
253- sort_expr,
254- LogicalPlan :: CrossJoin ( CrossJoin {
255- left : Arc :: new ( sort_push_down ( optimizer, left, None , optimizer_config) ?) ,
256- right : Arc :: new ( sort_push_down ( optimizer, right, None , optimizer_config) ?) ,
257- schema : schema. clone ( ) ,
258- } ) ,
259- )
260- }
261168 LogicalPlan :: Union ( Union {
262169 inputs,
263170 schema,
@@ -384,15 +291,10 @@ mod tests {
384291 } ;
385292 use datafusion:: logical_plan:: { col, JoinType , LogicalPlanBuilder } ;
386293
387- fn optimize ( plan : & LogicalPlan ) -> Result < LogicalPlan > {
294+ fn optimize ( plan : & LogicalPlan ) -> LogicalPlan {
388295 let rule = SortPushDown :: new ( ) ;
389296 rule. optimize ( plan, & OptimizerConfig :: new ( ) )
390- }
391-
392- fn assert_optimized_plan_eq ( plan : LogicalPlan , expected : & str ) {
393- let optimized_plan = optimize ( & plan) . expect ( "failed to optimize plan" ) ;
394- let formatted_plan = format ! ( "{:?}" , optimized_plan) ;
395- assert_eq ! ( formatted_plan, expected) ;
297+ . expect ( "failed to optimize plan" )
396298 }
397299
398300 fn sort ( expr : Expr , asc : bool , nulls_first : bool ) -> Expr {
@@ -417,14 +319,7 @@ mod tests {
417319 ] ) ?
418320 . build ( ) ?;
419321
420- let expected = "\
421- Projection: #t1.c1 AS n1, #t1.c2, #t1.c3 AS n2, alias=t2\
422- \n Sort: #t1.c2 ASC NULLS LAST, #t1.c3 DESC NULLS FIRST\
423- \n Projection: #t1.c1, #t1.c2, #t1.c3\
424- \n TableScan: t1 projection=None\
425- ";
426-
427- assert_optimized_plan_eq ( plan, expected) ;
322+ insta:: assert_debug_snapshot!( optimize( & plan) ) ;
428323 Ok ( ( ) )
429324 }
430325
@@ -450,16 +345,7 @@ mod tests {
450345 ] ) ?
451346 . build ( ) ?;
452347
453- let expected = "\
454- Projection: #t3.n3, #t3.n4, #t3.n2, alias=t4\
455- \n Projection: #t2.n1 AS n3, #t2.c2 AS n4, #t2.n2, alias=t3\
456- \n Projection: #t1.c1 AS n1, #t1.c2, #t1.c3 AS n2, alias=t2\
457- \n Sort: #t1.c2 ASC NULLS LAST, #t1.c3 DESC NULLS FIRST\
458- \n Projection: #t1.c1, #t1.c2, #t1.c3\
459- \n TableScan: t1 projection=None\
460- ";
461-
462- assert_optimized_plan_eq ( plan, expected) ;
348+ insta:: assert_debug_snapshot!( optimize( & plan) ) ;
463349 Ok ( ( ) )
464350 }
465351
@@ -487,21 +373,12 @@ mod tests {
487373 ] ) ?
488374 . build ( ) ?;
489375
490- let expected = "\
491- Projection: #t3.n3, #t3.n4, #t3.n2, alias=t4\
492- \n Projection: #t2.n1 AS n3, #t2.c2 AS n4, #t2.n2, alias=t3\
493- \n Projection: #t1.c1 AS n1, #t1.c2, #t1.c3 AS n2, alias=t2\
494- \n Sort: #t1.c2 ASC NULLS LAST, #t1.c3 DESC NULLS FIRST\
495- \n Projection: #t1.c1, #t1.c2, #t1.c3\
496- \n TableScan: t1 projection=None\
497- ";
498-
499- assert_optimized_plan_eq ( plan, expected) ;
376+ insta:: assert_debug_snapshot!( optimize( & plan) ) ;
500377 Ok ( ( ) )
501378 }
502379
503380 #[ test]
504- fn test_sort_down_join ( ) -> Result < ( ) > {
381+ fn test_sort_down_join_sort_left ( ) -> Result < ( ) > {
505382 let plan = LogicalPlanBuilder :: from (
506383 LogicalPlanBuilder :: from ( make_sample_table ( "j1" , vec ! [ "key" , "c1" ] , vec ! [ ] ) ?)
507384 . project ( vec ! [ col( "key" ) , col( "c1" ) ] ) ?
@@ -521,18 +398,12 @@ mod tests {
521398 . sort ( vec ! [ sort( col( "j1.c1" ) , true , false ) ] ) ?
522399 . build ( ) ?;
523400
524- let expected = "\
525- Projection: #j1.c1, #j2.c2\
526- \n Inner Join: #j1.key = #j2.key\
527- \n Sort: #j1.c1 ASC NULLS LAST\
528- \n Projection: #j1.key, #j1.c1\
529- \n TableScan: j1 projection=None\
530- \n Projection: #j2.key, #j2.c2\
531- \n TableScan: j2 projection=None\
532- ";
533-
534- assert_optimized_plan_eq ( plan, expected) ;
401+ insta:: assert_debug_snapshot!( optimize( & plan) ) ;
402+ Ok ( ( ) )
403+ }
535404
405+ #[ test]
406+ fn test_sort_down_join_sort_right ( ) -> Result < ( ) > {
536407 let plan = LogicalPlanBuilder :: from (
537408 LogicalPlanBuilder :: from ( make_sample_table ( "j1" , vec ! [ "key" , "c1" ] , vec ! [ ] ) ?)
538409 . project ( vec ! [ col( "key" ) , col( "c1" ) ] ) ?
@@ -552,23 +423,12 @@ mod tests {
552423 . sort ( vec ! [ sort( col( "j2.c2" ) , true , false ) ] ) ?
553424 . build ( ) ?;
554425
555- let expected = "\
556- Projection: #j1.c1, #j2.c2\
557- \n Sort: #j2.c2 ASC NULLS LAST\
558- \n Inner Join: #j1.key = #j2.key\
559- \n Projection: #j1.key, #j1.c1\
560- \n TableScan: j1 projection=None\
561- \n Projection: #j2.key, #j2.c2\
562- \n TableScan: j2 projection=None\
563- ";
564-
565- assert_optimized_plan_eq ( plan, expected) ;
566-
426+ insta:: assert_debug_snapshot!( optimize( & plan) ) ;
567427 Ok ( ( ) )
568428 }
569429
570430 #[ test]
571- fn test_sort_down_cross_join ( ) -> Result < ( ) > {
431+ fn test_sort_down_cross_join_sort_left ( ) -> Result < ( ) > {
572432 let plan = LogicalPlanBuilder :: from (
573433 LogicalPlanBuilder :: from ( make_sample_table ( "j1" , vec ! [ "key" , "c1" ] , vec ! [ ] ) ?)
574434 . project ( vec ! [ col( "key" ) , col( "c1" ) ] ) ?
@@ -583,18 +443,12 @@ mod tests {
583443 . sort ( vec ! [ sort( col( "j1.c1" ) , true , false ) ] ) ?
584444 . build ( ) ?;
585445
586- let expected = "\
587- Projection: #j1.c1, #j2.c2\
588- \n CrossJoin:\
589- \n Sort: #j1.c1 ASC NULLS LAST\
590- \n Projection: #j1.key, #j1.c1\
591- \n TableScan: j1 projection=None\
592- \n Projection: #j2.key, #j2.c2\
593- \n TableScan: j2 projection=None\
594- ";
595-
596- assert_optimized_plan_eq ( plan, expected) ;
446+ insta:: assert_debug_snapshot!( optimize( & plan) ) ;
447+ Ok ( ( ) )
448+ }
597449
450+ #[ test]
451+ fn test_sort_down_cross_join_sort_right ( ) -> Result < ( ) > {
598452 let plan = LogicalPlanBuilder :: from (
599453 LogicalPlanBuilder :: from ( make_sample_table ( "j1" , vec ! [ "key" , "c1" ] , vec ! [ ] ) ?)
600454 . project ( vec ! [ col( "key" ) , col( "c1" ) ] ) ?
@@ -609,17 +463,7 @@ mod tests {
609463 . sort ( vec ! [ sort( col( "j2.c2" ) , true , false ) ] ) ?
610464 . build ( ) ?;
611465
612- let expected = "\
613- Projection: #j1.c1, #j2.c2\
614- \n Sort: #j2.c2 ASC NULLS LAST\
615- \n CrossJoin:\
616- \n Projection: #j1.key, #j1.c1\
617- \n TableScan: j1 projection=None\
618- \n Projection: #j2.key, #j2.c2\
619- \n TableScan: j2 projection=None\
620- ";
621-
622- assert_optimized_plan_eq ( plan, expected) ;
466+ insta:: assert_debug_snapshot!( optimize( & plan) ) ;
623467
624468 Ok ( ( ) )
625469 }
0 commit comments