1616# under the License.
1717
1818from datafusion_ray .context import DatafusionRayContext
19- from datafusion import SessionContext
19+ from datafusion import SessionContext , SessionConfig , RuntimeConfig , col , lit , functions as F
2020
2121
2222def test_basic_query_succeed ():
@@ -27,7 +27,7 @@ def test_basic_query_succeed():
2727 record_batch = ctx .sql ("SELECT * FROM tips" )
2828 assert record_batch .num_rows == 244
2929
30- def test_aggregate ():
30+ def test_aggregate_csv ():
3131 df_ctx = SessionContext ()
3232 ctx = DatafusionRayContext (df_ctx )
3333 df_ctx .register_csv ("tips" , "examples/tips.csv" , has_header = True )
@@ -39,6 +39,36 @@ def test_aggregate():
3939 num_rows += record_batch .num_rows
4040 assert num_rows == 4
4141
42+ def test_aggregate_parquet ():
43+ runtime = RuntimeConfig ()
44+ config = SessionConfig ().set ('datafusion.execution.parquet.schema_force_view_types' , 'true' )
45+ df_ctx = SessionContext (config , runtime )
46+ ctx = DatafusionRayContext (df_ctx )
47+ df_ctx .register_parquet ("tips" , "examples/tips.parquet" )
48+ record_batches = ctx .sql ("select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker" )
49+ assert isinstance (record_batches , list )
50+ # TODO why does this return many empty batches?
51+ num_rows = 0
52+ for record_batch in record_batches :
53+ num_rows += record_batch .num_rows
54+ assert num_rows == 4
55+
56+ def test_aggregate_parquet_dataframe ():
57+ df_ctx = SessionContext ()
58+ ray_ctx = DatafusionRayContext (df_ctx )
59+ df = df_ctx .read_parquet (f"examples/tips.parquet" )
60+ df = (
61+ df .aggregate (
62+ [col ("sex" ), col ("smoker" ), col ("day" ), col ("time" )],
63+ [F .avg (col ("tip" ) / col ("total_bill" )).alias ("tip_pct" )],
64+ )
65+ .filter (col ("day" ) != lit ("Dinner" ))
66+ .aggregate ([col ("sex" ), col ("smoker" )], [F .avg (col ("tip_pct" )).alias ("avg_pct" )])
67+ )
68+ ray_results = ray_ctx .plan (df .execution_plan ())
69+ df_ctx .create_dataframe ([ray_results ]).show ()
70+
71+
4272def test_no_result_query ():
4373 df_ctx = SessionContext ()
4474 ctx = DatafusionRayContext (df_ctx )
0 commit comments