1515// specific language governing permissions and limitations
1616// under the License.
1717
18+ use crate :: coalesce:: LimitedBatchCoalescer ;
1819use crate :: metrics:: { ExecutionPlanMetricsSet , MetricsSet } ;
1920use crate :: stream:: RecordBatchStreamAdapter ;
2021use crate :: {
@@ -24,16 +25,19 @@ use arrow::array::RecordBatch;
2425use arrow_schema:: { Fields , Schema , SchemaRef } ;
2526use datafusion_common:: tree_node:: { Transformed , TreeNode , TreeNodeRecursion } ;
2627use datafusion_common:: { Result , assert_eq_or_internal_err} ;
27- use datafusion_execution:: { SendableRecordBatchStream , TaskContext } ;
28+ use datafusion_execution:: { RecordBatchStream , SendableRecordBatchStream , TaskContext } ;
2829use datafusion_physical_expr:: ScalarFunctionExpr ;
2930use datafusion_physical_expr:: async_scalar_function:: AsyncFuncExpr ;
3031use datafusion_physical_expr:: equivalence:: ProjectionMapping ;
3132use datafusion_physical_expr:: expressions:: Column ;
3233use datafusion_physical_expr_common:: physical_expr:: PhysicalExpr ;
34+ use futures:: Stream ;
3335use futures:: stream:: StreamExt ;
3436use log:: trace;
3537use std:: any:: Any ;
38+ use std:: pin:: Pin ;
3639use std:: sync:: Arc ;
40+ use std:: task:: { Context , Poll , ready} ;
3741
3842/// This structure evaluates a set of async expressions on a record
3943/// batch producing a new record batch
@@ -188,7 +192,16 @@ impl ExecutionPlan for AsyncFuncExec {
188192 let schema_captured = self . schema ( ) ;
189193 let config_options_ref = Arc :: clone ( context. session_config ( ) . options ( ) ) ;
190194
191- let stream_with_async_functions = input_stream. then ( move |batch| {
195+ let coalesced_input_stream = CoalesceInputStream {
196+ input_stream,
197+ batch_coalescer : LimitedBatchCoalescer :: new (
198+ Arc :: clone ( & self . input . schema ( ) ) ,
199+ config_options_ref. execution . batch_size ,
200+ None ,
201+ ) ,
202+ } ;
203+
204+ let stream_with_async_functions = coalesced_input_stream. then ( move |batch| {
192205 // need to clone *again* to capture the async_exprs and schema in the
193206 // stream and satisfy lifetime requirements.
194207 let async_exprs_captured = Arc :: clone ( & async_exprs_captured) ;
@@ -221,6 +234,49 @@ impl ExecutionPlan for AsyncFuncExec {
221234 }
222235}
223236
237+ struct CoalesceInputStream {
238+ input_stream : Pin < Box < dyn RecordBatchStream + Send > > ,
239+ batch_coalescer : LimitedBatchCoalescer ,
240+ }
241+
242+ impl Stream for CoalesceInputStream {
243+ type Item = Result < RecordBatch > ;
244+
245+ fn poll_next (
246+ mut self : Pin < & mut Self > ,
247+ cx : & mut Context < ' _ > ,
248+ ) -> Poll < Option < Self :: Item > > {
249+ let mut completed = false ;
250+
251+ loop {
252+ if let Some ( batch) = self . batch_coalescer . next_completed_batch ( ) {
253+ return Poll :: Ready ( Some ( Ok ( batch) ) ) ;
254+ }
255+
256+ if completed {
257+ return Poll :: Ready ( None ) ;
258+ }
259+
260+ match ready ! ( self . input_stream. poll_next_unpin( cx) ) {
261+ Some ( Ok ( batch) ) => {
262+ if let Err ( err) = self . batch_coalescer . push_batch ( batch) {
263+ return Poll :: Ready ( Some ( Err ( err) ) ) ;
264+ }
265+ }
266+ Some ( err) => {
267+ return Poll :: Ready ( Some ( err) ) ;
268+ }
269+ None => {
270+ completed = true ;
271+ if let Err ( err) = self . batch_coalescer . finish ( ) {
272+ return Poll :: Ready ( Some ( Err ( err) ) ) ;
273+ }
274+ }
275+ }
276+ }
277+ }
278+ }
279+
224280const ASYNC_FN_PREFIX : & str = "__async_fn_" ;
225281
226282/// Maps async_expressions to new columns
@@ -307,3 +363,51 @@ impl AsyncMapper {
307363 Arc :: new ( Column :: new ( async_expr. name ( ) , output_idx) )
308364 }
309365}
366+
367+ #[ cfg( test) ]
368+ mod tests {
369+ use std:: sync:: Arc ;
370+
371+ use arrow:: array:: { RecordBatch , UInt32Array } ;
372+ use arrow_schema:: { DataType , Field , Schema } ;
373+ use datafusion_common:: Result ;
374+ use datafusion_execution:: { TaskContext , config:: SessionConfig } ;
375+ use futures:: StreamExt ;
376+
377+ use crate :: { ExecutionPlan , async_func:: AsyncFuncExec , test:: TestMemoryExec } ;
378+
379+ #[ tokio:: test]
380+ async fn test_async_fn_with_coalescing ( ) -> Result < ( ) > {
381+ let schema =
382+ Arc :: new ( Schema :: new ( vec ! [ Field :: new( "c0" , DataType :: UInt32 , false ) ] ) ) ;
383+
384+ let batch = RecordBatch :: try_new (
385+ Arc :: clone ( & schema) ,
386+ vec ! [ Arc :: new( UInt32Array :: from( vec![ 1 , 2 , 3 , 4 , 5 , 6 ] ) ) ] ,
387+ ) ?;
388+
389+ let batches: Vec < RecordBatch > = ( 0 ..50 ) . map ( |_| batch. clone ( ) ) . collect ( ) ;
390+
391+ let session_config = SessionConfig :: new ( ) . with_batch_size ( 200 ) ;
392+ let task_ctx = TaskContext :: default ( ) . with_session_config ( session_config) ;
393+ let task_ctx = Arc :: new ( task_ctx) ;
394+
395+ let test_exec =
396+ TestMemoryExec :: try_new_exec ( & [ batches] , Arc :: clone ( & schema) , None ) ?;
397+ let exec = AsyncFuncExec :: try_new ( vec ! [ ] , test_exec) ?;
398+
399+ let mut stream = exec. execute ( 0 , Arc :: clone ( & task_ctx) ) ?;
400+ let batch = stream
401+ . next ( )
402+ . await
403+ . expect ( "expected to get a record batch" ) ?;
404+ assert_eq ! ( 200 , batch. num_rows( ) ) ;
405+ let batch = stream
406+ . next ( )
407+ . await
408+ . expect ( "expected to get a record batch" ) ?;
409+ assert_eq ! ( 100 , batch. num_rows( ) ) ;
410+
411+ Ok ( ( ) )
412+ }
413+ }
0 commit comments