@@ -31,15 +31,19 @@ use datafusion::physical_plan::repartition::RepartitionExec;
3131use datafusion:: physical_plan:: sorts:: sort:: SortExec ;
3232use datafusion:: physical_plan:: { ExecutionPlan , ExecutionPlanProperties } ;
3333use datafusion:: prelude:: DataFrame ;
34+ use datafusion_python:: errors:: PyDataFusionError ;
3435use datafusion_python:: physical_plan:: PyExecutionPlan ;
3536use datafusion_python:: sql:: logical:: PyLogicalPlan ;
3637use datafusion_python:: utils:: wait_for_future;
3738use futures:: stream:: StreamExt ;
3839use itertools:: Itertools ;
3940use log:: trace;
41+ use pyo3:: exceptions:: PyStopAsyncIteration ;
42+ use pyo3:: exceptions:: PyStopIteration ;
4043use pyo3:: prelude:: * ;
4144use std:: borrow:: Cow ;
4245use std:: sync:: Arc ;
46+ use tokio:: sync:: Mutex ;
4347
4448use crate :: isolator:: PartitionIsolatorExec ;
4549use crate :: max_rows:: MaxRowsExec ;
@@ -428,9 +432,12 @@ impl PyDataFrameStage {
428432 }
429433}
430434
435+ // PyRecordBatch and PyRecordBatchStream are borrowed, and slightly modified from datafusion-python
436+ // they are not publicly exposed in that repo
437+
431438#[ pyclass]
432439pub struct PyRecordBatch {
433- batch : RecordBatch ,
440+ pub batch : RecordBatch ,
434441}
435442
436443#[ pymethods]
@@ -448,31 +455,58 @@ impl From<RecordBatch> for PyRecordBatch {
448455
449456#[ pyclass]
450457pub struct PyRecordBatchStream {
451- stream : SendableRecordBatchStream ,
458+ stream : Arc < Mutex < SendableRecordBatchStream > > ,
452459}
453460
454461impl PyRecordBatchStream {
455462 pub fn new ( stream : SendableRecordBatchStream ) -> Self {
456- Self { stream }
463+ Self {
464+ stream : Arc :: new ( Mutex :: new ( stream) ) ,
465+ }
457466 }
458467}
459468
460469#[ pymethods]
461470impl PyRecordBatchStream {
462- fn next ( & mut self , py : Python ) -> PyResult < Option < PyObject > > {
463- let result = self . stream . next ( ) ;
464- match wait_for_future ( py, result) {
465- None => Ok ( None ) ,
466- Some ( Ok ( b) ) => Ok ( Some ( b. to_pyarrow ( py) ?) ) ,
467- Some ( Err ( e) ) => Err ( e. into ( ) ) ,
468- }
471+ fn next ( & mut self , py : Python ) -> PyResult < PyObject > {
472+ let stream = self . stream . clone ( ) ;
473+ wait_for_future ( py, next_stream ( stream, true ) ) . and_then ( |b| b. to_pyarrow ( py) )
469474 }
470475
471- fn __next__ ( & mut self , py : Python ) -> PyResult < Option < PyObject > > {
476+ fn __next__ ( & mut self , py : Python ) -> PyResult < PyObject > {
472477 self . next ( py)
473478 }
474479
480+ fn __anext__ < ' py > ( & ' py self , py : Python < ' py > ) -> PyResult < Bound < ' py , PyAny > > {
481+ let stream = self . stream . clone ( ) ;
482+ pyo3_async_runtimes:: tokio:: future_into_py ( py, next_stream ( stream, false ) )
483+ }
484+
475485 fn __iter__ ( slf : PyRef < ' _ , Self > ) -> PyRef < ' _ , Self > {
476486 slf
477487 }
488+
489+ fn __aiter__ ( slf : PyRef < ' _ , Self > ) -> PyRef < ' _ , Self > {
490+ slf
491+ }
492+ }
493+
494+ async fn next_stream (
495+ stream : Arc < Mutex < SendableRecordBatchStream > > ,
496+ sync : bool ,
497+ ) -> PyResult < PyRecordBatch > {
498+ let mut stream = stream. lock ( ) . await ;
499+ match stream. next ( ) . await {
500+ Some ( Ok ( batch) ) => Ok ( batch. into ( ) ) ,
501+ Some ( Err ( e) ) => Err ( PyDataFusionError :: from ( e) ) ?,
502+ None => {
503+ // Depending on whether the iteration is sync or not, we raise either a
504+ // StopIteration or a StopAsyncIteration
505+ if sync {
506+ Err ( PyStopIteration :: new_err ( "stream exhausted" ) )
507+ } else {
508+ Err ( PyStopAsyncIteration :: new_err ( "stream exhausted" ) )
509+ }
510+ }
511+ }
478512}
0 commit comments