1717
1818use std:: sync:: Arc ;
1919
20+ use crate :: memory_limit:: DummyStreamPartition ;
2021use crate :: physical_optimizer:: test_utils:: {
2122 aggregate_exec, bounded_window_exec, bounded_window_exec_with_partition,
2223 check_integrity, coalesce_batches_exec, coalesce_partitions_exec, create_test_schema,
@@ -32,11 +33,11 @@ use arrow::compute::SortOptions;
3233use arrow:: datatypes:: { DataType , SchemaRef } ;
3334use datafusion_common:: config:: ConfigOptions ;
3435use datafusion_common:: tree_node:: { TreeNode , TransformedResult } ;
35- use datafusion_common:: { Result , ScalarValue } ;
36+ use datafusion_common:: { Result , ScalarValue , TableReference } ;
3637use datafusion_datasource:: file_scan_config:: FileScanConfigBuilder ;
3738use datafusion_datasource:: source:: DataSourceExec ;
3839use datafusion_expr_common:: operator:: Operator ;
39- use datafusion_expr:: { JoinType , WindowFrame , WindowFrameBound , WindowFrameUnits , WindowFunctionDefinition } ;
40+ use datafusion_expr:: { JoinType , SortExpr , WindowFrame , WindowFrameBound , WindowFrameUnits , WindowFunctionDefinition } ;
4041use datafusion_execution:: object_store:: ObjectStoreUrl ;
4142use datafusion_functions_aggregate:: average:: avg_udaf;
4243use datafusion_functions_aggregate:: count:: count_udaf;
@@ -61,7 +62,14 @@ use datafusion_physical_optimizer::enforce_sorting::sort_pushdown::{SortPushDown
6162use datafusion_physical_optimizer:: enforce_distribution:: EnforceDistribution ;
6263use datafusion_physical_optimizer:: output_requirements:: OutputRequirementExec ;
6364use datafusion_physical_optimizer:: PhysicalOptimizerRule ;
64-
65+ use datafusion:: prelude:: * ;
66+ use arrow:: array:: { Int32Array , RecordBatch } ;
67+ use arrow:: datatypes:: { Field } ;
68+ use arrow_schema:: Schema ;
69+ use datafusion_execution:: TaskContext ;
70+ use datafusion_catalog:: streaming:: StreamingTable ;
71+
72+ use futures:: StreamExt ;
6573use rstest:: rstest;
6674
6775/// Create a sorted Csv exec
@@ -879,6 +887,7 @@ async fn test_soft_hard_requirements_multiple_soft_requirements() -> Result<()>
879887 assert_optimized ! ( expected_input, expected_optimized, physical_plan, true ) ;
880888 Ok ( ( ) )
881889}
890+
882891#[ tokio:: test]
883892async fn test_soft_hard_requirements_multiple_sorts ( ) -> Result < ( ) > {
884893 let schema = create_test_schema ( ) ?;
@@ -3842,3 +3851,124 @@ fn test_parallelize_sort_preserves_fetch() -> Result<()> {
38423851 ) ;
38433852 Ok ( ( ) )
38443853}
3854+
3855+ #[ tokio:: test]
3856+ async fn test_partial_sort_with_homogeneous_batches ( ) -> Result < ( ) > {
3857+ // Create schema for the table
3858+ let schema = Arc :: new ( Schema :: new ( vec ! [
3859+ Field :: new( "a" , DataType :: Int32 , false ) ,
3860+ Field :: new( "b" , DataType :: Int32 , false ) ,
3861+ Field :: new( "c" , DataType :: Int32 , false ) ,
3862+ ] ) ) ;
3863+
3864+ // Create homogeneous batches - each batch has the same values for columns a and b
3865+ let batch1 = RecordBatch :: try_new (
3866+ schema. clone ( ) ,
3867+ vec ! [
3868+ Arc :: new( Int32Array :: from( vec![ 1 , 1 , 1 ] ) ) ,
3869+ Arc :: new( Int32Array :: from( vec![ 1 , 1 , 1 ] ) ) ,
3870+ Arc :: new( Int32Array :: from( vec![ 3 , 2 , 1 ] ) ) ,
3871+ ] ,
3872+ ) ?;
3873+ let batch2 = RecordBatch :: try_new (
3874+ schema. clone ( ) ,
3875+ vec ! [
3876+ Arc :: new( Int32Array :: from( vec![ 2 , 2 , 2 ] ) ) ,
3877+ Arc :: new( Int32Array :: from( vec![ 2 , 2 , 2 ] ) ) ,
3878+ Arc :: new( Int32Array :: from( vec![ 4 , 6 , 5 ] ) ) ,
3879+ ] ,
3880+ ) ?;
3881+ let batch3 = RecordBatch :: try_new (
3882+ schema. clone ( ) ,
3883+ vec ! [
3884+ Arc :: new( Int32Array :: from( vec![ 3 , 3 , 3 ] ) ) ,
3885+ Arc :: new( Int32Array :: from( vec![ 3 , 3 , 3 ] ) ) ,
3886+ Arc :: new( Int32Array :: from( vec![ 9 , 7 , 8 ] ) ) ,
3887+ ] ,
3888+ ) ?;
3889+
3890+ // Create session with batch size of 3 to match our homogeneous batch pattern
3891+ let session_config = SessionConfig :: new ( )
3892+ . with_batch_size ( 3 )
3893+ . with_target_partitions ( 1 ) ;
3894+ let ctx = SessionContext :: new_with_config ( session_config) ;
3895+
3896+ let sort_order = vec ! [
3897+ SortExpr :: new(
3898+ Expr :: Column ( datafusion_common:: Column :: new(
3899+ Option :: <TableReference >:: None ,
3900+ "a" ,
3901+ ) ) ,
3902+ true ,
3903+ false ,
3904+ ) ,
3905+ SortExpr :: new(
3906+ Expr :: Column ( datafusion_common:: Column :: new(
3907+ Option :: <TableReference >:: None ,
3908+ "b" ,
3909+ ) ) ,
3910+ true ,
3911+ false ,
3912+ ) ,
3913+ ] ;
3914+ let batches = Arc :: new ( DummyStreamPartition {
3915+ schema : schema. clone ( ) ,
3916+ batches : vec ! [ batch1, batch2, batch3] ,
3917+ } ) as _ ;
3918+ let provider = StreamingTable :: try_new ( schema. clone ( ) , vec ! [ batches] ) ?
3919+ . with_sort_order ( sort_order)
3920+ . with_infinite_table ( true ) ;
3921+ ctx. register_table ( "test_table" , Arc :: new ( provider) ) ?;
3922+
3923+ let sql = "SELECT * FROM test_table ORDER BY a ASC, c ASC" ;
3924+ let df = ctx. sql ( sql) . await ?;
3925+
3926+ let physical_plan = df. create_physical_plan ( ) . await ?;
3927+
3928+ // Verify that PartialSortExec is used
3929+ let plan_str = displayable ( physical_plan. as_ref ( ) ) . indent ( true ) . to_string ( ) ;
3930+ assert ! (
3931+ plan_str. contains( "PartialSortExec" ) ,
3932+ "Expected PartialSortExec in plan:\n {plan_str}" ,
3933+ ) ;
3934+
3935+ let task_ctx = Arc :: new ( TaskContext :: default ( ) ) ;
3936+ let mut stream = physical_plan. execute ( 0 , task_ctx. clone ( ) ) ?;
3937+
3938+ let mut collected_batches = Vec :: new ( ) ;
3939+ while let Some ( batch) = stream. next ( ) . await {
3940+ let batch = batch?;
3941+ if batch. num_rows ( ) > 0 {
3942+ collected_batches. push ( batch) ;
3943+ }
3944+ }
3945+
3946+ // Assert we got 3 separate batches (not concatenated into fewer)
3947+ assert_eq ! (
3948+ collected_batches. len( ) ,
3949+ 3 ,
3950+ "Expected 3 separate batches, got {}" ,
3951+ collected_batches. len( )
3952+ ) ;
3953+
3954+ // Verify each batch has been sorted within itself
3955+ let expected_values = [ vec ! [ 1 , 2 , 3 ] , vec ! [ 4 , 5 , 6 ] , vec ! [ 7 , 8 , 9 ] ] ;
3956+
3957+ for ( i, batch) in collected_batches. iter ( ) . enumerate ( ) {
3958+ let c_array = batch
3959+ . column ( 2 )
3960+ . as_any ( )
3961+ . downcast_ref :: < Int32Array > ( )
3962+ . unwrap ( ) ;
3963+ let actual = c_array. values ( ) . iter ( ) . copied ( ) . collect :: < Vec < i32 > > ( ) ;
3964+ assert_eq ! ( actual, expected_values[ i] , "Batch {i} not sorted correctly" , ) ;
3965+ }
3966+
3967+ assert_eq ! (
3968+ task_ctx. runtime_env( ) . memory_pool. reserved( ) ,
3969+ 0 ,
3970+ "Memory should be released after execution"
3971+ ) ;
3972+
3973+ Ok ( ( ) )
3974+ }
0 commit comments