@@ -24,6 +24,7 @@ use arrow::{
24
24
} ;
25
25
use arrow_schema:: SortOptions ;
26
26
use datafusion:: {
27
+ assert_batches_eq,
27
28
logical_expr:: Operator ,
28
29
physical_plan:: {
29
30
expressions:: { BinaryExpr , Column , Literal } ,
@@ -736,6 +737,101 @@ async fn test_topk_dynamic_filter_pushdown() {
736
737
) ;
737
738
}
738
739
740
+ #[ tokio:: test]
741
+ async fn test_topk_dynamic_filter_pushdown_multi_column_sort ( ) {
742
+ let batches = vec ! [
743
+ // We are going to do ORDER BY b ASC NULLS LAST, a DESC
744
+ // And we put the values in such a way that the first batch will fill the TopK
745
+ // and we skip the second batch.
746
+ record_batch!(
747
+ ( "a" , Utf8 , [ "ac" , "ad" ] ) ,
748
+ ( "b" , Utf8 , [ "bb" , "ba" ] ) ,
749
+ ( "c" , Float64 , [ 2.0 , 1.0 ] )
750
+ )
751
+ . unwrap( ) ,
752
+ record_batch!(
753
+ ( "a" , Utf8 , [ "aa" , "ab" ] ) ,
754
+ ( "b" , Utf8 , [ "bc" , "bd" ] ) ,
755
+ ( "c" , Float64 , [ 1.0 , 2.0 ] )
756
+ )
757
+ . unwrap( ) ,
758
+ ] ;
759
+ let scan = TestScanBuilder :: new ( schema ( ) )
760
+ . with_support ( true )
761
+ . with_batches ( batches)
762
+ . build ( ) ;
763
+ let plan = Arc :: new (
764
+ SortExec :: new (
765
+ LexOrdering :: new ( vec ! [
766
+ PhysicalSortExpr :: new(
767
+ col( "b" , & schema( ) ) . unwrap( ) ,
768
+ SortOptions :: default ( ) . asc( ) . nulls_last( ) ,
769
+ ) ,
770
+ PhysicalSortExpr :: new(
771
+ col( "a" , & schema( ) ) . unwrap( ) ,
772
+ SortOptions :: default ( ) . desc( ) . nulls_first( ) ,
773
+ ) ,
774
+ ] )
775
+ . unwrap ( ) ,
776
+ Arc :: clone ( & scan) ,
777
+ )
778
+ . with_fetch ( Some ( 2 ) ) ,
779
+ ) as Arc < dyn ExecutionPlan > ;
780
+
781
+ // expect the predicate to be pushed down into the DataSource
782
+ insta:: assert_snapshot!(
783
+ OptimizationTest :: new( Arc :: clone( & plan) , FilterPushdown :: new_post_optimization( ) , true ) ,
784
+ @r"
785
+ OptimizationTest:
786
+ input:
787
+ - SortExec: TopK(fetch=2), expr=[b@1 ASC NULLS LAST, a@0 DESC], preserve_partitioning=[false]
788
+ - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true
789
+ output:
790
+ Ok:
791
+ - SortExec: TopK(fetch=2), expr=[b@1 ASC NULLS LAST, a@0 DESC], preserve_partitioning=[false]
792
+ - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilterPhysicalExpr [ true ]
793
+ "
794
+ ) ;
795
+
796
+ // Actually apply the optimization to the plan and put some data through it to check that the filter is updated to reflect the TopK state
797
+ let mut config = ConfigOptions :: default ( ) ;
798
+ config. execution . parquet . pushdown_filters = true ;
799
+ let plan = FilterPushdown :: new_post_optimization ( )
800
+ . optimize ( plan, & config)
801
+ . unwrap ( ) ;
802
+ let config = SessionConfig :: new ( ) . with_batch_size ( 2 ) ;
803
+ let session_ctx = SessionContext :: new_with_config ( config) ;
804
+ session_ctx. register_object_store (
805
+ ObjectStoreUrl :: parse ( "test://" ) . unwrap ( ) . as_ref ( ) ,
806
+ Arc :: new ( InMemory :: new ( ) ) ,
807
+ ) ;
808
+ let state = session_ctx. state ( ) ;
809
+ let task_ctx = state. task_ctx ( ) ;
810
+ let mut stream = plan. execute ( 0 , Arc :: clone ( & task_ctx) ) . unwrap ( ) ;
811
+ // Iterate one batch
812
+ let res = stream. next ( ) . await . unwrap ( ) . unwrap ( ) ;
813
+ #[ rustfmt:: skip]
814
+ let expected = [
815
+ "+----+----+-----+" ,
816
+ "| a | b | c |" ,
817
+ "+----+----+-----+" ,
818
+ "| ad | ba | 1.0 |" ,
819
+ "| ac | bb | 2.0 |" ,
820
+ "+----+----+-----+" ,
821
+ ] ;
822
+ assert_batches_eq ! ( expected, & [ res] ) ;
823
+ // Now check what our filter looks like
824
+ insta:: assert_snapshot!(
825
+ format!( "{}" , format_plan_for_test( & plan) ) ,
826
+ @r"
827
+ - SortExec: TopK(fetch=2), expr=[b@1 ASC NULLS LAST, a@0 DESC], preserve_partitioning=[false], filter=[b@1 < bb OR b@1 = bb AND (a@0 IS NULL OR a@0 > ac)]
828
+ - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilterPhysicalExpr [ b@1 < bb OR b@1 = bb AND (a@0 IS NULL OR a@0 > ac) ]
829
+ "
830
+ ) ;
831
+ // There should be no more batches
832
+ assert ! ( stream. next( ) . await . is_none( ) ) ;
833
+ }
834
+
739
835
#[ tokio:: test]
740
836
async fn test_hashjoin_dynamic_filter_pushdown ( ) {
741
837
use datafusion_common:: JoinType ;
0 commit comments