From a87e69d9172e086fd36e17a1b9a6d51bf6a98f5e Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Fri, 14 Mar 2025 09:07:32 -0600 Subject: [PATCH] added pytest asyncio tests --- pyproject.toml | 1 + python/tests/test_dataframe.py | 54 ++++++++++++++++++++++++++++++++++ uv.lock | 17 ++++++++++- 3 files changed, 71 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 060e3b80a..a4ed18c4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -150,6 +150,7 @@ dev = [ "maturin>=1.8.1", "numpy>1.25.0", "pytest>=7.4.4", + "pytest-asyncio>=0.23.3", "ruff>=0.9.1", "toml>=0.10.2", "pygithub==2.5.0", diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index d084f12dd..384b17878 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -771,6 +771,16 @@ def test_execution_plan(aggregate_df): assert rows_returned == 5 +@pytest.mark.asyncio +async def test_async_iteration_of_df(aggregate_df): + rows_returned = 0 + async for batch in aggregate_df.execute_stream(): + assert batch is not None + rows_returned += len(batch.to_pyarrow()[0]) + + assert rows_returned == 5 + + def test_repartition(df): df.repartition(2) @@ -958,6 +968,18 @@ def test_execute_stream(df): assert not list(stream) # after one iteration the generator must be exhausted +@pytest.mark.asyncio +async def test_execute_stream_async(df): + stream = df.execute_stream() + batches = [batch async for batch in stream] + + assert all(batch is not None for batch in batches) + + # After consuming all batches, the stream should be exhausted + remaining_batches = [batch async for batch in stream] + assert not remaining_batches + + @pytest.mark.parametrize("schema", [True, False]) def test_execute_stream_to_arrow_table(df, schema): stream = df.execute_stream() @@ -974,6 +996,25 @@ def test_execute_stream_to_arrow_table(df, schema): assert set(pyarrow_table.column_names) == {"a", "b", "c"} +@pytest.mark.asyncio +@pytest.mark.parametrize("schema", [True, False]) +async def test_execute_stream_to_arrow_table_async(df, schema): + stream = df.execute_stream() + + if schema: + pyarrow_table = pa.Table.from_batches( + [batch.to_pyarrow() async for batch in stream], schema=df.schema() + ) + else: + pyarrow_table = pa.Table.from_batches( + [batch.to_pyarrow() async for batch in stream] + ) + + assert isinstance(pyarrow_table, pa.Table) + assert pyarrow_table.shape == (3, 3) + assert set(pyarrow_table.column_names) == {"a", "b", "c"} + + def test_execute_stream_partitioned(df): streams = df.execute_stream_partitioned() assert all(batch is not None for stream in streams for batch in stream) @@ -982,6 +1023,19 @@ def test_execute_stream_partitioned(df): ) # after one iteration all generators must be exhausted +@pytest.mark.asyncio +async def test_execute_stream_partitioned_async(df): + streams = df.execute_stream_partitioned() + + for stream in streams: + batches = [batch async for batch in stream] + assert all(batch is not None for batch in batches) + + # Ensure the stream is exhausted after iteration + remaining_batches = [batch async for batch in stream] + assert not remaining_batches + + def test_empty_to_arrow_table(df): # Convert empty datafusion dataframe to pyarrow Table pyarrow_table = df.limit(0).to_arrow_table() diff --git a/uv.lock b/uv.lock index 619b92856..7e4bc4c6b 100644 --- a/uv.lock +++ b/uv.lock @@ -284,9 +284,11 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "maturin" }, + { name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "numpy", version = "2.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "pygithub" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "ruff" }, { name = "toml" }, ] @@ -314,9 +316,10 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "maturin", specifier = ">=1.8.1" }, - { name = "numpy", marker = "python_full_version >= '3.10'", specifier = ">1.24.4" }, + { name = "numpy", specifier = ">1.25.0" }, { name = "pygithub", specifier = "==2.5.0" }, { name = "pytest", specifier = ">=7.4.4" }, + { name = "pytest-asyncio", specifier = ">=0.23.3" }, { name = "ruff", specifier = ">=0.9.1" }, { name = "toml", specifier = ">=0.10.2" }, ] @@ -1079,6 +1082,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083 }, ] +[[package]] +name = "pytest-asyncio" +version = "0.25.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f2/a8/ecbc8ede70921dd2f544ab1cadd3ff3bf842af27f87bbdea774c7baa1d38/pytest_asyncio-0.25.3.tar.gz", hash = "sha256:fc1da2cf9f125ada7e710b4ddad05518d4cee187ae9412e9ac9271003497f07a", size = 54239 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/67/17/3493c5624e48fd97156ebaec380dcaafee9506d7e2c46218ceebbb57d7de/pytest_asyncio-0.25.3-py3-none-any.whl", hash = "sha256:9e89518e0f9bd08928f97a3482fdc4e244df17529460bc038291ccaf8f85c7c3", size = 19467 }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0"