Skip to content

Commit 35f0b08

Browse files
committed
Merge branch 'main' into gabrielmusat/ergonomy-improvements
# Conflicts: # src/common/ttl_map.rs # tests/tpch_validation_test.rs
2 parents 70c99e1 + 113cc1b commit 35f0b08

File tree

5 files changed

+62
-24
lines changed

5 files changed

+62
-24
lines changed

src/common/ttl_map.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ use dashmap::{DashMap, Entry};
2727
use datafusion::error::DataFusionError;
2828
use std::collections::HashSet;
2929
use std::hash::Hash;
30-
use std::mem;
3130
use std::sync::atomic::AtomicU64;
3231
#[cfg(test)]
3332
use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
@@ -94,7 +93,7 @@ where
9493
shard.insert(key);
9594
}
9695
BucketOp::Clear => {
97-
let keys_to_delete = mem::take(&mut shard);
96+
let keys_to_delete = std::mem::take(&mut shard);
9897
for key in keys_to_delete {
9998
data.remove(&key);
10099
}

src/common/util.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion};
12
use datafusion::error::Result;
23
use datafusion::physical_plan::{displayable, ExecutionPlan, ExecutionPlanProperties};
34

45
use std::fmt::Write;
6+
use std::sync::Arc;
57

68
pub fn display_plan_with_partition_in_out(plan: &dyn ExecutionPlan) -> Result<String> {
79
let mut f = String::new();

src/physical_optimizer.rs

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,14 @@ impl DistributedPhysicalOptimizerRule {
8383
internal_datafusion_err!("Expected RepartitionExec to have a child"),
8484
)?);
8585

86-
let maybe_isolated_plan = if let Some(ppt) = self.partitions_per_task {
87-
let isolated = Arc::new(PartitionIsolatorExec::new(child, ppt));
88-
plan.with_new_children(vec![isolated])?
89-
} else {
90-
plan
91-
};
86+
let maybe_isolated_plan =
87+
if can_be_divided(&plan)? && self.partitions_per_task.is_some() {
88+
let ppt = self.partitions_per_task.unwrap();
89+
let isolated = Arc::new(PartitionIsolatorExec::new(child, ppt));
90+
plan.with_new_children(vec![isolated])?
91+
} else {
92+
plan
93+
};
9294

9395
return Ok(Transformed::yes(Arc::new(
9496
ArrowFlightReadExec::new_pending(
@@ -120,7 +122,7 @@ impl DistributedPhysicalOptimizerRule {
120122
) -> Result<ExecutionStage, DataFusionError> {
121123
let mut inputs = vec![];
122124

123-
let distributed = plan.transform_down(|plan| {
125+
let distributed = plan.clone().transform_down(|plan| {
124126
let Some(node) = plan.as_any().downcast_ref::<ArrowFlightReadExec>() else {
125127
return Ok(Transformed::no(plan));
126128
};
@@ -137,15 +139,45 @@ impl DistributedPhysicalOptimizerRule {
137139
let mut stage = ExecutionStage::new(query_id, *num, distributed.data, inputs);
138140
*num += 1;
139141

140-
if let Some(partitions_per_task) = self.partitions_per_task {
141-
stage = stage.with_maximum_partitions_per_task(partitions_per_task);
142-
}
142+
stage = match (self.partitions_per_task, can_be_divided(&plan)?) {
143+
(Some(partitions_per_task), true) => {
144+
stage.with_maximum_partitions_per_task(partitions_per_task)
145+
}
146+
(_, _) => stage,
147+
};
148+
143149
stage.depth = depth;
144150

145151
Ok(stage)
146152
}
147153
}
148154

155+
/// Returns a boolean indicating if this stage can be divided into more than one task.
156+
///
157+
/// Some Plan nodes need to materialize all partitions inorder to execute such as
158+
/// NestedLoopJoinExec. Rewriting the plan to accommodate dividing it into tasks
159+
/// would result in redundant work.
160+
///
161+
/// The plans we cannot split are:
162+
/// - NestedLoopJoinExec
163+
pub fn can_be_divided(plan: &Arc<dyn ExecutionPlan>) -> Result<bool> {
164+
// recursively check to see if this stages plan contains a NestedLoopJoinExec
165+
let mut has_unsplittable_plan = false;
166+
let search = |f: &Arc<dyn ExecutionPlan>| {
167+
if f.as_any()
168+
.downcast_ref::<datafusion::physical_plan::joins::NestedLoopJoinExec>()
169+
.is_some()
170+
{
171+
has_unsplittable_plan = true;
172+
return Ok(TreeNodeRecursion::Stop);
173+
}
174+
175+
Ok(TreeNodeRecursion::Continue)
176+
};
177+
plan.apply(search)?;
178+
Ok(!has_unsplittable_plan)
179+
}
180+
149181
#[cfg(test)]
150182
mod tests {
151183
use crate::assert_snapshot;

src/test_utils/tpch.rs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,18 @@ where
160160

161161
macro_rules! must_generate_tpch_table {
162162
($generator:ident, $arrow:ident, $name:literal, $data_dir:expr) => {
163-
generate_table(
164-
// TODO: Consider adjusting the partitions and batch sizes.
165-
$arrow::new($generator::new(SCALE_FACTOR, 1, 1)).with_batch_size(1000),
166-
$name,
167-
$data_dir,
168-
)
169-
.expect(concat!("Failed to generate ", $name, " table"));
163+
let data_dir = $data_dir.join(format!("{}.parquet", $name));
164+
fs::create_dir_all(data_dir.clone()).expect("Failed to create data directory");
165+
// create three partitions for the table
166+
(1..=3).for_each(|part| {
167+
generate_table(
168+
// TODO: Consider adjusting the partitions and batch sizes.
169+
$arrow::new($generator::new(SCALE_FACTOR, part, 3)).with_batch_size(1000),
170+
&format!("{}.parquet", part),
171+
&data_dir.clone().into_boxed_path(),
172+
)
173+
.expect(concat!("Failed to generate ", $name, " table"));
174+
});
170175
};
171176
}
172177

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,6 @@ mod tests {
120120
}
121121

122122
#[tokio::test]
123-
// TODO: Add support for NestedLoopJoinExec to support query 22.
124-
#[ignore]
125123
async fn test_tpch_22() -> Result<(), Box<dyn Error>> {
126124
test_tpch_query(22).await
127125
}
@@ -164,7 +162,6 @@ mod tests {
164162
// and once in a non-distributed manner. For each query, it asserts that the results are identical.
165163
async fn run_tpch_query(ctx2: SessionContext, query_id: u8) -> Result<(), Box<dyn Error>> {
166164
ensure_tpch_data().await;
167-
168165
let sql = get_test_tpch_query(query_id);
169166

170167
// Context 1: Non-distributed execution.
@@ -195,19 +192,21 @@ mod tests {
195192
.await?;
196193
}
197194

195+
// Query 15 has three queries in it, one creating the view, the second
196+
// executing, which we want to capture the output of, and the third
197+
// tearing down the view
198198
let (stream1, stream2) = if query_id == 15 {
199199
let queries: Vec<&str> = sql
200200
.split(';')
201201
.map(str::trim)
202202
.filter(|s| !s.is_empty())
203203
.collect();
204204

205-
println!("queryies: {:?}", queries);
206-
207205
ctx1.sql(queries[0]).await?.collect().await?;
208206
ctx2.sql(queries[0]).await?.collect().await?;
209207
let df1 = ctx1.sql(queries[1]).await?;
210208
let df2 = ctx2.sql(queries[1]).await?;
209+
211210
let stream1 = df1.execute_stream().await?;
212211
let stream2 = df2.execute_stream().await?;
213212

@@ -217,6 +216,7 @@ mod tests {
217216
} else {
218217
let stream1 = ctx1.sql(&sql).await?.execute_stream().await?;
219218
let stream2 = ctx2.sql(&sql).await?.execute_stream().await?;
219+
220220
(stream1, stream2)
221221
};
222222

0 commit comments

Comments
 (0)