Skip to content

Commit 4940506

Browse files
committed
use test fixtures
1 parent 912c527 commit 4940506

File tree

2 files changed

+19
-22
lines changed

2 files changed

+19
-22
lines changed

datafusion_ray/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,5 @@ def plan(self, execution_plan: Any) -> List[pa.RecordBatch]:
160160
_, partitions = ray.get(future)
161161
# assert len(partitions) == 1, len(partitions)
162162
record_batches = ray.get(partitions[0])
163+
# filter out empty batches
163164
return [batch for batch in record_batches if batch.num_rows > 0]

tests/test_context.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,43 +17,42 @@
1717

1818
from datafusion_ray.context import DatafusionRayContext
1919
from datafusion import SessionContext, SessionConfig, RuntimeConfig, col, lit, functions as F
20+
import pytest
2021

21-
22-
def test_basic_query_succeed():
22+
@pytest.fixture
23+
def df_ctx():
24+
"""Fixture to create a DataFusion context."""
25+
# used fixed partition count so that tests are deterministic on different environments
2326
config = SessionConfig().with_target_partitions(4)
24-
df_ctx = SessionContext(config=config)
25-
ctx = DatafusionRayContext(df_ctx)
27+
return SessionContext(config=config)
28+
29+
@pytest.fixture
30+
def ctx(df_ctx):
31+
"""Fixture to create a Datafusion Ray context."""
32+
return DatafusionRayContext(df_ctx)
33+
34+
def test_basic_query_succeed(df_ctx, ctx):
2635
df_ctx.register_csv("tips", "examples/tips.csv", has_header=True)
27-
# TODO why does this return a single batch and not a list of batches?
2836
record_batches = ctx.sql("SELECT * FROM tips")
2937
assert len(record_batches) <= 4
3038
num_rows = sum(batch.num_rows for batch in record_batches)
3139
assert num_rows == 244
3240

33-
def test_aggregate_csv():
34-
config = SessionConfig().with_target_partitions(4)
35-
df_ctx = SessionContext(config=config)
36-
ctx = DatafusionRayContext(df_ctx)
41+
def test_aggregate_csv(df_ctx, ctx):
3742
df_ctx.register_csv("tips", "examples/tips.csv", has_header=True)
3843
record_batches = ctx.sql("select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker")
3944
assert len(record_batches) <= 4
4045
num_rows = sum(batch.num_rows for batch in record_batches)
4146
assert num_rows == 4
4247

43-
def test_aggregate_parquet():
44-
config = SessionConfig().with_target_partitions(4)
45-
df_ctx = SessionContext(config=config)
46-
ctx = DatafusionRayContext(df_ctx)
48+
def test_aggregate_parquet(df_ctx, ctx):
4749
df_ctx.register_parquet("tips", "examples/tips.parquet")
4850
record_batches = ctx.sql("select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker")
4951
assert len(record_batches) <= 4
5052
num_rows = sum(batch.num_rows for batch in record_batches)
5153
assert num_rows == 4
5254

53-
def test_aggregate_parquet_dataframe():
54-
config = SessionConfig().with_target_partitions(4)
55-
df_ctx = SessionContext(config=config)
56-
ray_ctx = DatafusionRayContext(df_ctx)
55+
def test_aggregate_parquet_dataframe(df_ctx, ctx):
5756
df = df_ctx.read_parquet(f"examples/tips.parquet")
5857
df = (
5958
df.aggregate(
@@ -63,13 +62,10 @@ def test_aggregate_parquet_dataframe():
6362
.filter(col("day") != lit("Dinner"))
6463
.aggregate([col("sex"), col("smoker")], [F.avg(col("tip_pct")).alias("avg_pct")])
6564
)
66-
ray_results = ray_ctx.plan(df.execution_plan())
65+
ray_results = ctx.plan(df.execution_plan())
6766
df_ctx.create_dataframe([ray_results]).show()
6867

6968

70-
def test_no_result_query():
71-
config = SessionConfig().with_target_partitions(4)
72-
df_ctx = SessionContext(config=config)
73-
ctx = DatafusionRayContext(df_ctx)
69+
def test_no_result_query(df_ctx, ctx):
7470
df_ctx.register_csv("tips", "examples/tips.csv", has_header=True)
7571
ctx.sql("CREATE VIEW tips_view AS SELECT * FROM tips")

0 commit comments

Comments
 (0)