Skip to content

Commit 2783613

Browse files
authored
chore: Make query stage / shuffle code easier to understand (#54)
1 parent 151a0e2 commit 2783613

File tree

23 files changed

+72
-77
lines changed

23 files changed

+72
-77
lines changed

datafusion_ray/context.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def execute_query_stage(
5050

5151
# if the query stage has a single output partition then we need to execute for the output
5252
# partition, otherwise we need to execute in parallel for each input partition
53-
concurrency = stage.get_input_partition_count()
53+
concurrency = stage.get_execution_partition_count()
5454
output_partitions_count = stage.get_output_partition_count()
5555
if output_partitions_count == 1:
5656
# reduce stage
@@ -159,5 +159,6 @@ def plan(self, execution_plan: Any) -> List[pa.RecordBatch]:
159159
)
160160
_, partitions = ray.get(future)
161161
# assert len(partitions) == 1, len(partitions)
162-
result_set = ray.get(partitions[0])
163-
return result_set
162+
record_batches = ray.get(partitions[0])
163+
# filter out empty batches
164+
return [batch for batch in record_batches if batch.num_rows > 0]

src/planner.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ mod test {
399399
let query_stage = graph.query_stages.get(&id).unwrap();
400400
output.push_str(&format!(
401401
"Query Stage #{id} ({} -> {}):\n{}\n",
402-
query_stage.get_input_partition_count(),
402+
query_stage.get_execution_partition_count(),
403403
query_stage.get_output_partition_count(),
404404
displayable(query_stage.plan.as_ref()).indent(false)
405405
));

src/query_stage.rs

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use crate::context::serialize_execution_plan;
19-
use crate::shuffle::{ShuffleCodec, ShuffleReaderExec};
19+
use crate::shuffle::{ShuffleCodec, ShuffleReaderExec, ShuffleWriterExec};
2020
use datafusion::error::Result;
2121
use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties, Partitioning};
2222
use datafusion::prelude::SessionContext;
@@ -60,8 +60,8 @@ impl PyQueryStage {
6060
self.stage.get_child_stage_ids()
6161
}
6262

63-
pub fn get_input_partition_count(&self) -> usize {
64-
self.stage.get_input_partition_count()
63+
pub fn get_execution_partition_count(&self) -> usize {
64+
self.stage.get_execution_partition_count()
6565
}
6666

6767
pub fn get_output_partition_count(&self) -> usize {
@@ -75,16 +75,6 @@ pub struct QueryStage {
7575
pub plan: Arc<dyn ExecutionPlan>,
7676
}
7777

78-
fn _get_output_partition_count(plan: &dyn ExecutionPlan) -> usize {
79-
// UnknownPartitioning and HashPartitioning with empty expressions will
80-
// both return 1 partition.
81-
match plan.properties().output_partitioning() {
82-
Partitioning::UnknownPartitioning(_) => 1,
83-
Partitioning::Hash(expr, _) if expr.is_empty() => 1,
84-
p => p.partition_count(),
85-
}
86-
}
87-
8878
impl QueryStage {
8979
pub fn new(id: usize, plan: Arc<dyn ExecutionPlan>) -> Self {
9080
Self { id, plan }
@@ -96,21 +86,27 @@ impl QueryStage {
9686
ids
9787
}
9888

99-
/// Get the input partition count. This is the same as the number of concurrent tasks
100-
/// when we schedule this query stage for execution
101-
pub fn get_input_partition_count(&self) -> usize {
102-
if self.plan.children().is_empty() {
103-
// leaf node (file scan)
104-
self.plan.output_partitioning().partition_count()
89+
/// Get the number of partitions that can be executed in parallel
90+
pub fn get_execution_partition_count(&self) -> usize {
91+
if let Some(shuffle) = self.plan.as_any().downcast_ref::<ShuffleWriterExec>() {
92+
// use the partitioning of the input to the shuffle write because we are
93+
// really executing that and then using the shuffle writer to repartition
94+
// the output
95+
shuffle.input_plan.output_partitioning().partition_count()
10596
} else {
106-
self.plan.children()[0]
107-
.output_partitioning()
108-
.partition_count()
97+
// for any other plan, use its output partitioning
98+
self.plan.output_partitioning().partition_count()
10999
}
110100
}
111101

112102
pub fn get_output_partition_count(&self) -> usize {
113-
_get_output_partition_count(self.plan.as_ref())
103+
// UnknownPartitioning and HashPartitioning with empty expressions will
104+
// both return 1 partition.
105+
match self.plan.properties().output_partitioning() {
106+
Partitioning::UnknownPartitioning(_) => 1,
107+
Partitioning::Hash(expr, _) if expr.is_empty() => 1,
108+
p => p.partition_count(),
109+
}
114110
}
115111
}
116112

src/shuffle/codec.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ impl PhysicalExtensionCodec for ShuffleCodec {
102102
};
103103
PlanType::ShuffleReader(reader)
104104
} else if let Some(writer) = node.as_any().downcast_ref::<ShuffleWriterExec>() {
105-
let plan = PhysicalPlanNode::try_from_physical_plan(writer.plan.clone(), self)?;
105+
let plan = PhysicalPlanNode::try_from_physical_plan(writer.input_plan.clone(), self)?;
106106
let partitioning =
107107
encode_partitioning_scheme(writer.properties().output_partitioning())?;
108108
let writer = ShuffleWriterExecNode {

src/shuffle/writer.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ use std::sync::Arc;
4747
#[derive(Debug)]
4848
pub struct ShuffleWriterExec {
4949
pub stage_id: usize,
50-
pub(crate) plan: Arc<dyn ExecutionPlan>,
50+
pub(crate) input_plan: Arc<dyn ExecutionPlan>,
5151
/// Output partitioning
5252
properties: PlanProperties,
5353
/// Directory to write shuffle files from
@@ -84,7 +84,7 @@ impl ShuffleWriterExec {
8484

8585
Self {
8686
stage_id,
87-
plan,
87+
input_plan: plan,
8888
properties,
8989
shuffle_dir: shuffle_dir.to_string(),
9090
metrics: ExecutionPlanMetricsSet::new(),
@@ -98,11 +98,11 @@ impl ExecutionPlan for ShuffleWriterExec {
9898
}
9999

100100
fn schema(&self) -> SchemaRef {
101-
self.plan.schema()
101+
self.input_plan.schema()
102102
}
103103

104104
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
105-
vec![&self.plan]
105+
vec![&self.input_plan]
106106
}
107107

108108
fn with_new_children(
@@ -122,7 +122,7 @@ impl ExecutionPlan for ShuffleWriterExec {
122122
self.stage_id
123123
);
124124

125-
let mut stream = self.plan.execute(input_partition, context)?;
125+
let mut stream = self.input_plan.execute(input_partition, context)?;
126126
let write_time =
127127
MetricBuilder::new(&self.metrics).subset_time("write_time", input_partition);
128128
let repart_time =

testdata/expected-plans/q1.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ ShuffleWriterExec(stage_id=1, output_partitioning=Hash([Column { name: "l_return
4242
CoalesceBatchesExec: target_batch_size=8192
4343
ShuffleReaderExec(stage_id=0, input_partitioning=Hash([Column { name: "l_returnflag", index: 0 }, Column { name: "l_linestatus", index: 1 }], 2))
4444

45-
Query Stage #2 (2 -> 1):
45+
Query Stage #2 (1 -> 1):
4646
SortPreservingMergeExec: [l_returnflag@0 ASC NULLS LAST, l_linestatus@1 ASC NULLS LAST]
4747
ShuffleReaderExec(stage_id=1, input_partitioning=Hash([Column { name: "l_returnflag", index: 0 }, Column { name: "l_linestatus", index: 1 }], 2))
4848

testdata/expected-plans/q10.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ ShuffleWriterExec(stage_id=7, output_partitioning=Hash([Column { name: "c_custke
117117
CoalesceBatchesExec: target_batch_size=8192
118118
ShuffleReaderExec(stage_id=6, input_partitioning=Hash([Column { name: "c_custkey", index: 0 }, Column { name: "c_name", index: 1 }, Column { name: "c_acctbal", index: 2 }, Column { name: "c_phone", index: 3 }, Column { name: "n_name", index: 4 }, Column { name: "c_address", index: 5 }, Column { name: "c_comment", index: 6 }], 2))
119119

120-
Query Stage #8 (2 -> 1):
120+
Query Stage #8 (1 -> 1):
121121
SortPreservingMergeExec: [revenue@2 DESC], fetch=20
122122
ShuffleReaderExec(stage_id=7, input_partitioning=Hash([Column { name: "c_custkey", index: 0 }, Column { name: "c_name", index: 1 }, Column { name: "c_acctbal", index: 3 }, Column { name: "c_phone", index: 6 }, Column { name: "n_name", index: 4 }, Column { name: "c_address", index: 5 }, Column { name: "c_comment", index: 7 }], 2))
123123

testdata/expected-plans/q11.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ ShuffleWriterExec(stage_id=10, output_partitioning=Hash([Column { name: "ps_part
167167
CoalesceBatchesExec: target_batch_size=8192
168168
ShuffleReaderExec(stage_id=9, input_partitioning=Hash([Column { name: "ps_partkey", index: 0 }], 2))
169169

170-
Query Stage #11 (2 -> 1):
170+
Query Stage #11 (1 -> 1):
171171
SortPreservingMergeExec: [value@1 DESC]
172172
ShuffleReaderExec(stage_id=10, input_partitioning=Hash([Column { name: "ps_partkey", index: 0 }], 2))
173173

testdata/expected-plans/q12.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ ShuffleWriterExec(stage_id=3, output_partitioning=Hash([Column { name: "l_shipmo
6565
CoalesceBatchesExec: target_batch_size=8192
6666
ShuffleReaderExec(stage_id=2, input_partitioning=Hash([Column { name: "l_shipmode", index: 0 }], 2))
6767

68-
Query Stage #4 (2 -> 1):
68+
Query Stage #4 (1 -> 1):
6969
SortPreservingMergeExec: [l_shipmode@0 ASC NULLS LAST]
7070
ShuffleReaderExec(stage_id=3, input_partitioning=Hash([Column { name: "l_shipmode", index: 0 }], 2))
7171

testdata/expected-plans/q13.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ ShuffleWriterExec(stage_id=3, output_partitioning=Hash([Column { name: "c_count"
7070
CoalesceBatchesExec: target_batch_size=8192
7171
ShuffleReaderExec(stage_id=2, input_partitioning=Hash([Column { name: "c_count", index: 0 }], 2))
7272

73-
Query Stage #4 (2 -> 1):
73+
Query Stage #4 (1 -> 1):
7474
SortPreservingMergeExec: [custdist@1 DESC, c_count@0 DESC]
7575
ShuffleReaderExec(stage_id=3, input_partitioning=Hash([Column { name: "c_count", index: 0 }], 2))
7676

0 commit comments

Comments
 (0)