1515// specific language governing permissions and limitations 
1616// under the License. 
1717
18+ use  std:: sync:: Arc ; 
19+ 
1820use  crate :: utils:: wait_for_future; 
1921use  datafusion:: arrow:: pyarrow:: ToPyArrow ; 
2022use  datafusion:: arrow:: record_batch:: RecordBatch ; 
2123use  datafusion:: physical_plan:: SendableRecordBatchStream ; 
2224use  futures:: StreamExt ; 
25+ use  pyo3:: exceptions:: { PyStopAsyncIteration ,  PyStopIteration } ; 
2326use  pyo3:: prelude:: * ; 
2427use  pyo3:: { pyclass,  pymethods,  PyObject ,  PyResult ,  Python } ; 
28+ use  tokio:: sync:: Mutex ; 
2529
2630#[ pyclass( name = "RecordBatch" ,  module = "datafusion" ,  subclass) ]  
2731pub  struct  PyRecordBatch  { 
@@ -43,31 +47,58 @@ impl From<RecordBatch> for PyRecordBatch {
4347
4448#[ pyclass( name = "RecordBatchStream" ,  module = "datafusion" ,  subclass) ]  
4549pub  struct  PyRecordBatchStream  { 
46-     stream :  SendableRecordBatchStream , 
50+     stream :  Arc < Mutex < SendableRecordBatchStream > > , 
4751} 
4852
4953impl  PyRecordBatchStream  { 
5054    pub  fn  new ( stream :  SendableRecordBatchStream )  -> Self  { 
51-         Self  {  stream } 
55+         Self  { 
56+             stream :  Arc :: new ( Mutex :: new ( stream) ) , 
57+         } 
5258    } 
5359} 
5460
5561#[ pymethods]  
5662impl  PyRecordBatchStream  { 
57-     fn  next ( & mut  self ,  py :  Python )  -> PyResult < Option < PyRecordBatch > >  { 
58-         let  result = self . stream . next ( ) ; 
59-         match  wait_for_future ( py,  result)  { 
60-             None  => Ok ( None ) , 
61-             Some ( Ok ( b) )  => Ok ( Some ( b. into ( ) ) ) , 
62-             Some ( Err ( e) )  => Err ( e. into ( ) ) , 
63-         } 
63+     fn  next ( & mut  self ,  py :  Python )  -> PyResult < PyRecordBatch >  { 
64+         let  stream = self . stream . clone ( ) ; 
65+         wait_for_future ( py,  next_stream ( stream,  true ) ) 
6466    } 
6567
66-     fn  __next__ ( & mut  self ,  py :  Python )  -> PyResult < Option < PyRecordBatch > >  { 
68+     fn  __next__ ( & mut  self ,  py :  Python )  -> PyResult < PyRecordBatch >  { 
6769        self . next ( py) 
6870    } 
6971
72+     fn  __anext__ < ' py > ( & ' py  self ,  py :  Python < ' py > )  -> PyResult < Bound < ' py ,  PyAny > >  { 
73+         let  stream = self . stream . clone ( ) ; 
74+         pyo3_async_runtimes:: tokio:: future_into_py ( py,  next_stream ( stream,  false ) ) 
75+     } 
76+ 
7077    fn  __iter__ ( slf :  PyRef < ' _ ,  Self > )  -> PyRef < ' _ ,  Self >  { 
7178        slf
7279    } 
80+ 
81+     fn  __aiter__ ( slf :  PyRef < ' _ ,  Self > )  -> PyRef < ' _ ,  Self >  { 
82+         slf
83+     } 
84+ } 
85+ 
86+ async  fn  next_stream ( 
87+     stream :  Arc < Mutex < SendableRecordBatchStream > > , 
88+     sync :  bool , 
89+ )  -> PyResult < PyRecordBatch >  { 
90+     let  mut  stream = stream. lock ( ) . await ; 
91+     match  stream. next ( ) . await  { 
92+         Some ( Ok ( batch) )  => Ok ( batch. into ( ) ) , 
93+         Some ( Err ( e) )  => Err ( e. into ( ) ) , 
94+         None  => { 
95+             // Depending on whether the iteration is sync or not, we raise either a 
96+             // StopIteration or a StopAsyncIteration 
97+             if  sync { 
98+                 Err ( PyStopIteration :: new_err ( "stream exhausted" ) ) 
99+             }  else  { 
100+                 Err ( PyStopAsyncIteration :: new_err ( "stream exhausted" ) ) 
101+             } 
102+         } 
103+     } 
73104} 
0 commit comments