@@ -29,7 +29,7 @@ use datafusion::physical_plan::execution_plan::Boundedness;
2929use datafusion:: physical_plan:: ExecutionPlan ;
3030use datafusion:: prelude:: SessionContext ;
3131use datafusion_common:: { DataFusionError , JoinType , ScalarValue } ;
32- use datafusion_execution:: TaskContext ;
32+ use datafusion_execution:: { SendableRecordBatchStream , TaskContext } ;
3333use datafusion_expr_common:: operator:: Operator ;
3434use datafusion_expr_common:: operator:: Operator :: { Divide , Eq , Gt , Modulo } ;
3535use datafusion_functions_aggregate:: min_max;
@@ -42,12 +42,14 @@ use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
4242use datafusion_physical_optimizer:: ensure_coop:: EnsureCooperative ;
4343use datafusion_physical_optimizer:: PhysicalOptimizerRule ;
4444use datafusion_physical_plan:: coalesce_batches:: CoalesceBatchesExec ;
45+ use datafusion_physical_plan:: coop:: make_cooperative;
4546use datafusion_physical_plan:: filter:: FilterExec ;
4647use datafusion_physical_plan:: joins:: { HashJoinExec , PartitionMode , SortMergeJoinExec } ;
4748use datafusion_physical_plan:: memory:: { LazyBatchGenerator , LazyMemoryExec } ;
4849use datafusion_physical_plan:: projection:: ProjectionExec ;
4950use datafusion_physical_plan:: repartition:: RepartitionExec ;
5051use datafusion_physical_plan:: sorts:: sort:: SortExec ;
52+ use datafusion_physical_plan:: stream:: RecordBatchStreamAdapter ;
5153use datafusion_physical_plan:: union:: InterleaveExec ;
5254use futures:: StreamExt ;
5355use parking_lot:: RwLock ;
@@ -250,6 +252,58 @@ async fn agg_grouped_topk_yields(
250252 query_yields ( aggr, session_ctx. task_ctx ( ) ) . await
251253}
252254
255+ #[ rstest]
256+ #[ tokio:: test]
257+ // A test that mocks the behavior of `SpillManager::read_spill_as_stream` without file access
258+ // to verify that a cooperative stream would properly yields in a spill file read scenario
259+ async fn spill_reader_stream_yield ( ) -> Result < ( ) , Box < dyn Error > > {
260+ use datafusion_physical_plan:: common:: spawn_buffered;
261+
262+ // A mock stream that always returns `Poll::Ready(Some(...))` immediately
263+ let always_ready =
264+ make_lazy_exec ( "value" , false ) . execute ( 0 , SessionContext :: new ( ) . task_ctx ( ) ) ?;
265+
266+ // this function makes a consumer stream that resembles how read_stream from spill file is constructed
267+ let stream = make_cooperative ( always_ready) ;
268+
269+ // Set large buffer so that buffer always has free space for the producer/sender
270+ let buffer_capacity = 100_000 ;
271+ let mut mock_stream = spawn_buffered ( stream, buffer_capacity) ;
272+ let schema = mock_stream. schema ( ) ;
273+
274+ let consumer_stream = futures:: stream:: poll_fn ( move |cx| {
275+ let mut collected = vec ! [ ] ;
276+ // To make sure that inner stream is polled multiple times, loop until the buffer is full
277+ // Ideally, the stream will yield before the loop ends
278+ for _ in 0 ..buffer_capacity {
279+ match mock_stream. as_mut ( ) . poll_next ( cx) {
280+ Poll :: Ready ( Some ( Ok ( batch) ) ) => {
281+ collected. push ( batch) ;
282+ }
283+ Poll :: Ready ( Some ( Err ( e) ) ) => {
284+ return Poll :: Ready ( Some ( Err ( e) ) ) ;
285+ }
286+ Poll :: Ready ( None ) => {
287+ break ;
288+ }
289+ Poll :: Pending => {
290+ // polling inner stream may return Pending only when it reaches budget, since
291+ // we intentionally made ProducerStream always return Ready
292+ return Poll :: Pending ;
293+ }
294+ }
295+ }
296+
297+ // This should be unreachable since the stream is canceled
298+ unreachable ! ( "Expected the stream to be canceled, but it continued polling" ) ;
299+ } ) ;
300+
301+ let consumer_record_batch_stream =
302+ Box :: pin ( RecordBatchStreamAdapter :: new ( schema, consumer_stream) ) ;
303+
304+ stream_yields ( consumer_record_batch_stream) . await
305+ }
306+
253307#[ rstest]
254308#[ tokio:: test]
255309async fn sort_yields (
@@ -698,17 +752,9 @@ enum Yielded {
698752 Timeout ,
699753}
700754
701- async fn query_yields (
702- plan : Arc < dyn ExecutionPlan > ,
703- task_ctx : Arc < TaskContext > ,
755+ async fn stream_yields (
756+ mut stream : SendableRecordBatchStream ,
704757) -> Result < ( ) , Box < dyn Error > > {
705- // Run plan through EnsureCooperative
706- let optimized =
707- EnsureCooperative :: new ( ) . optimize ( plan, task_ctx. session_config ( ) . options ( ) ) ?;
708-
709- // Get the stream
710- let mut stream = physical_plan:: execute_stream ( optimized, task_ctx) ?;
711-
712758 // Create an independent executor pool
713759 let child_runtime = Runtime :: new ( ) ?;
714760
@@ -753,3 +799,18 @@ async fn query_yields(
753799 ) ;
754800 Ok ( ( ) )
755801}
802+
803+ async fn query_yields (
804+ plan : Arc < dyn ExecutionPlan > ,
805+ task_ctx : Arc < TaskContext > ,
806+ ) -> Result < ( ) , Box < dyn Error > > {
807+ // Run plan through EnsureCooperative
808+ let optimized =
809+ EnsureCooperative :: new ( ) . optimize ( plan, task_ctx. session_config ( ) . options ( ) ) ?;
810+
811+ // Get the stream
812+ let stream = physical_plan:: execute_stream ( optimized, task_ctx) ?;
813+
814+ // Spawn a task that tries to poll the stream and check whether given stream yields
815+ stream_yields ( stream) . await
816+ }
0 commit comments