@@ -771,6 +771,16 @@ def test_execution_plan(aggregate_df):
771771 assert rows_returned == 5
772772
773773
774+ @pytest .mark .asyncio
775+ async def test_async_iteration_of_df (aggregate_df ):
776+ rows_returned = 0
777+ async for batch in aggregate_df .execute_stream ():
778+ assert batch is not None
779+ rows_returned += len (batch .to_pyarrow ()[0 ])
780+
781+ assert rows_returned == 5
782+
783+
774784def test_repartition (df ):
775785 df .repartition (2 )
776786
@@ -958,6 +968,18 @@ def test_execute_stream(df):
958968 assert not list (stream ) # after one iteration the generator must be exhausted
959969
960970
971+ @pytest .mark .asyncio
972+ async def test_execute_stream_async (df ):
973+ stream = df .execute_stream ()
974+ batches = [batch async for batch in stream ]
975+
976+ assert all (batch is not None for batch in batches )
977+
978+ # After consuming all batches, the stream should be exhausted
979+ remaining_batches = [batch async for batch in stream ]
980+ assert not remaining_batches
981+
982+
961983@pytest .mark .parametrize ("schema" , [True , False ])
962984def test_execute_stream_to_arrow_table (df , schema ):
963985 stream = df .execute_stream ()
@@ -974,6 +996,25 @@ def test_execute_stream_to_arrow_table(df, schema):
974996 assert set (pyarrow_table .column_names ) == {"a" , "b" , "c" }
975997
976998
999+ @pytest .mark .asyncio
1000+ @pytest .mark .parametrize ("schema" , [True , False ])
1001+ async def test_execute_stream_to_arrow_table_async (df , schema ):
1002+ stream = df .execute_stream ()
1003+
1004+ if schema :
1005+ pyarrow_table = pa .Table .from_batches (
1006+ [batch .to_pyarrow () async for batch in stream ], schema = df .schema ()
1007+ )
1008+ else :
1009+ pyarrow_table = pa .Table .from_batches (
1010+ [batch .to_pyarrow () async for batch in stream ]
1011+ )
1012+
1013+ assert isinstance (pyarrow_table , pa .Table )
1014+ assert pyarrow_table .shape == (3 , 3 )
1015+ assert set (pyarrow_table .column_names ) == {"a" , "b" , "c" }
1016+
1017+
9771018def test_execute_stream_partitioned (df ):
9781019 streams = df .execute_stream_partitioned ()
9791020 assert all (batch is not None for stream in streams for batch in stream )
@@ -982,6 +1023,19 @@ def test_execute_stream_partitioned(df):
9821023 ) # after one iteration all generators must be exhausted
9831024
9841025
1026+ @pytest .mark .asyncio
1027+ async def test_execute_stream_partitioned_async (df ):
1028+ streams = df .execute_stream_partitioned ()
1029+
1030+ for stream in streams :
1031+ batches = [batch async for batch in stream ]
1032+ assert all (batch is not None for batch in batches )
1033+
1034+ # Ensure the stream is exhausted after iteration
1035+ remaining_batches = [batch async for batch in stream ]
1036+ assert not remaining_batches
1037+
1038+
9851039def test_empty_to_arrow_table (df ):
9861040 # Convert empty datafusion dataframe to pyarrow Table
9871041 pyarrow_table = df .limit (0 ).to_arrow_table ()
0 commit comments