diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index 015b660..1f8085b 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -159,10 +159,14 @@ impl DistributedSessionBuilder for RunOpt { builder = builder.with_physical_optimizer_rule(Arc::new(InMemoryDataSourceRule)); } if !self.workers.is_empty() { - let rule = DistributedPhysicalOptimizerRule::new() - .with_network_coalesce_tasks(self.coalesce_tasks.unwrap_or(self.workers.len())) - .with_network_shuffle_tasks(self.shuffle_tasks.unwrap_or(self.workers.len())); - builder = builder.with_physical_optimizer_rule(Arc::new(rule)); + builder = builder + .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule)) + .with_distributed_network_coalesce_tasks( + self.coalesce_tasks.unwrap_or(self.workers.len()), + ) + .with_distributed_network_shuffle_tasks( + self.shuffle_tasks.unwrap_or(self.workers.len()), + ); } Ok(builder diff --git a/examples/in_memory_cluster.rs b/examples/in_memory_cluster.rs index 9778761..9a084f8 100644 --- a/examples/in_memory_cluster.rs +++ b/examples/in_memory_cluster.rs @@ -28,6 +28,12 @@ struct Args { #[structopt(long)] explain: bool, + + #[structopt(long, default_value = "3")] + network_shuffle_tasks: usize, + + #[structopt(long, default_value = "3")] + network_coalesce_tasks: usize, } #[tokio::main] @@ -37,7 +43,9 @@ async fn main() -> Result<(), Box> { let state = SessionStateBuilder::new() .with_default_features() .with_distributed_channel_resolver(InMemoryChannelResolver::new()) - .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule::new())) + .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule)) + .with_distributed_network_coalesce_tasks(args.network_shuffle_tasks) + .with_distributed_network_shuffle_tasks(args.network_coalesce_tasks) .build(); let ctx = SessionContext::from(state); diff --git a/examples/localhost_run.rs b/examples/localhost_run.rs index bfdb87d..0035a0c 100644 --- a/examples/localhost_run.rs +++ b/examples/localhost_run.rs @@ -48,11 +48,9 @@ async fn main() -> Result<(), Box> { let state = SessionStateBuilder::new() .with_default_features() .with_distributed_channel_resolver(localhost_resolver) - .with_physical_optimizer_rule(Arc::new( - DistributedPhysicalOptimizerRule::new() - .with_network_coalesce_tasks(args.network_coalesce_tasks) - .with_network_shuffle_tasks(args.network_shuffle_tasks), - )) + .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule)) + .with_distributed_network_coalesce_tasks(args.network_coalesce_tasks) + .with_distributed_network_shuffle_tasks(args.network_shuffle_tasks) .build(); let ctx = SessionContext::from(state); diff --git a/src/distributed_ext.rs b/src/distributed_ext.rs index 4a7b483..ddcb008 100644 --- a/src/distributed_ext.rs +++ b/src/distributed_ext.rs @@ -1,9 +1,12 @@ -use crate::ChannelResolver; use crate::channel_resolver_ext::set_distributed_channel_resolver; use crate::config_extension_ext::{ set_distributed_option_extension, set_distributed_option_extension_from_headers, }; +use crate::distributed_planner::{ + set_distributed_network_coalesce_tasks, set_distributed_network_shuffle_tasks, +}; use crate::protobuf::{set_distributed_user_codec, set_distributed_user_codec_arc}; +use crate::{ChannelResolver, IntoPlanDependentUsize}; use datafusion::common::DataFusionError; use datafusion::config::ConfigExtension; use datafusion::execution::{SessionState, SessionStateBuilder}; @@ -221,6 +224,32 @@ pub trait DistributedExt: Sized { &mut self, resolver: T, ); + + /// Upon merging multiple tasks into one, this defines how many tasks are merged. + /// ```text + /// ( task 1 ) + /// ▲ + /// ┌───────────┴──────────┐ + /// ( task 1 ) ( task 2 ) ( task 3 ) N tasks + /// ``` + /// This parameter defines N + fn with_distributed_network_coalesce_tasks(self, tasks: T) -> Self; + + /// Same as [DistributedExt::with_distributed_network_coalesce_tasks] but with an in-place mutation. + fn set_distributed_network_coalesce_tasks(&mut self, tasks: T); + + /// Upon shuffling data, this defines how many tasks are employed into performing the shuffling. + /// ```text + /// ( task 1 ) ( task 2 ) ( task 3 ) + /// ▲ ▲ ▲ + /// └────┬──────┴─────┬────┘ + /// ( task 1 ) ( task 2 ) N tasks + /// ``` + /// This parameter defines N + fn with_distributed_network_shuffle_tasks(self, tasks: T) -> Self; + + /// Same as [DistributedExt::with_distributed_network_shuffle_tasks] but with an in-place mutation. + fn set_distributed_network_shuffle_tasks(&mut self, tasks: T); } impl DistributedExt for SessionConfig { @@ -253,6 +282,14 @@ impl DistributedExt for SessionConfig { set_distributed_channel_resolver(self, resolver) } + fn set_distributed_network_coalesce_tasks(&mut self, tasks: T) { + set_distributed_network_coalesce_tasks(self, tasks) + } + + fn set_distributed_network_shuffle_tasks(&mut self, tasks: T) { + set_distributed_network_shuffle_tasks(self, tasks) + } + delegate! { to self { #[call(set_distributed_option_extension)] @@ -274,6 +311,14 @@ impl DistributedExt for SessionConfig { #[call(set_distributed_channel_resolver)] #[expr($;self)] fn with_distributed_channel_resolver(mut self, resolver: T) -> Self; + + #[call(set_distributed_network_coalesce_tasks)] + #[expr($;self)] + fn with_distributed_network_coalesce_tasks(mut self, tasks: T) -> Self; + + #[call(set_distributed_network_shuffle_tasks)] + #[expr($;self)] + fn with_distributed_network_shuffle_tasks(mut self, tasks: T) -> Self; } } } @@ -305,6 +350,16 @@ impl DistributedExt for SessionStateBuilder { #[call(set_distributed_channel_resolver)] #[expr($;self)] fn with_distributed_channel_resolver(mut self, resolver: T) -> Self; + + fn set_distributed_network_coalesce_tasks(&mut self, tasks: T); + #[call(set_distributed_network_coalesce_tasks)] + #[expr($;self)] + fn with_distributed_network_coalesce_tasks(mut self, tasks: T) -> Self; + + fn set_distributed_network_shuffle_tasks(&mut self, tasks: T); + #[call(set_distributed_network_shuffle_tasks)] + #[expr($;self)] + fn with_distributed_network_shuffle_tasks(mut self, tasks: T) -> Self; } } } @@ -336,6 +391,16 @@ impl DistributedExt for SessionState { #[call(set_distributed_channel_resolver)] #[expr($;self)] fn with_distributed_channel_resolver(mut self, resolver: T) -> Self; + + fn set_distributed_network_coalesce_tasks(&mut self, tasks: T); + #[call(set_distributed_network_coalesce_tasks)] + #[expr($;self)] + fn with_distributed_network_coalesce_tasks(mut self, tasks: T) -> Self; + + fn set_distributed_network_shuffle_tasks(&mut self, tasks: T); + #[call(set_distributed_network_shuffle_tasks)] + #[expr($;self)] + fn with_distributed_network_shuffle_tasks(mut self, tasks: T) -> Self; } } } @@ -367,6 +432,16 @@ impl DistributedExt for SessionContext { #[call(set_distributed_channel_resolver)] #[expr($;self)] fn with_distributed_channel_resolver(self, resolver: T) -> Self; + + fn set_distributed_network_coalesce_tasks(&mut self, tasks: T); + #[call(set_distributed_network_coalesce_tasks)] + #[expr($;self)] + fn with_distributed_network_coalesce_tasks(self, tasks: T) -> Self; + + fn set_distributed_network_shuffle_tasks(&mut self, tasks: T); + #[call(set_distributed_network_shuffle_tasks)] + #[expr($;self)] + fn with_distributed_network_shuffle_tasks(self, tasks: T) -> Self; } } } diff --git a/src/distributed_planner/distributed_config.rs b/src/distributed_planner/distributed_config.rs new file mode 100644 index 0000000..51572e1 --- /dev/null +++ b/src/distributed_planner/distributed_config.rs @@ -0,0 +1,126 @@ +use datafusion::common::extensions_options; +use datafusion::config::{ConfigExtension, ConfigField, Visit, default_config_transform}; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionConfig; +use std::fmt::{Debug, Display, Formatter}; +use std::sync::Arc; + +#[derive(Clone)] +#[allow(clippy::type_complexity)] +pub struct PlanDependentUsize( + pub(crate) Arc) -> usize + Send + Sync>, +); + +impl PlanDependentUsize { + pub fn call(&self, plan: &Arc) -> usize { + self.0(plan) + } +} + +pub trait IntoPlanDependentUsize { + fn into_plan_dependent_usize(self) -> PlanDependentUsize; +} + +impl IntoPlanDependentUsize for usize { + fn into_plan_dependent_usize(self) -> PlanDependentUsize { + PlanDependentUsize(Arc::new(move |_| self)) + } +} + +impl) -> usize + Send + Sync + 'static> IntoPlanDependentUsize for T { + fn into_plan_dependent_usize(self) -> PlanDependentUsize { + PlanDependentUsize(Arc::new(self)) + } +} + +impl Default for PlanDependentUsize { + fn default() -> Self { + PlanDependentUsize(Arc::new(|_| 0)) + } +} + +impl Debug for PlanDependentUsize { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "PlanDependantUsize") + } +} + +impl Display for PlanDependentUsize { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "PlanDependantUsize") + } +} + +impl ConfigField for PlanDependentUsize { + fn visit(&self, v: &mut V, key: &str, description: &'static str) { + v.some(key, self, description); + } + + fn set(&mut self, _: &str, value: &str) -> datafusion::common::Result<()> { + *self = default_config_transform::(value)?.into_plan_dependent_usize(); + Ok(()) + } +} + +extensions_options! { + pub struct DistributedConfig { + /// Upon shuffling data, this defines how many tasks are employed into performing the shuffling. + /// ```text + /// ( task 1 ) ( task 2 ) ( task 3 ) + /// ▲ ▲ ▲ + /// └────┬──────┴─────┬────┘ + /// ( task 1 ) ( task 2 ) N tasks + /// ``` + /// This parameter defines N + pub network_shuffle_tasks: Option, default = None + /// Upon merging multiple tasks into one, this defines how many tasks are merged. + /// ```text + /// ( task 1 ) + /// ▲ + /// ┌───────────┴──────────┐ + /// ( task 1 ) ( task 2 ) ( task 3 ) N tasks + /// ``` + /// This parameter defines N + pub network_coalesce_tasks: Option, default = None + } +} + +impl ConfigExtension for DistributedConfig { + const PREFIX: &'static str = "distributed"; +} + +impl DistributedConfig { + /// Sets the amount of tasks used in a network shuffle operation. + pub fn with_network_shuffle_tasks(mut self, tasks: impl IntoPlanDependentUsize) -> Self { + self.network_shuffle_tasks = Some(tasks.into_plan_dependent_usize()); + self + } + + /// Sets the amount of tasks used in a network coalesce operation. + pub fn with_network_coalesce_tasks(mut self, tasks: impl IntoPlanDependentUsize) -> Self { + self.network_coalesce_tasks = Some(tasks.into_plan_dependent_usize()); + self + } +} + +pub(crate) fn set_distributed_network_coalesce_tasks( + cfg: &mut SessionConfig, + tasks: impl IntoPlanDependentUsize, +) { + let ext = &mut cfg.options_mut().extensions; + let Some(prev) = ext.get_mut::() else { + return ext.insert(DistributedConfig::default().with_network_coalesce_tasks(tasks)); + }; + prev.network_coalesce_tasks = Some(tasks.into_plan_dependent_usize()); +} + +pub(crate) fn set_distributed_network_shuffle_tasks( + cfg: &mut SessionConfig, + tasks: impl IntoPlanDependentUsize, +) { + let ext = &mut cfg.options_mut().extensions; + let Some(prev) = ext.get_mut::() else { + return ext.insert(DistributedConfig::default().with_network_shuffle_tasks(tasks)); + }; + prev.network_shuffle_tasks = Some(tasks.into_plan_dependent_usize()); +} diff --git a/src/distributed_planner/distributed_physical_optimizer_rule.rs b/src/distributed_planner/distributed_physical_optimizer_rule.rs index d3ffc48..a964a43 100644 --- a/src/distributed_planner/distributed_physical_optimizer_rule.rs +++ b/src/distributed_planner/distributed_physical_optimizer_rule.rs @@ -1,3 +1,4 @@ +use crate::distributed_planner::distributed_config::DistributedConfig; use crate::distributed_planner::distributed_plan_error::get_distribute_plan_err; use crate::distributed_planner::{ DistributedPlanError, NetworkBoundaryExt, limit_tasks_err, non_distributable_err, @@ -54,57 +55,19 @@ use uuid::Uuid; /// like when a plan is not parallelizable in different tasks (e.g. a collect left [HashJoinExec]) /// or when a [DataSourceExec] has not enough partitions to be spread across tasks. #[derive(Debug, Default)] -pub struct DistributedPhysicalOptimizerRule { - /// Upon shuffling data, this defines how many tasks are employed into performing the shuffling. - /// ```text - /// ( task 1 ) ( task 2 ) ( task 3 ) - /// ▲ ▲ ▲ - /// └────┬──────┴─────┬────┘ - /// ( task 1 ) ( task 2 ) N tasks - /// ``` - /// This parameter defines N - network_shuffle_tasks: Option, - /// Upon merging multiple tasks into one, this defines how many tasks are merged. - /// ```text - /// ( task 1 ) - /// ▲ - /// ┌───────────┴──────────┐ - /// ( task 1 ) ( task 2 ) ( task 3 ) N tasks - /// ``` - /// This parameter defines N - network_coalesce_tasks: Option, -} - -impl DistributedPhysicalOptimizerRule { - pub fn new() -> Self { - DistributedPhysicalOptimizerRule { - network_shuffle_tasks: None, - network_coalesce_tasks: None, - } - } - - /// Sets the amount of tasks employed in performing shuffles. - pub fn with_network_shuffle_tasks(mut self, tasks: usize) -> Self { - self.network_shuffle_tasks = Some(tasks); - self - } - - /// Sets the amount of input tasks for every task coalescing operation. - pub fn with_network_coalesce_tasks(mut self, tasks: usize) -> Self { - self.network_coalesce_tasks = Some(tasks); - self - } -} +pub struct DistributedPhysicalOptimizerRule; impl PhysicalOptimizerRule for DistributedPhysicalOptimizerRule { fn optimize( &self, plan: Arc, - _config: &ConfigOptions, + config: &ConfigOptions, ) -> Result> { + let Some(cfg) = config.extensions.get::() else { + return Ok(plan); + }; // We can only optimize plans that are not already distributed - let plan = self.apply_network_boundaries(plan)?; - Self::distribute_plan(plan) + distribute_plan(apply_network_boundaries(plan, cfg)?) } fn name(&self) -> &str { @@ -116,196 +79,216 @@ impl PhysicalOptimizerRule for DistributedPhysicalOptimizerRule { } } -impl DistributedPhysicalOptimizerRule { - fn apply_network_boundaries( - &self, - mut plan: Arc, - ) -> Result, DataFusionError> { - if plan.output_partitioning().partition_count() > 1 { - // Coalescing partitions here will allow us to put a NetworkCoalesceExec on top - // of the plan, executing it in parallel. - plan = Arc::new(CoalescePartitionsExec::new(plan)) - } +/// Places the appropriate [NetworkBoundary]s in the plan. It will look for certain nodes in the +/// provided plan and wrap them with their distributed equivalent, for example: +/// - A [RepartitionExec] will be wrapped with a [NetworkShuffleExec] for performing the +/// repartition over the network (shuffling). +/// - A [CoalescePartitionsExec] and a [SortPreservingMergeExec] both coalesce P partitions into +/// one, so a [NetworkCoalesceExec] is injected right below them to also coalesce distributed +/// tasks. +/// - A [DataSourceExec] is wrapped with a [PartitionIsolatorExec] so that each distributed task +/// only executes a certain amount of partitions. +/// +/// How many tasks are employed in each step is controlled by the user through [DistributedConfig]. +pub fn apply_network_boundaries( + mut plan: Arc, + cfg: &DistributedConfig, +) -> Result, DataFusionError> { + if plan.output_partitioning().partition_count() > 1 { + // Coalescing partitions here will allow us to put a NetworkCoalesceExec on top + // of the plan, executing it in parallel. + plan = Arc::new(CoalescePartitionsExec::new(plan)) + } - let result = - plan.transform_up(|plan| { - // If this node is a DataSourceExec, we need to wrap it with PartitionIsolatorExec so - // that not all tasks have access to all partitions of the underlying DataSource. - if plan.as_any().is::() { - let node = PartitionIsolatorExec::new(plan); + let result = plan.transform_up(|plan| { + // If this node is a DataSourceExec, we need to wrap it with PartitionIsolatorExec so + // that not all tasks have access to all partitions of the underlying DataSource. + if plan.as_any().is::() { + let node = PartitionIsolatorExec::new(plan); - return Ok(Transformed::yes(Arc::new(node))); - } + return Ok(Transformed::yes(Arc::new(node))); + } - // If this is a hash RepartitionExec, introduce a shuffle. - if let (Some(node), Some(tasks)) = ( - plan.as_any().downcast_ref::(), - self.network_shuffle_tasks, - ) { - if !matches!(node.partitioning(), Partitioning::Hash(_, _)) { - return Ok(Transformed::no(plan)); - } - let node = NetworkShuffleExec::try_new(plan, tasks)?; + // If this is a hash RepartitionExec, introduce a shuffle. + if let (Some(node), Some(tasks)) = ( + plan.as_any().downcast_ref::(), + cfg.network_shuffle_tasks.clone(), + ) { + if !matches!(node.partitioning(), Partitioning::Hash(_, _)) { + return Ok(Transformed::no(plan)); + } + let input_tasks = tasks.0(&plan); + if input_tasks == 0 { + return Ok(Transformed::no(plan)); + } + let node = NetworkShuffleExec::try_new(plan, input_tasks)?; - return Ok(Transformed::yes(Arc::new(node))); - } + return Ok(Transformed::yes(Arc::new(node))); + } - // If this is a CoalescePartitionsExec, it means that the original plan is trying to - // merge all partitions into one. We need to go one step ahead and also merge all tasks - // into one. - if let (Some(node), Some(tasks)) = ( - plan.as_any().downcast_ref::(), - self.network_coalesce_tasks, - ) { - // If the immediate child is a PartitionIsolatorExec, it means that the rest of the - // plan is just a couple of non-computational nodes that are probably not worth - // distributing. - if node.input().as_any().is::() { - return Ok(Transformed::no(plan)); - } + // If this is a CoalescePartitionsExec, it means that the original plan is trying to + // merge all partitions into one. We need to go one step ahead and also merge all tasks + // into one. + if let (Some(node), Some(tasks)) = ( + plan.as_any().downcast_ref::(), + cfg.network_coalesce_tasks.clone(), + ) { + // If the immediate child is a PartitionIsolatorExec, it means that the rest of the + // plan is just a couple of non-computational nodes that are probably not worth + // distributing. + if node.input().as_any().is::() { + return Ok(Transformed::no(plan)); + } - let plan = plan.clone().with_new_children(vec![Arc::new( - NetworkCoalesceExec::new(Arc::clone(node.input()), tasks), - )])?; + let input_tasks = tasks.0(&plan); + if input_tasks == 0 { + return Ok(Transformed::no(plan)); + } + let plan = Arc::clone(&plan).with_new_children(vec![Arc::new( + NetworkCoalesceExec::new(Arc::clone(node.input()), input_tasks), + )])?; - return Ok(Transformed::yes(plan)); - } + return Ok(Transformed::yes(plan)); + } - // The SortPreservingMergeExec node will try to coalesce all partitions into just 1. - // We need to account for it and help it by also coalescing all tasks into one, therefore - // a NetworkCoalesceExec is introduced. - if let (Some(node), Some(tasks)) = ( - plan.as_any().downcast_ref::(), - self.network_coalesce_tasks, - ) { - let plan = plan.clone().with_new_children(vec![Arc::new( - NetworkCoalesceExec::new(Arc::clone(node.input()), tasks), - )])?; - - return Ok(Transformed::yes(plan)); - } + // The SortPreservingMergeExec node will try to coalesce all partitions into just 1. + // We need to account for it and help it by also coalescing all tasks into one, therefore + // a NetworkCoalesceExec is introduced. + if let (Some(node), Some(tasks)) = ( + plan.as_any().downcast_ref::(), + cfg.network_coalesce_tasks.clone(), + ) { + let input_tasks = tasks.0(&plan); + if input_tasks == 0 { + return Ok(Transformed::no(plan)); + } + let plan = Arc::clone(&plan).with_new_children(vec![Arc::new( + NetworkCoalesceExec::new(Arc::clone(node.input()), input_tasks), + )])?; - Ok(Transformed::no(plan)) - })?; - Ok(result.data) - } + return Ok(Transformed::yes(plan)); + } - /// Takes a plan with certain network boundaries in it ([NetworkShuffleExec], [NetworkCoalesceExec], ...) - /// and breaks it down into stages. - /// - /// This can be used a standalone function for distributing arbitrary plans in which users have - /// manually placed network boundaries, or as part of the [DistributedPhysicalOptimizerRule] that - /// places the network boundaries automatically as a standard [PhysicalOptimizerRule]. - pub fn distribute_plan( - plan: Arc, - ) -> Result, DataFusionError> { - let stage = match Self::_distribute_plan_inner(Uuid::new_v4(), plan.clone(), &mut 1, 0, 1) { - Ok(stage) => stage, - Err(err) => { - return match get_distribute_plan_err(&err) { - Some(DistributedPlanError::NonDistributable(_)) => plan - .transform_down(|plan| { - // If the node cannot be distributed, rollback all the network boundaries. - if let Some(nb) = plan.as_network_boundary() { - return Ok(Transformed::yes(nb.rollback()?)); - } - Ok(Transformed::no(plan)) - }) - .map(|v| v.data), - _ => Err(err), - }; - } - }; - let plan = stage.plan.decoded()?; - Ok(Arc::new(DistributedExec::new(Arc::clone(plan)))) - } + Ok(Transformed::no(plan)) + })?; + Ok(result.data) +} - fn _distribute_plan_inner( - query_id: Uuid, - plan: Arc, - num: &mut usize, - depth: usize, - n_tasks: usize, - ) -> Result { - let mut distributed = plan.clone().transform_down(|plan| { - // We cannot break down CollectLeft hash joins into more than 1 task, as these need - // a full materialized build size with all the data in it. - // - // Maybe in the future these can be broadcast joins? - if let Some(node) = plan.as_any().downcast_ref::() { - if n_tasks > 1 && node.mode == PartitionMode::CollectLeft { - return Err(limit_tasks_err(1)); - } - } +/// Takes a plan with certain network boundaries in it ([NetworkShuffleExec], [NetworkCoalesceExec], ...) +/// and breaks it down into stages. +/// +/// This can be used a standalone function for distributing arbitrary plans in which users have +/// manually placed network boundaries, or as part of the [DistributedPhysicalOptimizerRule] that +/// places the network boundaries automatically as a standard [PhysicalOptimizerRule]. +pub fn distribute_plan( + plan: Arc, +) -> Result, DataFusionError> { + let stage = match _distribute_plan_inner(Uuid::new_v4(), plan.clone(), &mut 1, 0, 1) { + Ok(stage) => stage, + Err(err) => { + return match get_distribute_plan_err(&err) { + Some(DistributedPlanError::NonDistributable(_)) => plan + .transform_down(|plan| { + // If the node cannot be distributed, rollback all the network boundaries. + if let Some(nb) = plan.as_network_boundary() { + return Ok(Transformed::yes(nb.rollback()?)); + } + Ok(Transformed::no(plan)) + }) + .map(|v| v.data), + _ => Err(err), + }; + } + }; + let plan = stage.plan.decoded()?; + Ok(Arc::new(DistributedExec::new(Arc::clone(plan)))) +} - // We cannot distribute [StreamingTableExec] nodes, so abort distribution. - if plan.as_any().is::() { - return Err(non_distributable_err(StreamingTableExec::static_name())) +fn _distribute_plan_inner( + query_id: Uuid, + plan: Arc, + num: &mut usize, + depth: usize, + n_tasks: usize, +) -> Result { + let mut distributed = plan.clone().transform_down(|plan| { + // We cannot break down CollectLeft hash joins into more than 1 task, as these need + // a full materialized build size with all the data in it. + // + // Maybe in the future these can be broadcast joins? + if let Some(node) = plan.as_any().downcast_ref::() { + if n_tasks > 1 && node.mode == PartitionMode::CollectLeft { + return Err(limit_tasks_err(1)); } + } - if let Some(node) = plan.as_any().downcast_ref::() { - // If there's only 1 task, no need to perform any isolation. - if n_tasks == 1 { - return Ok(Transformed::yes(Arc::clone(plan.children().first().unwrap()))); - } - let node = node.ready(n_tasks)?; - return Ok(Transformed::new(Arc::new(node), true, TreeNodeRecursion::Jump)); + // We cannot distribute [StreamingTableExec] nodes, so abort distribution. + if plan.as_any().is::() { + return Err(non_distributable_err(StreamingTableExec::static_name())) + } + + if let Some(node) = plan.as_any().downcast_ref::() { + // If there's only 1 task, no need to perform any isolation. + if n_tasks == 1 { + return Ok(Transformed::yes(Arc::clone(plan.children().first().unwrap()))); } + let node = node.ready(n_tasks)?; + return Ok(Transformed::new(Arc::new(node), true, TreeNodeRecursion::Jump)); + } - let Some(mut dnode) = plan.as_network_boundary().map(Referenced::Borrowed) else { - return Ok(Transformed::no(plan)); - }; + let Some(mut dnode) = plan.as_network_boundary().map(Referenced::Borrowed) else { + return Ok(Transformed::no(plan)); + }; - let stage = loop { - let input_stage_info = dnode.as_ref().get_input_stage_info(n_tasks)?; - // If the current stage has just 1 task, and the next stage is only going to have - // 1 task, there's no point in having a network boundary in between, they can just - // communicate in memory. - if n_tasks == 1 && input_stage_info.task_count == 1 { - let mut n = dnode.as_ref().rollback()?; - if let Some(node) = n.as_any().downcast_ref::() { - // Also trim PartitionIsolatorExec out of the plan. - n = Arc::clone(node.children().first().unwrap()); - } - return Ok(Transformed::yes(n)); + let stage = loop { + let input_stage_info = dnode.as_ref().get_input_stage_info(n_tasks)?; + // If the current stage has just 1 task, and the next stage is only going to have + // 1 task, there's no point in having a network boundary in between, they can just + // communicate in memory. + if n_tasks == 1 && input_stage_info.task_count == 1 { + let mut n = dnode.as_ref().rollback()?; + if let Some(node) = n.as_any().downcast_ref::() { + // Also trim PartitionIsolatorExec out of the plan. + n = Arc::clone(node.children().first().unwrap()); } - match Self::_distribute_plan_inner(query_id, input_stage_info.plan, num, depth + 1, input_stage_info.task_count) { - Ok(v) => break v, - Err(e) => match get_distribute_plan_err(&e) { - None => return Err(e), - Some(DistributedPlanError::LimitTasks(limit)) => { - // While attempting to build a new stage, a failure was raised stating - // that no more than `limit` tasks can be used for it, so we are going - // to limit the amount of tasks to the requested number and try building - // the stage again. - if input_stage_info.task_count == *limit { - return plan_err!("A node requested {limit} tasks for the stage its in, but that stage already has that many tasks"); - } - dnode = Referenced::Arced(dnode.as_ref().with_input_task_count(*limit)?); - } - Some(DistributedPlanError::NonDistributable(_)) => { - // This full plan is non-distributable, so abort any task and stage - // assignation. - return Err(e); + return Ok(Transformed::yes(n)); + } + match _distribute_plan_inner(query_id, input_stage_info.plan, num, depth + 1, input_stage_info.task_count) { + Ok(v) => break v, + Err(e) => match get_distribute_plan_err(&e) { + None => return Err(e), + Some(DistributedPlanError::LimitTasks(limit)) => { + // While attempting to build a new stage, a failure was raised stating + // that no more than `limit` tasks can be used for it, so we are going + // to limit the amount of tasks to the requested number and try building + // the stage again. + if input_stage_info.task_count == *limit { + return plan_err!("A node requested {limit} tasks for the stage its in, but that stage already has that many tasks"); } - }, - } - }; - let node = dnode.as_ref().with_input_stage(stage)?; - Ok(Transformed::new(node, true, TreeNodeRecursion::Jump)) - })?; - - // The head stage is executable, and upon execution, it will lazily assign worker URLs to - // all tasks. This must only be done once, so the executable StageExec must only be called - // once on 1 partition. - if depth == 0 && distributed.data.output_partitioning().partition_count() > 1 { - distributed.data = Arc::new(CoalescePartitionsExec::new(distributed.data)); - } - - let stage = Stage::new(query_id, *num, distributed.data, n_tasks); - *num += 1; - Ok(stage) + dnode = Referenced::Arced(dnode.as_ref().with_input_task_count(*limit)?); + } + Some(DistributedPlanError::NonDistributable(_)) => { + // This full plan is non-distributable, so abort any task and stage + // assignation. + return Err(e); + } + }, + } + }; + let node = dnode.as_ref().with_input_stage(stage)?; + Ok(Transformed::new(node, true, TreeNodeRecursion::Jump)) + })?; + + // The head stage is executable, and upon execution, it will lazily assign worker URLs to + // all tasks. This must only be done once, so the executable StageExec must only be called + // once on 1 partition. + if depth == 0 && distributed.data.output_partitioning().partition_count() > 1 { + distributed.data = Arc::new(CoalescePartitionsExec::new(distributed.data)); } + + let stage = Stage::new(query_id, *num, distributed.data, n_tasks); + *num += 1; + Ok(stage) } /// Helper enum for storing either borrowed or owned trait object references enum Referenced<'a, T: ?Sized> { @@ -324,8 +307,8 @@ impl Referenced<'_, T> { #[cfg(test)] mod tests { - use crate::distributed_planner::distributed_physical_optimizer_rule::DistributedPhysicalOptimizerRule; use crate::test_utils::parquet::register_parquet_tables; + use crate::{DistributedConfig, DistributedPhysicalOptimizerRule}; use crate::{assert_snapshot, display_plan_ascii}; use datafusion::error::DataFusionError; use datafusion::execution::SessionStateBuilder; @@ -359,8 +342,13 @@ mod tests { #[tokio::test] async fn test_select_all() { - let query = r#"SELECT * FROM weather"#; - let plan = sql_to_explain(query, 1).await.unwrap(); + let query = r#" + SET distributed.network_coalesce_tasks = 2; + SET distributed.network_shuffle_tasks = 2; + + SELECT * FROM weather + "#; + let plan = sql_to_explain(query).await.unwrap(); assert_snapshot!(plan, @r" ┌───── DistributedExec ── Tasks: t0:[p0] │ CoalescePartitionsExec @@ -371,9 +359,13 @@ mod tests { #[tokio::test] async fn test_aggregation() { - let query = - r#"SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)"#; - let plan = sql_to_explain(query, 2).await.unwrap(); + let query = r#" + SET distributed.network_coalesce_tasks = 2; + SET distributed.network_shuffle_tasks = 2; + + SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*) + "#; + let plan = sql_to_explain(query).await.unwrap(); assert_snapshot!(plan, @r" ┌───── DistributedExec ── Tasks: t0:[p0] │ ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday] @@ -399,9 +391,13 @@ mod tests { #[tokio::test] async fn test_aggregation_with_partitions_per_task() { - let query = - r#"SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)"#; - let plan = sql_to_explain(query, 2).await.unwrap(); + let query = r#" + SET distributed.network_coalesce_tasks = 2; + SET distributed.network_shuffle_tasks = 2; + + SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*) + "#; + let plan = sql_to_explain(query).await.unwrap(); assert_snapshot!(plan, @r" ┌───── DistributedExec ── Tasks: t0:[p0] │ ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday] @@ -427,8 +423,13 @@ mod tests { #[tokio::test] async fn test_left_join() { - let query = r#"SELECT a."MinTemp", b."MaxTemp" FROM weather a LEFT JOIN weather b ON a."RainToday" = b."RainToday" "#; - let plan = sql_to_explain(query, 2).await.unwrap(); + let query = r#" + SET distributed.network_coalesce_tasks = 2; + SET distributed.network_shuffle_tasks = 2; + + SELECT a."MinTemp", b."MaxTemp" FROM weather a LEFT JOIN weather b ON a."RainToday" = b."RainToday" + "#; + let plan = sql_to_explain(query).await.unwrap(); assert_snapshot!(plan, @r" ┌───── DistributedExec ── Tasks: t0:[p0] │ CoalescePartitionsExec @@ -444,6 +445,9 @@ mod tests { #[tokio::test] async fn test_left_join_distributed() { let query = r#" + SET distributed.network_coalesce_tasks = 2; + SET distributed.network_shuffle_tasks = 2; + WITH a AS ( SELECT AVG("MinTemp") as "MinTemp", @@ -465,9 +469,8 @@ mod tests { FROM a LEFT JOIN b ON a."RainTomorrow" = b."RainTomorrow" - "#; - let plan = sql_to_explain(query, 2).await.unwrap(); + let plan = sql_to_explain(query).await.unwrap(); assert_snapshot!(plan, @r" ┌───── DistributedExec ── Tasks: t0:[p0] │ CoalescePartitionsExec @@ -509,8 +512,13 @@ mod tests { #[tokio::test] async fn test_sort() { - let query = r#"SELECT * FROM weather ORDER BY "MinTemp" DESC "#; - let plan = sql_to_explain(query, 2).await.unwrap(); + let query = r#" + SET distributed.network_coalesce_tasks = 2; + SET distributed.network_shuffle_tasks = 2; + + SELECT * FROM weather ORDER BY "MinTemp" DESC + "#; + let plan = sql_to_explain(query).await.unwrap(); assert_snapshot!(plan, @r" ┌───── DistributedExec ── Tasks: t0:[p0] │ SortPreservingMergeExec: [MinTemp@0 DESC] @@ -526,8 +534,13 @@ mod tests { #[tokio::test] async fn test_distinct() { - let query = r#"SELECT DISTINCT "RainToday", "WindGustDir" FROM weather"#; - let plan = sql_to_explain(query, 2).await.unwrap(); + let query = r#" + SET distributed.network_coalesce_tasks = 2; + SET distributed.network_shuffle_tasks = 2; + + SELECT DISTINCT "RainToday", "WindGustDir" FROM weather + "#; + let plan = sql_to_explain(query).await.unwrap(); assert_snapshot!(plan, @r" ┌───── DistributedExec ── Tasks: t0:[p0] │ CoalescePartitionsExec @@ -550,8 +563,13 @@ mod tests { #[tokio::test] async fn test_show_columns() { - let query = r#"SHOW COLUMNS from weather"#; - let plan = sql_to_explain(query, 2).await.unwrap(); + let query = r#" + SET distributed.network_coalesce_tasks = 2; + SET distributed.network_shuffle_tasks = 2; + + SHOW COLUMNS from weather + "#; + let plan = sql_to_explain(query).await.unwrap(); assert_snapshot!(plan, @r" CoalescePartitionsExec ProjectionExec: expr=[table_catalog@0 as table_catalog, table_schema@1 as table_schema, table_name@2 as table_name, column_name@3 as column_name, data_type@5 as data_type, is_nullable@4 as is_nullable] @@ -562,36 +580,27 @@ mod tests { "); } - async fn sql_to_explain(query: &str, tasks: usize) -> Result { - sql_to_explain_with_rule( - query, - DistributedPhysicalOptimizerRule::new() - .with_network_shuffle_tasks(tasks) - .with_network_coalesce_tasks(tasks), - ) - .await - } - - async fn sql_to_explain_with_rule( - query: &str, - rule: DistributedPhysicalOptimizerRule, - ) -> Result { + async fn sql_to_explain(query: &str) -> Result { let config = SessionConfig::new() .with_target_partitions(4) + .with_option_extension(DistributedConfig::default()) .with_information_schema(true); let state = SessionStateBuilder::new() .with_default_features() - .with_physical_optimizer_rule(Arc::new(rule)) .with_config(config) + .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule)) .build(); let ctx = SessionContext::new_with_state(state); register_parquet_tables(&ctx).await?; - let df = ctx.sql(query).await?; + let mut df = None; + for query in query.split(";") { + df = Some(ctx.sql(query).await?); + } - let physical_plan = df.create_physical_plan().await?; + let physical_plan = df.unwrap().create_physical_plan().await?; Ok(display_plan_ascii(physical_plan.as_ref(), false)) } } diff --git a/src/distributed_planner/mod.rs b/src/distributed_planner/mod.rs index 23f24e8..b7f1809 100644 --- a/src/distributed_planner/mod.rs +++ b/src/distributed_planner/mod.rs @@ -1,7 +1,15 @@ +mod distributed_config; mod distributed_physical_optimizer_rule; mod distributed_plan_error; mod network_boundary; -pub use distributed_physical_optimizer_rule::DistributedPhysicalOptimizerRule; +pub(crate) use distributed_config::{ + set_distributed_network_coalesce_tasks, set_distributed_network_shuffle_tasks, +}; + +pub use distributed_config::{DistributedConfig, IntoPlanDependentUsize, PlanDependentUsize}; +pub use distributed_physical_optimizer_rule::{ + DistributedPhysicalOptimizerRule, apply_network_boundaries, distribute_plan, +}; pub use distributed_plan_error::{DistributedPlanError, limit_tasks_err, non_distributable_err}; pub use network_boundary::{InputStageInfo, NetworkBoundary, NetworkBoundaryExt}; diff --git a/src/lib.rs b/src/lib.rs index 6d49c6f..a40f701 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,7 +17,9 @@ pub mod test_utils; pub use channel_resolver_ext::{BoxCloneSyncChannel, ChannelResolver}; pub use distributed_ext::DistributedExt; pub use distributed_planner::{ - DistributedPhysicalOptimizerRule, InputStageInfo, NetworkBoundary, NetworkBoundaryExt, + DistributedConfig, DistributedPhysicalOptimizerRule, InputStageInfo, IntoPlanDependentUsize, + NetworkBoundary, NetworkBoundaryExt, PlanDependentUsize, apply_network_boundaries, + distribute_plan, }; pub use execution_plans::{ DistributedExec, NetworkCoalesceExec, NetworkShuffleExec, PartitionIsolatorExec, diff --git a/src/metrics/task_metrics_collector.rs b/src/metrics/task_metrics_collector.rs index e6d5f54..ea956da 100644 --- a/src/metrics/task_metrics_collector.rs +++ b/src/metrics/task_metrics_collector.rs @@ -153,11 +153,9 @@ mod tests { .with_default_features() .with_config(config) .with_distributed_channel_resolver(InMemoryChannelResolver::new()) - .with_physical_optimizer_rule(Arc::new( - DistributedPhysicalOptimizerRule::default() - .with_network_coalesce_tasks(2) - .with_network_shuffle_tasks(2), - )) + .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule)) + .with_distributed_network_coalesce_tasks(2) + .with_distributed_network_shuffle_tasks(2) .build(); let ctx = SessionContext::from(state); diff --git a/src/metrics/task_metrics_rewriter.rs b/src/metrics/task_metrics_rewriter.rs index 4d1c06b..4c6812d 100644 --- a/src/metrics/task_metrics_rewriter.rs +++ b/src/metrics/task_metrics_rewriter.rs @@ -246,11 +246,9 @@ mod tests { if distributed { builder = builder .with_distributed_channel_resolver(InMemoryChannelResolver::new()) - .with_physical_optimizer_rule(Arc::new( - DistributedPhysicalOptimizerRule::default() - .with_network_coalesce_tasks(2) - .with_network_shuffle_tasks(2), - )) + .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule)) + .with_distributed_network_coalesce_tasks(2) + .with_distributed_network_shuffle_tasks(2) } let state = builder.build(); diff --git a/tests/custom_config_extension.rs b/tests/custom_config_extension.rs index f00c43e..ed69050 100644 --- a/tests/custom_config_extension.rs +++ b/tests/custom_config_extension.rs @@ -14,9 +14,11 @@ mod tests { use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, execute_stream, }; + use datafusion_distributed::NetworkShuffleExec; use datafusion_distributed::test_utils::localhost::start_localhost_context; - use datafusion_distributed::{DistributedExt, DistributedSessionBuilderContext}; - use datafusion_distributed::{DistributedPhysicalOptimizerRule, NetworkShuffleExec}; + use datafusion_distributed::{ + DistributedExt, DistributedSessionBuilderContext, distribute_plan, + }; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use futures::TryStreamExt; use prost::Message; @@ -56,7 +58,7 @@ mod tests { )?); } - let plan = DistributedPhysicalOptimizerRule::distribute_plan(plan)?; + let plan = distribute_plan(plan)?; let stream = execute_stream(plan, ctx.task_ctx())?; // It should not fail. stream.try_collect::>().await?; diff --git a/tests/custom_extension_codec.rs b/tests/custom_extension_codec.rs index 47134e6..30cb4f8 100644 --- a/tests/custom_extension_codec.rs +++ b/tests/custom_extension_codec.rs @@ -22,12 +22,12 @@ mod tests { use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, displayable, execute_stream, }; + use datafusion_distributed::NetworkShuffleExec; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::{ DistributedExt, DistributedSessionBuilderContext, PartitionIsolatorExec, assert_snapshot, - display_plan_ascii, + display_plan_ascii, distribute_plan, }; - use datafusion_distributed::{DistributedPhysicalOptimizerRule, NetworkShuffleExec}; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_proto::protobuf::proto_error; use futures::{TryStreamExt, stream}; @@ -58,7 +58,7 @@ mod tests { "); let distributed_plan = build_plan(true)?; - let distributed_plan = DistributedPhysicalOptimizerRule::distribute_plan(distributed_plan)?; + let distributed_plan = distribute_plan(distributed_plan)?; assert_snapshot!(display_plan_ascii(distributed_plan.as_ref(), false), @r" ┌───── DistributedExec ── Tasks: t0:[p0] diff --git a/tests/distributed_aggregation.rs b/tests/distributed_aggregation.rs index 1743bf5..24c41d4 100644 --- a/tests/distributed_aggregation.rs +++ b/tests/distributed_aggregation.rs @@ -1,13 +1,12 @@ #[cfg(all(feature = "integration", test))] mod tests { use datafusion::arrow::util::pretty::pretty_format_batches; - use datafusion::physical_optimizer::PhysicalOptimizerRule; use datafusion::physical_plan::{displayable, execute_stream}; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::test_utils::parquet::register_parquet_tables; use datafusion_distributed::{ - DefaultSessionBuilder, DistributedPhysicalOptimizerRule, assert_snapshot, - display_plan_ascii, + DefaultSessionBuilder, DistributedConfig, apply_network_boundaries, assert_snapshot, + display_plan_ascii, distribute_plan, }; use futures::TryStreamExt; use std::error::Error; @@ -24,9 +23,9 @@ mod tests { let physical_str = displayable(physical.as_ref()).indent(true).to_string(); - let physical_distributed = DistributedPhysicalOptimizerRule::default() - .with_network_shuffle_tasks(2) - .optimize(physical.clone(), &Default::default())?; + let cfg = DistributedConfig::default().with_network_shuffle_tasks(2); + let physical_distributed = apply_network_boundaries(physical.clone(), &cfg)?; + let physical_distributed = distribute_plan(physical_distributed)?; let physical_distributed_str = display_plan_ascii(physical_distributed.as_ref(), false); @@ -108,10 +107,11 @@ mod tests { let physical_str = displayable(physical.as_ref()).indent(true).to_string(); - let physical_distributed = DistributedPhysicalOptimizerRule::default() + let cfg = DistributedConfig::default() .with_network_shuffle_tasks(6) - .with_network_coalesce_tasks(6) - .optimize(physical.clone(), &Default::default())?; + .with_network_coalesce_tasks(6); + let physical_distributed = apply_network_boundaries(physical.clone(), &cfg)?; + let physical_distributed = distribute_plan(physical_distributed)?; let physical_distributed_str = display_plan_ascii(physical_distributed.as_ref(), false); diff --git a/tests/error_propagation.rs b/tests/error_propagation.rs index 5a2a370..bd7234b 100644 --- a/tests/error_propagation.rs +++ b/tests/error_propagation.rs @@ -14,8 +14,7 @@ mod tests { }; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::{ - DistributedExt, DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext, - NetworkShuffleExec, + DistributedExt, DistributedSessionBuilderContext, NetworkShuffleExec, distribute_plan, }; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_proto::protobuf::proto_error; @@ -51,7 +50,7 @@ mod tests { size, )?); } - let plan = DistributedPhysicalOptimizerRule::distribute_plan(plan)?; + let plan = distribute_plan(plan)?; let stream = execute_stream(plan, ctx.task_ctx())?; let Err(err) = stream.try_collect::>().await else { diff --git a/tests/highly_distributed_query.rs b/tests/highly_distributed_query.rs index 422e519..74a6373 100644 --- a/tests/highly_distributed_query.rs +++ b/tests/highly_distributed_query.rs @@ -6,8 +6,8 @@ mod tests { use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::test_utils::parquet::register_parquet_tables; use datafusion_distributed::{ - DefaultSessionBuilder, DistributedPhysicalOptimizerRule, NetworkShuffleExec, - assert_snapshot, display_plan_ascii, + DefaultSessionBuilder, NetworkShuffleExec, assert_snapshot, display_plan_ascii, + distribute_plan, }; use futures::TryStreamExt; use std::error::Error; @@ -34,8 +34,7 @@ mod tests { )?); } - let physical_distributed = - DistributedPhysicalOptimizerRule::distribute_plan(physical_distributed)?; + let physical_distributed = distribute_plan(physical_distributed)?; let physical_distributed_str = display_plan_ascii(physical_distributed.as_ref(), false); assert_snapshot!(physical_str, diff --git a/tests/introspection.rs b/tests/introspection.rs index ac0c590..f411536 100644 --- a/tests/introspection.rs +++ b/tests/introspection.rs @@ -2,14 +2,13 @@ mod tests { use datafusion::arrow::util::pretty::pretty_format_batches; use datafusion::execution::SessionStateBuilder; - use datafusion::physical_optimizer::PhysicalOptimizerRule; use datafusion::physical_plan::execute_stream; use datafusion::prelude::SessionConfig; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::test_utils::parquet::register_parquet_tables; use datafusion_distributed::{ - DefaultSessionBuilder, DistributedPhysicalOptimizerRule, - MappedDistributedSessionBuilderExt, assert_snapshot, display_plan_ascii, + DefaultSessionBuilder, DistributedConfig, MappedDistributedSessionBuilderExt, + apply_network_boundaries, assert_snapshot, display_plan_ascii, distribute_plan, }; use futures::TryStreamExt; use std::error::Error; @@ -28,10 +27,10 @@ mod tests { let df = ctx.sql(r#"SHOW COLUMNS from weather"#).await?; let physical = df.create_physical_plan().await?; - let physical_distributed = DistributedPhysicalOptimizerRule::default() + let cfg = DistributedConfig::default() .with_network_shuffle_tasks(2) - .with_network_coalesce_tasks(2) - .optimize(physical.clone(), &Default::default())?; + .with_network_coalesce_tasks(2); + let physical_distributed = distribute_plan(apply_network_boundaries(physical, &cfg)?)?; let physical_distributed_str = display_plan_ascii(physical_distributed.as_ref(), false); diff --git a/tests/stateful_execution_plan.rs b/tests/stateful_execution_plan.rs index cb4d0e4..a23473f 100644 --- a/tests/stateful_execution_plan.rs +++ b/tests/stateful_execution_plan.rs @@ -23,12 +23,12 @@ mod tests { use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, execute_stream, }; + use datafusion_distributed::NetworkShuffleExec; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::{ DistributedExt, DistributedSessionBuilderContext, PartitionIsolatorExec, assert_snapshot, - display_plan_ascii, + display_plan_ascii, distribute_plan, }; - use datafusion_distributed::{DistributedPhysicalOptimizerRule, NetworkShuffleExec}; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_proto::protobuf::proto_error; use futures::TryStreamExt; @@ -62,7 +62,7 @@ mod tests { let (ctx, _guard) = start_localhost_context(3, build_state).await; let distributed_plan = build_plan()?; - let distributed_plan = DistributedPhysicalOptimizerRule::distribute_plan(distributed_plan)?; + let distributed_plan = distribute_plan(distributed_plan)?; assert_snapshot!(display_plan_ascii(distributed_plan.as_ref(), false), @r" ┌───── DistributedExec ── Tasks: t0:[p0] diff --git a/tests/tpch_validation_test.rs b/tests/tpch_validation_test.rs index e0df1c1..9167dac 100644 --- a/tests/tpch_validation_test.rs +++ b/tests/tpch_validation_test.rs @@ -7,8 +7,8 @@ mod tests { use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::test_utils::tpch; use datafusion_distributed::{ - DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext, assert_snapshot, - display_plan_ascii, explain_analyze, + DistributedExt, DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext, + assert_snapshot, display_plan_ascii, explain_analyze, }; use futures::TryStreamExt; use std::error::Error; @@ -2197,13 +2197,12 @@ mod tests { async fn build_state( ctx: DistributedSessionBuilderContext, ) -> Result { - let rule = DistributedPhysicalOptimizerRule::new() - .with_network_shuffle_tasks(SHUFFLE_TASKS) - .with_network_coalesce_tasks(COALESCE_TASKS); Ok(SessionStateBuilder::new() .with_runtime_env(ctx.runtime_env) .with_default_features() - .with_physical_optimizer_rule(Arc::new(rule)) + .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule)) + .with_distributed_network_coalesce_tasks(COALESCE_TASKS) + .with_distributed_network_shuffle_tasks(SHUFFLE_TASKS) .build()) } diff --git a/tests/udfs.rs b/tests/udfs.rs index 1e67d05..9b88db4 100644 --- a/tests/udfs.rs +++ b/tests/udfs.rs @@ -10,14 +10,13 @@ mod tests { }; use datafusion::physical_expr::expressions::lit; use datafusion::physical_expr::{Partitioning, ScalarFunctionExpr}; - use datafusion::physical_optimizer::PhysicalOptimizerRule; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::{ExecutionPlan, execute_stream}; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::{ - DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext, assert_snapshot, - display_plan_ascii, + DistributedConfig, DistributedSessionBuilderContext, apply_network_boundaries, + assert_snapshot, display_plan_ascii, distribute_plan, }; use futures::TryStreamExt; use std::any::Any; @@ -59,10 +58,12 @@ mod tests { let node = wrap(wrap(Arc::new(EmptyExec::new(Arc::new(Schema::empty()))))); - let physical_distributed = DistributedPhysicalOptimizerRule::default() + let cfg = DistributedConfig::default() .with_network_shuffle_tasks(2) - .with_network_coalesce_tasks(2) - .optimize(node, &Default::default())?; + .with_network_coalesce_tasks(2); + let node = apply_network_boundaries(node, &cfg)?; + + let physical_distributed = distribute_plan(node)?; let physical_distributed_str = display_plan_ascii(physical_distributed.as_ref(), false);