Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ dev = [
"maturin>=1.8.1",
"numpy>1.25.0",
"pytest>=7.4.4",
"pytest-asyncio>=0.23.3",
"ruff>=0.9.1",
"toml>=0.10.2",
"pygithub==2.5.0",
Expand Down
54 changes: 54 additions & 0 deletions python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,16 @@ def test_execution_plan(aggregate_df):
assert rows_returned == 5


@pytest.mark.asyncio
async def test_async_iteration_of_df(aggregate_df):
rows_returned = 0
async for batch in aggregate_df.execute_stream():
assert batch is not None
rows_returned += len(batch.to_pyarrow()[0])

assert rows_returned == 5


def test_repartition(df):
df.repartition(2)

Expand Down Expand Up @@ -958,6 +968,18 @@ def test_execute_stream(df):
assert not list(stream) # after one iteration the generator must be exhausted


@pytest.mark.asyncio
async def test_execute_stream_async(df):
stream = df.execute_stream()
batches = [batch async for batch in stream]

assert all(batch is not None for batch in batches)

# After consuming all batches, the stream should be exhausted
remaining_batches = [batch async for batch in stream]
assert not remaining_batches


@pytest.mark.parametrize("schema", [True, False])
def test_execute_stream_to_arrow_table(df, schema):
stream = df.execute_stream()
Expand All @@ -974,6 +996,25 @@ def test_execute_stream_to_arrow_table(df, schema):
assert set(pyarrow_table.column_names) == {"a", "b", "c"}


@pytest.mark.asyncio
@pytest.mark.parametrize("schema", [True, False])
async def test_execute_stream_to_arrow_table_async(df, schema):
stream = df.execute_stream()

if schema:
pyarrow_table = pa.Table.from_batches(
[batch.to_pyarrow() async for batch in stream], schema=df.schema()
)
else:
pyarrow_table = pa.Table.from_batches(
[batch.to_pyarrow() async for batch in stream]
)

assert isinstance(pyarrow_table, pa.Table)
assert pyarrow_table.shape == (3, 3)
assert set(pyarrow_table.column_names) == {"a", "b", "c"}


def test_execute_stream_partitioned(df):
streams = df.execute_stream_partitioned()
assert all(batch is not None for stream in streams for batch in stream)
Expand All @@ -982,6 +1023,19 @@ def test_execute_stream_partitioned(df):
) # after one iteration all generators must be exhausted


@pytest.mark.asyncio
async def test_execute_stream_partitioned_async(df):
streams = df.execute_stream_partitioned()

for stream in streams:
batches = [batch async for batch in stream]
assert all(batch is not None for batch in batches)

# Ensure the stream is exhausted after iteration
remaining_batches = [batch async for batch in stream]
assert not remaining_batches


def test_empty_to_arrow_table(df):
# Convert empty datafusion dataframe to pyarrow Table
pyarrow_table = df.limit(0).to_arrow_table()
Expand Down
17 changes: 16 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.