1717
1818from datafusion_ray .context import DatafusionRayContext
1919from datafusion import SessionContext , SessionConfig , RuntimeConfig , col , lit , functions as F
20+ import pytest
2021
21-
22- def test_basic_query_succeed ():
22+ @pytest .fixture
23+ def df_ctx ():
24+ """Fixture to create a DataFusion context."""
25+ # used fixed partition count so that tests are deterministic on different environments
2326 config = SessionConfig ().with_target_partitions (4 )
24- df_ctx = SessionContext (config = config )
25- ctx = DatafusionRayContext (df_ctx )
27+ return SessionContext (config = config )
28+
29+ @pytest .fixture
30+ def ctx (df_ctx ):
31+ """Fixture to create a Datafusion Ray context."""
32+ return DatafusionRayContext (df_ctx )
33+
34+ def test_basic_query_succeed (df_ctx , ctx ):
2635 df_ctx .register_csv ("tips" , "examples/tips.csv" , has_header = True )
27- # TODO why does this return a single batch and not a list of batches?
2836 record_batches = ctx .sql ("SELECT * FROM tips" )
2937 assert len (record_batches ) <= 4
3038 num_rows = sum (batch .num_rows for batch in record_batches )
3139 assert num_rows == 244
3240
33- def test_aggregate_csv ():
34- config = SessionConfig ().with_target_partitions (4 )
35- df_ctx = SessionContext (config = config )
36- ctx = DatafusionRayContext (df_ctx )
41+ def test_aggregate_csv (df_ctx , ctx ):
3742 df_ctx .register_csv ("tips" , "examples/tips.csv" , has_header = True )
3843 record_batches = ctx .sql ("select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker" )
3944 assert len (record_batches ) <= 4
4045 num_rows = sum (batch .num_rows for batch in record_batches )
4146 assert num_rows == 4
4247
43- def test_aggregate_parquet ():
44- config = SessionConfig ().with_target_partitions (4 )
45- df_ctx = SessionContext (config = config )
46- ctx = DatafusionRayContext (df_ctx )
48+ def test_aggregate_parquet (df_ctx , ctx ):
4749 df_ctx .register_parquet ("tips" , "examples/tips.parquet" )
4850 record_batches = ctx .sql ("select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker" )
4951 assert len (record_batches ) <= 4
5052 num_rows = sum (batch .num_rows for batch in record_batches )
5153 assert num_rows == 4
5254
53- def test_aggregate_parquet_dataframe ():
54- config = SessionConfig ().with_target_partitions (4 )
55- df_ctx = SessionContext (config = config )
56- ray_ctx = DatafusionRayContext (df_ctx )
55+ def test_aggregate_parquet_dataframe (df_ctx , ctx ):
5756 df = df_ctx .read_parquet (f"examples/tips.parquet" )
5857 df = (
5958 df .aggregate (
@@ -63,13 +62,10 @@ def test_aggregate_parquet_dataframe():
6362 .filter (col ("day" ) != lit ("Dinner" ))
6463 .aggregate ([col ("sex" ), col ("smoker" )], [F .avg (col ("tip_pct" )).alias ("avg_pct" )])
6564 )
66- ray_results = ray_ctx .plan (df .execution_plan ())
65+ ray_results = ctx .plan (df .execution_plan ())
6766 df_ctx .create_dataframe ([ray_results ]).show ()
6867
6968
70- def test_no_result_query ():
71- config = SessionConfig ().with_target_partitions (4 )
72- df_ctx = SessionContext (config = config )
73- ctx = DatafusionRayContext (df_ctx )
69+ def test_no_result_query (df_ctx , ctx ):
7470 df_ctx .register_csv ("tips" , "examples/tips.csv" , has_header = True )
7571 ctx .sql ("CREATE VIEW tips_view AS SELECT * FROM tips" )
0 commit comments