Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/common/ttl_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ use dashmap::{DashMap, Entry};
use datafusion::error::DataFusionError;
use std::collections::HashSet;
use std::hash::Hash;
use std::mem;
use std::sync::atomic::AtomicU64;
#[cfg(test)]
use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
Expand Down Expand Up @@ -94,7 +93,7 @@ where
shard.insert(key);
}
BucketOp::Clear => {
let keys_to_delete = mem::replace(&mut shard, HashSet::new());
let keys_to_delete = std::mem::replace(&mut shard, HashSet::new());
for key in keys_to_delete {
data.remove(&key);
}
Expand Down
28 changes: 28 additions & 0 deletions src/common/util.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion::error::Result;
use datafusion::physical_plan::{displayable, ExecutionPlan, ExecutionPlanProperties};

use std::fmt::Write;
use std::sync::Arc;

pub fn display_plan_with_partition_in_out(plan: &dyn ExecutionPlan) -> Result<String> {
let mut f = String::new();
Expand Down Expand Up @@ -34,3 +36,29 @@ pub fn display_plan_with_partition_in_out(plan: &dyn ExecutionPlan) -> Result<St
visit(plan, 0, &mut f)?;
Ok(f)
}

/// Returns a boolean indicating if this stage can be divided into more than one task.
///
/// Some Plan nodes need to materialize all partitions inorder to execute such as
/// NestedLoopJoinExec. Rewriting the plan to accommodate dividing it into tasks
/// would result in redundant work.
///
/// The plans we cannot split are:
/// - NestedLoopJoinExec
pub fn can_be_divided(plan: &Arc<dyn ExecutionPlan>) -> Result<bool> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually, if a function is only used in one place, it can be better in the long run to just place it where it's used, as it's not really a common function.

Otherwise, following the same rule, we can end-up with a massive "utils" or "common" modules with a lot of unrelated stuff that is not really commonly used across the project.

This one for example could just be placed in physical_optimizer.rs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha! This is why i always end up with massive utils and common in my projects! Good feedback. I'll move to physical_optimizer.rs. 😅

// recursively check to see if this stages plan contains a NestedLoopJoinExec
let mut has_unsplittable_plan = false;
let search = |f: &Arc<dyn ExecutionPlan>| {
if f.as_any()
.downcast_ref::<datafusion::physical_plan::joins::NestedLoopJoinExec>()
.is_some()
{
has_unsplittable_plan = true;
return Ok(TreeNodeRecursion::Stop);
}

Ok(TreeNodeRecursion::Continue)
};
plan.apply(search)?;
Ok(!has_unsplittable_plan)
}
27 changes: 17 additions & 10 deletions src/physical_optimizer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::sync::Arc;

use super::stage::ExecutionStage;
use crate::common::util::can_be_divided;
use crate::{plan::PartitionIsolatorExec, ArrowFlightReadExec};
use datafusion::common::tree_node::TreeNodeRecursion;
use datafusion::error::DataFusionError;
Expand Down Expand Up @@ -83,12 +84,14 @@ impl DistributedPhysicalOptimizerRule {
internal_datafusion_err!("Expected RepartitionExec to have a child"),
)?);

let maybe_isolated_plan = if let Some(ppt) = self.partitions_per_task {
let isolated = Arc::new(PartitionIsolatorExec::new(child, ppt));
plan.with_new_children(vec![isolated])?
} else {
plan
};
let maybe_isolated_plan =
if can_be_divided(&plan)? && self.partitions_per_task.is_some() {
let ppt = self.partitions_per_task.unwrap();
let isolated = Arc::new(PartitionIsolatorExec::new(child, ppt));
plan.with_new_children(vec![isolated])?
} else {
plan
};

return Ok(Transformed::yes(Arc::new(
ArrowFlightReadExec::new_pending(
Expand Down Expand Up @@ -120,7 +123,7 @@ impl DistributedPhysicalOptimizerRule {
) -> Result<ExecutionStage, DataFusionError> {
let mut inputs = vec![];

let distributed = plan.transform_down(|plan| {
let distributed = plan.clone().transform_down(|plan| {
let Some(node) = plan.as_any().downcast_ref::<ArrowFlightReadExec>() else {
return Ok(Transformed::no(plan));
};
Expand All @@ -137,9 +140,13 @@ impl DistributedPhysicalOptimizerRule {
let mut stage = ExecutionStage::new(query_id, *num, distributed.data, inputs);
*num += 1;

if let Some(partitions_per_task) = self.partitions_per_task {
stage = stage.with_maximum_partitions_per_task(partitions_per_task);
}
stage = match (self.partitions_per_task, can_be_divided(&plan)?) {
(Some(partitions_per_task), true) => {
stage.with_maximum_partitions_per_task(partitions_per_task)
}
(_, _) => stage,
};

stage.depth = depth;

Ok(stage)
Expand Down
19 changes: 12 additions & 7 deletions src/test_utils/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,18 @@ where

macro_rules! must_generate_tpch_table {
($generator:ident, $arrow:ident, $name:literal, $data_dir:expr) => {
generate_table(
// TODO: Consider adjusting the partitions and batch sizes.
$arrow::new($generator::new(SCALE_FACTOR, 1, 1)).with_batch_size(1000),
$name,
$data_dir,
)
.expect(concat!("Failed to generate ", $name, " table"));
let data_dir = $data_dir.join(format!("{}.parquet", $name));
fs::create_dir_all(data_dir.clone()).expect("Failed to create data directory");
// create three partitions for the table
(1..=3).for_each(|part| {
generate_table(
// TODO: Consider adjusting the partitions and batch sizes.
$arrow::new($generator::new(SCALE_FACTOR, part, 3)).with_batch_size(1000),
&format!("{}.parquet", part),
&data_dir.clone().into_boxed_path(),
)
.expect(concat!("Failed to generate ", $name, " table"));
});
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod tests {
use async_trait::async_trait;
use datafusion::error::DataFusionError;
use datafusion::execution::SessionStateBuilder;

use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_distributed::test_utils::localhost::start_localhost_context;
use datafusion_distributed::{DistributedPhysicalOptimizerRule, SessionBuilder};
Expand Down Expand Up @@ -119,8 +120,6 @@ mod tests {
}

#[tokio::test]
// TODO: Add support for NestedLoopJoinExec to support query 22.
#[ignore]
Comment on lines -122 to -123
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

async fn test_tpch_22() -> Result<(), Box<dyn Error>> {
test_tpch_query(22).await
}
Expand Down Expand Up @@ -155,7 +154,6 @@ mod tests {

config.options_mut().optimizer.prefer_hash_join = true;
// end critical options section

let rule = DistributedPhysicalOptimizerRule::new().with_maximum_partitions_per_task(2);
Ok(builder
.with_config(config)
Expand All @@ -174,7 +172,6 @@ mod tests {
// and once in a non-distributed manner. For each query, it asserts that the results are identical.
async fn run_tpch_query(ctx2: SessionContext, query_id: u8) -> Result<(), Box<dyn Error>> {
ensure_tpch_data().await;

let sql = get_test_tpch_query(query_id);

// Context 1: Non-distributed execution.
Expand Down Expand Up @@ -205,19 +202,21 @@ mod tests {
.await?;
}

// Query 15 has three queries in it, one creating the view, the second
// executing, which we want to capture the output of, and the third
// tearing down the view
let (stream1, stream2) = if query_id == 15 {
let queries: Vec<&str> = sql
.split(';')
.map(str::trim)
.filter(|s| !s.is_empty())
.collect();

println!("queryies: {:?}", queries);

ctx1.sql(queries[0]).await?.collect().await?;
ctx2.sql(queries[0]).await?.collect().await?;
let df1 = ctx1.sql(queries[1]).await?;
let df2 = ctx2.sql(queries[1]).await?;

let stream1 = df1.execute_stream().await?;
let stream2 = df2.execute_stream().await?;

Expand All @@ -227,6 +226,7 @@ mod tests {
} else {
let stream1 = ctx1.sql(&sql).await?.execute_stream().await?;
let stream2 = ctx2.sql(&sql).await?.execute_stream().await?;

(stream1, stream2)
};

Expand Down
Loading