1515// specific language governing permissions and limitations
1616// under the License.
1717
18+ use std:: sync:: Arc ;
19+
1820use crate :: utils:: wait_for_future;
1921use datafusion:: arrow:: pyarrow:: ToPyArrow ;
2022use datafusion:: arrow:: record_batch:: RecordBatch ;
2123use datafusion:: physical_plan:: SendableRecordBatchStream ;
2224use futures:: StreamExt ;
25+ use pyo3:: exceptions:: { PyStopAsyncIteration , PyStopIteration } ;
2326use pyo3:: prelude:: * ;
2427use pyo3:: { pyclass, pymethods, PyObject , PyResult , Python } ;
28+ use tokio:: sync:: Mutex ;
2529
2630#[ pyclass( name = "RecordBatch" , module = "datafusion" , subclass) ]
2731pub struct PyRecordBatch {
@@ -43,31 +47,58 @@ impl From<RecordBatch> for PyRecordBatch {
4347
4448#[ pyclass( name = "RecordBatchStream" , module = "datafusion" , subclass) ]
4549pub struct PyRecordBatchStream {
46- stream : SendableRecordBatchStream ,
50+ stream : Arc < Mutex < SendableRecordBatchStream > > ,
4751}
4852
4953impl PyRecordBatchStream {
5054 pub fn new ( stream : SendableRecordBatchStream ) -> Self {
51- Self { stream }
55+ Self {
56+ stream : Arc :: new ( Mutex :: new ( stream) ) ,
57+ }
5258 }
5359}
5460
5561#[ pymethods]
5662impl PyRecordBatchStream {
57- fn next ( & mut self , py : Python ) -> PyResult < Option < PyRecordBatch > > {
58- let result = self . stream . next ( ) ;
59- match wait_for_future ( py, result) {
60- None => Ok ( None ) ,
61- Some ( Ok ( b) ) => Ok ( Some ( b. into ( ) ) ) ,
62- Some ( Err ( e) ) => Err ( e. into ( ) ) ,
63- }
63+ fn next ( & mut self , py : Python ) -> PyResult < PyRecordBatch > {
64+ let stream = self . stream . clone ( ) ;
65+ wait_for_future ( py, next_stream ( stream, true ) )
6466 }
6567
66- fn __next__ ( & mut self , py : Python ) -> PyResult < Option < PyRecordBatch > > {
68+ fn __next__ ( & mut self , py : Python ) -> PyResult < PyRecordBatch > {
6769 self . next ( py)
6870 }
6971
72+ fn __anext__ < ' py > ( & ' py self , py : Python < ' py > ) -> PyResult < Bound < ' py , PyAny > > {
73+ let stream = self . stream . clone ( ) ;
74+ pyo3_async_runtimes:: tokio:: future_into_py ( py, next_stream ( stream, false ) )
75+ }
76+
7077 fn __iter__ ( slf : PyRef < ' _ , Self > ) -> PyRef < ' _ , Self > {
7178 slf
7279 }
80+
81+ fn __aiter__ ( slf : PyRef < ' _ , Self > ) -> PyRef < ' _ , Self > {
82+ slf
83+ }
84+ }
85+
86+ async fn next_stream (
87+ stream : Arc < Mutex < SendableRecordBatchStream > > ,
88+ sync : bool ,
89+ ) -> PyResult < PyRecordBatch > {
90+ let mut stream = stream. lock ( ) . await ;
91+ match stream. next ( ) . await {
92+ Some ( Ok ( batch) ) => Ok ( batch. into ( ) ) ,
93+ Some ( Err ( e) ) => Err ( e. into ( ) ) ,
94+ None => {
95+ // Depending on whether the iteration is sync or not, we raise either a
96+ // StopIteration or a StopAsyncIteration
97+ if sync {
98+ Err ( PyStopIteration :: new_err ( "stream exhausted" ) )
99+ } else {
100+ Err ( PyStopAsyncIteration :: new_err ( "stream exhausted" ) )
101+ }
102+ }
103+ }
73104}
0 commit comments