Skip to content

Commit 5fbea4b

Browse files
committed
Defensively verifiying that all children plans have the same count of output partitions
1 parent 151a0e2 commit 5fbea4b

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

src/query_stage.rs

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use datafusion::prelude::SessionContext;
2323
use datafusion_proto::bytes::physical_plan_from_bytes_with_extension_codec;
2424
use pyo3::prelude::*;
2525
use pyo3::types::PyBytes;
26+
use std::collections::HashSet;
2627
use std::sync::Arc;
2728

2829
#[pyclass(name = "QueryStage", module = "datafusion_ray", subclass)]
@@ -99,14 +100,23 @@ impl QueryStage {
99100
/// Get the input partition count. This is the same as the number of concurrent tasks
100101
/// when we schedule this query stage for execution
101102
pub fn get_input_partition_count(&self) -> usize {
102-
if self.plan.children().is_empty() {
103-
// leaf node (file scan)
104-
self.plan.output_partitioning().partition_count()
105-
} else {
106-
self.plan.children()[0]
107-
.output_partitioning()
108-
.partition_count()
103+
let mut output_partition_counts = HashSet::new();
104+
105+
for child in self.plan.children() {
106+
output_partition_counts.insert(child.output_partitioning().partition_count());
107+
if output_partition_counts.len() > 1 {
108+
panic!(
109+
"Children plan of {:#?} have a distinct outout partitioning partition count",
110+
self.plan
111+
);
112+
}
109113
}
114+
// If this stage is a leaf node (file scan), it won't have children
115+
// so we return the partition count of the plan itself
116+
output_partition_counts
117+
.into_iter()
118+
.next()
119+
.unwrap_or(self.plan.output_partitioning().partition_count())
110120
}
111121

112122
pub fn get_output_partition_count(&self) -> usize {

0 commit comments

Comments
 (0)