Skip to content

Commit ebe403a

Browse files
committed
tests
1 parent 4e60563 commit ebe403a

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

src/query_stage.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use crate::context::serialize_execution_plan;
19-
use crate::shuffle::{ShuffleCodec, ShuffleReaderExec};
19+
use crate::shuffle::{ShuffleCodec, ShuffleReaderExec, ShuffleWriterExec};
2020
use datafusion::error::Result;
2121
use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties, Partitioning};
2222
use datafusion::prelude::SessionContext;
@@ -99,7 +99,14 @@ impl QueryStage {
9999
/// Get the input partition count. This is the same as the number of concurrent tasks
100100
/// when we schedule this query stage for execution
101101
pub fn get_input_partition_count(&self) -> usize {
102-
self.plan.output_partitioning().partition_count()
102+
self.plan.children()[0].output_partitioning().partition_count()
103+
if self.plan.as_any().is::<ShuffleWriterExec>() {
104+
// most query stages represent a shuffle write
105+
self.plan.children()[0].output_partitioning().partition_count()
106+
} else {
107+
// probably the final query stage
108+
self.plan.output_partitioning().partition_count()
109+
}
103110
}
104111

105112
pub fn get_output_partition_count(&self) -> usize {

tests/test_context.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,21 @@ def test_basic_query_succeed():
2323
df_ctx = SessionContext()
2424
ctx = DatafusionRayContext(df_ctx)
2525
df_ctx.register_csv("tips", "examples/tips.csv", has_header=True)
26+
# TODO why does this return a single batch and not a list of batches?
2627
record_batch = ctx.sql("SELECT * FROM tips")
2728
assert record_batch.num_rows == 244
2829

30+
def test_aggregate():
31+
df_ctx = SessionContext()
32+
ctx = DatafusionRayContext(df_ctx)
33+
df_ctx.register_csv("tips", "examples/tips.csv", has_header=True)
34+
record_batches = ctx.sql("select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker")
35+
assert isinstance(record_batches, list)
36+
# TODO why does this return many empty batches?
37+
num_rows = 0
38+
for record_batch in record_batches:
39+
num_rows += record_batch.num_rows
40+
assert num_rows == 4
2941

3042
def test_no_result_query():
3143
df_ctx = SessionContext()

0 commit comments

Comments
 (0)