diff --git a/benchmarks/cdk/bin/worker.rs b/benchmarks/cdk/bin/worker.rs index 45bdfbc4..49394dca 100644 --- a/benchmarks/cdk/bin/worker.rs +++ b/benchmarks/cdk/bin/worker.rs @@ -38,6 +38,10 @@ struct Cmd { /// The bucket name. #[structopt(long, default_value = "datafusion-distributed-benchmarks")] bucket: String, + + // Turns broadcast joins on. + #[structopt(long)] + broadcast_joins: bool, } #[tokio::main] @@ -67,12 +71,15 @@ async fn main() -> Result<(), Box> { let runtime_env = Arc::new(RuntimeEnv::default()); runtime_env.register_object_store(&s3_url, s3); - let state = SessionStateBuilder::new() + let mut state = SessionStateBuilder::new() .with_default_features() .with_runtime_env(Arc::clone(&runtime_env)) .with_distributed_worker_resolver(Ec2WorkerResolver::new()) .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule)) .build(); + if cmd.broadcast_joins { + state = state.with_distributed_broadcast_joins_enabled(true)?; + } let ctx = SessionContext::from(state); let worker = Worker::default().with_runtime_env(runtime_env); diff --git a/benchmarks/src/run.rs b/benchmarks/src/run.rs index cdbff0a3..04655ada 100644 --- a/benchmarks/src/run.rs +++ b/benchmarks/src/run.rs @@ -93,6 +93,10 @@ pub struct RunOpt { #[structopt(long)] children_isolator_unions: bool, + /// Turns on broadcast joins. + #[structopt(long)] + enable_broadcast_joins: bool, + /// Collects metrics across network boundaries #[structopt(long)] collect_metrics: bool, @@ -210,6 +214,7 @@ impl RunOpt { self.cardinality_task_sf.unwrap_or(1.0), )? .with_distributed_children_isolator_unions(self.children_isolator_unions)? + .with_distributed_broadcast_joins_enabled(self.enable_broadcast_joins)? .with_distributed_metrics_collection(self.collect_metrics)? .build(); let ctx = SessionContext::new_with_state(state); diff --git a/src/distributed_ext.rs b/src/distributed_ext.rs index 97bde461..47401527 100644 --- a/src/distributed_ext.rs +++ b/src/distributed_ext.rs @@ -449,6 +449,23 @@ pub trait DistributedExt: Sized { &mut self, enabled: bool, ) -> Result<(), DataFusionError>; + + /// Enables broadcast joins for CollectLeft hash joins. When enabled, the build side of + /// a CollectLeft join is broadcast to all consumer tasks instead of being coalesced + /// into a single partition. + /// + /// Note: This option is disabled by default until the implementation is smarter about when to + /// broadcast. + fn with_distributed_broadcast_joins_enabled( + self, + enabled: bool, + ) -> Result; + + /// Same as [DistributedExt::with_distributed_broadcast_joins_enabled] but with an in-place mutation. + fn set_distributed_broadcast_joins_enabled( + &mut self, + enabled: bool, + ) -> Result<(), DataFusionError>; } impl DistributedExt for SessionConfig { @@ -529,6 +546,15 @@ impl DistributedExt for SessionConfig { Ok(()) } + fn set_distributed_broadcast_joins_enabled( + &mut self, + enabled: bool, + ) -> Result<(), DataFusionError> { + let d_cfg = DistributedConfig::from_config_options_mut(self.options_mut())?; + d_cfg.broadcast_joins_enabled = enabled; + Ok(()) + } + delegate! { to self { #[call(set_distributed_option_extension)] @@ -574,6 +600,10 @@ impl DistributedExt for SessionConfig { #[call(set_distributed_children_isolator_unions)] #[expr($?;Ok(self))] fn with_distributed_children_isolator_unions(mut self, enabled: bool) -> Result; + + #[call(set_distributed_broadcast_joins_enabled)] + #[expr($?;Ok(self))] + fn with_distributed_broadcast_joins_enabled(mut self, enabled: bool) -> Result; } } } @@ -635,6 +665,11 @@ impl DistributedExt for SessionStateBuilder { #[call(set_distributed_children_isolator_unions)] #[expr($?;Ok(self))] fn with_distributed_children_isolator_unions(mut self, enabled: bool) -> Result; + + fn set_distributed_broadcast_joins_enabled(&mut self, enabled: bool) -> Result<(), DataFusionError>; + #[call(set_distributed_broadcast_joins_enabled)] + #[expr($?;Ok(self))] + fn with_distributed_broadcast_joins_enabled(mut self, enabled: bool) -> Result; } } } @@ -696,6 +731,11 @@ impl DistributedExt for SessionState { #[call(set_distributed_children_isolator_unions)] #[expr($?;Ok(self))] fn with_distributed_children_isolator_unions(mut self, enabled: bool) -> Result; + + fn set_distributed_broadcast_joins_enabled(&mut self, enabled: bool) -> Result<(), DataFusionError>; + #[call(set_distributed_broadcast_joins_enabled)] + #[expr($?;Ok(self))] + fn with_distributed_broadcast_joins_enabled(mut self, enabled: bool) -> Result; } } } @@ -757,6 +797,11 @@ impl DistributedExt for SessionContext { #[call(set_distributed_children_isolator_unions)] #[expr($?;Ok(self))] fn with_distributed_children_isolator_unions(self, enabled: bool) -> Result; + + fn set_distributed_broadcast_joins_enabled(&mut self, enabled: bool) -> Result<(), DataFusionError>; + #[call(set_distributed_broadcast_joins_enabled)] + #[expr($?;Ok(self))] + fn with_distributed_broadcast_joins_enabled(self, enabled: bool) -> Result; } } } diff --git a/src/distributed_planner/distributed_config.rs b/src/distributed_planner/distributed_config.rs index 5eaea74a..b070f755 100644 --- a/src/distributed_planner/distributed_config.rs +++ b/src/distributed_planner/distributed_config.rs @@ -39,6 +39,12 @@ extensions_options! { /// Propagate collected metrics from all nodes in the plan across network boundaries /// so that they can be reconstructed on the head node of the plan. pub collect_metrics: bool, default = true + /// Enable broadcast joins for CollectLeft hash joins. When enabled, the build side of + /// a CollectLeft join is broadcast to all consumer tasks. + /// TODO: This option exists temporarily until we become smarter about when to actually + /// use broadcasting like checking build side size. + /// For now, broadcasting all CollectLeft joins is not always beneficial. + pub broadcast_joins_enabled: bool, default = false /// Collection of [TaskEstimator]s that will be applied to leaf nodes in order to /// estimate how many tasks should be spawned for the [Stage] containing the leaf node. pub(crate) __private_task_estimator: CombinedTaskEstimator, default = CombinedTaskEstimator::default() diff --git a/src/distributed_planner/distributed_physical_optimizer_rule.rs b/src/distributed_planner/distributed_physical_optimizer_rule.rs index c3e4e79d..906a1a3a 100644 --- a/src/distributed_planner/distributed_physical_optimizer_rule.rs +++ b/src/distributed_planner/distributed_physical_optimizer_rule.rs @@ -3,9 +3,9 @@ use crate::distributed_planner::plan_annotator::{ AnnotatedPlan, RequiredNetworkBoundary, annotate_plan, }; use crate::{ - DistributedConfig, DistributedExec, NetworkCoalesceExec, NetworkShuffleExec, TaskEstimator, + DistributedConfig, DistributedExec, NetworkBroadcastExec, NetworkCoalesceExec, + NetworkShuffleExec, TaskEstimator, }; -use datafusion::common::internal_err; use datafusion::common::tree_node::{Transformed, TreeNode}; use datafusion::config::ConfigOptions; use datafusion::error::DataFusionError; @@ -80,7 +80,7 @@ impl PhysicalOptimizerRule for DistributedPhysicalOptimizerRule { /// - The leaf nodes are scaled up in parallelism based on the number of distributed tasks in /// which they are going to run. This is configurable by the user via the [TaskEstimator] trait. /// - The appropriate network boundaries are placed in the plan depending on how it was annotated, -/// so new nodes like [NetworkCoalesceExec] and [NetworkShuffleExec] will be present. +/// so new nodes like [NetworkBroadcastExec], [NetworkCoalesceExec] and [NetworkShuffleExec] will be present. fn distribute_plan( annotated_plan: AnnotatedPlan, cfg: &ConfigOptions, @@ -88,19 +88,18 @@ fn distribute_plan( stage_id: &mut usize, ) -> Result, DataFusionError> { let d_cfg = DistributedConfig::from_config_options(cfg)?; - let children = annotated_plan.children; - // This is a leaf node, so we need to scale it up with the final task count. + let parent_task_count = annotated_plan.task_count.as_usize(); + if children.is_empty() { let scaled_up = d_cfg.__private_task_estimator.scale_up_leaf_node( &annotated_plan.plan, - annotated_plan.task_count.as_usize(), + parent_task_count, cfg, ); return Ok(scaled_up.unwrap_or(annotated_plan.plan)); } - let parent_task_count = annotated_plan.task_count.as_usize(); let max_child_task_count = children.iter().map(|v| v.task_count.as_usize()).max(); let new_children = children @@ -108,50 +107,64 @@ fn distribute_plan( .map(|child| distribute_plan(child, cfg, query_id, stage_id)) .collect::, _>>()?; - // It does not need a NetworkBoundary, so just keep recursing. - let Some(nb_req) = annotated_plan.required_network_boundary else { + // We see significant speed-ups when we introduce network boundaries between nested CollectLeft + // HashJoinExecs. This doesn't make logical sense as they are the same plans, only differing + // via addiitional network boundaries which solely cause overhead. + // + // Hypothesis: The network boundary operators use a spawn_select_all which just buffers streams + // in memory leading to a faster even with extra network hops. + if annotated_plan.required_network_boundary.is_some() + && parent_task_count == 1 + && max_child_task_count == Some(1) + { return annotated_plan.plan.with_new_children(new_children); - }; - - // It would need a network boundary, but on both sides of the boundary there is just 1 task, - // so we are fine with not introducing any network boundary. - if parent_task_count == 1 && max_child_task_count == Some(1) { - return annotated_plan.plan.with_new_children(new_children); - } - - // If the current node has a RepartitionExec below, it needs a shuffle, so put one - // NetworkShuffleExec boundary in between the RepartitionExec and the current node. - if nb_req == RequiredNetworkBoundary::Shuffle { - let new_child = Arc::new(NetworkShuffleExec::try_new( - require_one_child(new_children)?, - query_id, - *stage_id, - parent_task_count, - max_child_task_count.unwrap_or(1), - )?); - stage_id.add_assign(1); - return annotated_plan.plan.with_new_children(vec![new_child]); } - // If this is a CoalescePartitionsExec or a SortMergePreservingExec, it means that the original - // plan is trying to merge all partitions into one. We need to go one step ahead and also merge - // all distributed tasks into one. - if nb_req == RequiredNetworkBoundary::Coalesce { - let new_child = Arc::new(NetworkCoalesceExec::try_new( - require_one_child(new_children)?, - query_id, - *stage_id, - parent_task_count, - max_child_task_count.unwrap_or(1), - )?); - stage_id.add_assign(1); - return annotated_plan.plan.with_new_children(vec![new_child]); + match annotated_plan.required_network_boundary { + // No network boundary needed, just recurse on children. + None => annotated_plan.plan.with_new_children(new_children), + // If the current node has a RepartitionExec below, it needs a shuffle, so put one + // NetworkShuffleExec boundary in between the RepartitionExec and the current node. + Some(RequiredNetworkBoundary::Shuffle) => { + let new_child = Arc::new(NetworkShuffleExec::try_new( + require_one_child(new_children)?, + query_id, + *stage_id, + parent_task_count, + max_child_task_count.unwrap_or(1), + )?); + stage_id.add_assign(1); + annotated_plan.plan.with_new_children(vec![new_child]) + } + // If this is a CoalescePartitionsExec or a SortMergePreservingExec, it means that the original + // plan is trying to merge all partitions into one. We need to go one step ahead and also merge + // all distributed tasks into one. + Some(RequiredNetworkBoundary::Coalesce) => { + let new_child = Arc::new(NetworkCoalesceExec::try_new( + require_one_child(new_children)?, + query_id, + *stage_id, + parent_task_count, + max_child_task_count.unwrap_or(1), + )?); + stage_id.add_assign(1); + annotated_plan.plan.with_new_children(vec![new_child]) + } + // Broadcast boundary is placed on the build child of a CollectLeft HashJoinExec, it means + // that the build side (this node) is trying to broadcast to all consumers. We need to + // insert a BroadcastExec and NetworkBroadcastExec. + Some(RequiredNetworkBoundary::Broadcast) => { + let new_child = NetworkBroadcastExec::with_inner_broadcast( + require_one_child(new_children)?, + query_id, + *stage_id, + parent_task_count, + max_child_task_count.unwrap_or(1), + )?; + stage_id.add_assign(1); + annotated_plan.plan.with_new_children(vec![new_child]) + } } - - internal_err!( - "Unreachable code reached in distribute_plan. Could not determine how to place a network boundary below {}", - annotated_plan.plan.name() - ) } /// Rearranges the [CoalesceBatchesExec] nodes in the plan so that they are placed right below @@ -197,7 +210,7 @@ fn push_down_batch_coalescing( mod tests { use crate::test_utils::in_memory_channel_resolver::InMemoryWorkerResolver; use crate::test_utils::parquet::register_parquet_tables; - use crate::{DistributedExt, DistributedPhysicalOptimizerRule}; + use crate::{DistributedConfig, DistributedExt, DistributedPhysicalOptimizerRule}; use crate::{assert_snapshot, display_plan_ascii}; use datafusion::execution::SessionStateBuilder; use datafusion::prelude::{SessionConfig, SessionContext}; @@ -760,6 +773,88 @@ mod tests { "); } + #[tokio::test] + async fn test_broadcast_creates_network_broadcast_exec() { + let query = r#" + SELECT a."MinTemp", b."MaxTemp" + FROM weather a INNER JOIN weather b + ON a."RainToday" = b."RainToday" + "#; + let plan = sql_to_explain_with_broadcast(query, 3, true).await; + assert_snapshot!(plan, @r" + ┌───── DistributedExec ── Tasks: t0:[p0] + │ CoalescePartitionsExec + │ [Stage 2] => NetworkCoalesceExec: output_partitions=3, input_tasks=3 + └────────────────────────────────────────────────── + ┌───── Stage 2 ── Tasks: t0:[p0] t1:[p1] t2:[p2] + │ CoalesceBatchesExec: target_batch_size=8192 + │ HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2] + │ CoalescePartitionsExec + │ [Stage 1] => NetworkBroadcastExec: partitions_per_consumer=1, stage_partitions=3, input_tasks=3 + │ PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0] + │ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ] + └────────────────────────────────────────────────── + ┌───── Stage 1 ── Tasks: t0:[p0..p2] t1:[p3..p5] t2:[p6..p8] + │ BroadcastExec: input_partitions=1, consumer_tasks=3, output_partitions=3 + │ PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0] + │ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet + └────────────────────────────────────────────────── + ") + } + + #[tokio::test] + async fn test_broadcast_downgrades_to_coalesce_single_consumer() { + // When broadcast_joins_enabled is true but there's only single consumer, + // broadcast should downgrade to coalesce + let query = r#" + SELECT a."MinTemp", b."MaxTemp" + FROM weather a INNER JOIN weather b + ON a."RainToday" = b."RainToday" + "#; + let plan = sql_to_explain_with_broadcast(query, 0, true).await; + assert_snapshot!(plan, @r" + CoalesceBatchesExec: target_batch_size=8192 + HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2] + CoalescePartitionsExec + DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet + DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ] + ") + } + + #[tokio::test] + async fn test_broadcast_nested_joins() { + let query = r#" + SELECT a."MinTemp", b."MaxTemp", c."Rainfall" + FROM weather a + INNER JOIN weather b ON a."RainToday" = b."RainToday" + INNER JOIN weather c ON b."RainToday" = c."RainToday" + "#; + let plan = sql_to_explain_with_broadcast(query, 3, true).await; + assert_snapshot!(plan, @r" + ┌───── DistributedExec ── Tasks: t0:[p0] + │ CoalescePartitionsExec + │ CoalesceBatchesExec: target_batch_size=8192 + │ HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@2, RainToday@1)], projection=[MinTemp@0, MaxTemp@1, Rainfall@3] + │ CoalescePartitionsExec + │ [Stage 2] => NetworkCoalesceExec: output_partitions=3, input_tasks=3 + │ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[Rainfall, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ] + └────────────────────────────────────────────────── + ┌───── Stage 2 ── Tasks: t0:[p0] t1:[p1] t2:[p2] + │ CoalesceBatchesExec: target_batch_size=8192 + │ HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2, RainToday@3] + │ CoalescePartitionsExec + │ [Stage 1] => NetworkBroadcastExec: partitions_per_consumer=1, stage_partitions=3, input_tasks=3 + │ PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0] + │ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ] + └────────────────────────────────────────────────── + ┌───── Stage 1 ── Tasks: t0:[p0..p2] t1:[p3..p5] t2:[p6..p8] + │ BroadcastExec: input_partitions=1, consumer_tasks=3, output_partitions=3 + │ PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0] + │ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet + └────────────────────────────────────────────────── + ") + } + async fn sql_to_explain( query: &str, f: impl FnOnce(SessionStateBuilder) -> SessionStateBuilder, @@ -790,4 +885,42 @@ mod tests { let physical_plan = df.create_physical_plan().await.unwrap(); display_plan_ascii(physical_plan.as_ref(), false) } + + async fn sql_to_explain_with_broadcast( + query: &str, + num_workers: usize, + broadcast_enabled: bool, + ) -> String { + let mut config = SessionConfig::new() + .with_target_partitions(4) + .with_information_schema(true); + + let d_cfg = DistributedConfig { + broadcast_joins_enabled: broadcast_enabled, + ..Default::default() + }; + config.set_distributed_option_extension(d_cfg).unwrap(); + + let state = SessionStateBuilder::new() + .with_default_features() + .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule)) + .with_config(config) + .with_distributed_worker_resolver(InMemoryWorkerResolver::new(num_workers)) + .build(); + + let ctx = SessionContext::new_with_state(state); + let mut queries = query.split(";").collect_vec(); + let last_query = queries.pop().unwrap(); + + for query in queries { + ctx.sql(query).await.unwrap(); + } + + register_parquet_tables(&ctx).await.unwrap(); + + let df = ctx.sql(last_query).await.unwrap(); + + let physical_plan = df.create_physical_plan().await.unwrap(); + display_plan_ascii(physical_plan.as_ref(), false) + } } diff --git a/src/distributed_planner/network_boundary.rs b/src/distributed_planner/network_boundary.rs index 25359724..fcb0b676 100644 --- a/src/distributed_planner/network_boundary.rs +++ b/src/distributed_planner/network_boundary.rs @@ -1,4 +1,4 @@ -use crate::{NetworkCoalesceExec, NetworkShuffleExec, Stage}; +use crate::{NetworkBroadcastExec, NetworkCoalesceExec, NetworkShuffleExec, Stage}; use datafusion::physical_plan::ExecutionPlan; use std::sync::Arc; @@ -35,6 +35,8 @@ impl NetworkBoundaryExt for dyn ExecutionPlan { Some(node) } else if let Some(node) = self.as_any().downcast_ref::() { Some(node) + } else if let Some(node) = self.as_any().downcast_ref::() { + Some(node) } else { None } diff --git a/src/distributed_planner/plan_annotator.rs b/src/distributed_planner/plan_annotator.rs index 87b93671..446f4b95 100644 --- a/src/distributed_planner/plan_annotator.rs +++ b/src/distributed_planner/plan_annotator.rs @@ -1,3 +1,4 @@ +use crate::TaskCountAnnotation::{Desired, Maximum}; use crate::execution_plans::ChildrenIsolatorUnionExec; use crate::{DistributedConfig, TaskCountAnnotation, TaskEstimator}; use datafusion::common::{DataFusionError, plan_datafusion_err}; @@ -19,6 +20,7 @@ use std::sync::Arc; pub(super) enum RequiredNetworkBoundary { Shuffle, Coalesce, + Broadcast, } /// Wraps an [ExecutionPlan] and annotates it with information about how many distributed tasks @@ -63,88 +65,220 @@ impl Debug for AnnotatedPlan { } } -/// Annotates recursively an [ExecutionPlan] and its children with information about how many -/// distributed tasks it should run on, and whether it needs a network boundary below it or not. +/// Annotates an [ExecutionPlan] in three passes. /// -/// This is the first step of the distribution process, where the plan structure is still left -/// untouched and the existing nodes are just annotated for future steps to perform the distribution. -/// -/// The plans are annotated in a bottom-to-top manner, starting with the leaf nodes all the way -/// to the head of the plan: +/// Here is an un-annotated, single-node plan which will be used to understand each phase's purpose: +/// ```text +/// ┌──────────────────────┐ +/// │ │ +/// │ CoalesceBatches │ +/// │ │ +/// └───────────▲──────────┘ +/// ┌───────────┴──────────┐ +/// │ HashJoin │ +/// │ (CollectLeft) │ +/// │ │ +/// └────▲────────────▲────┘ +/// │ │ +/// ┌─────────┘ └──────────┐ +/// Build Side Probe Side +/// ┌───────────┴──────────┐ ┌───────────┴──────────┐ +/// │ │ │ │ +/// │ CoalescePartitions │ │ CoalescePartitions │ +/// │ │ │ │ +/// └───────────▲──────────┘ └───────────▲──────────┘ +/// ┌───────────┴──────────┐ ┌───────────┴──────────┐ +/// │ │ │ │ +/// │ DataSource │ │ Projection │ +/// │ │ │ │ +/// └──────────────────────┘ └───────────▲──────────┘ +/// ┌───────────┴──────────┐ +/// │ Aggregation │ +/// │ (Final) │ +/// │ │ +/// └───────────▲──────────┘ +/// ┌───────────┴──────────┐ +/// │ │ +/// │ Repartition │ +/// │ │ +/// └───────────▲──────────┘ +/// ┌───────────┴──────────┐ +/// │ Aggregation │ +/// │ (Partial) │ +/// │ │ +/// └───────────▲──────────┘ +/// ┌───────────┴──────────┐ +/// │ │ +/// │ DataSource │ +/// │ │ +/// └──────────────────────┘ +/// ``` /// -/// 1. Leaf nodes have the opportunity to provide an estimation of how many distributed tasks should -/// be used for the whole stage that will execute them. +/// # Pass 1: bottom-to-top Task Estimation and Mark Network Boundaries /// -/// 2. If a stage contains multiple leaf nodes, and all provide a task count estimation, the -/// biggest is taken. +/// This pass is a bottom-to-top pass that sets each node's task_count that depends on its children's task_count +/// and required_network_boundary in the [AnnotatedPlan]. This stems from the [DataSourceExec] nodes and sets +/// task_counts based on the estimated amount of data and cardinality. /// -/// 3. When traversing the plan in a bottom-to-top fashion, this function looks for nodes that -/// either increase or reduce cardinality: -/// - If there's a node that increases cardinality, the next stage will spawn more tasks than -/// the current one. -/// - If there's a node that reduces cardinality, the next stage will spawn fewer tasks than the -/// current one. +/// Regarding task_count this marks: +/// 1. DataSourceExec -> Estimates task_count via the [TaskEstimator]. +/// 2. Non-[NetworkBoundary] nodes -> inherits the max task_count from its children. +/// 3. [NetworkBoundary] nodes: +/// - [NetworkBoundary::Coalesce] -> Maximum(1): trying to coalesce partitions into 1. +/// - [NetworkBoundary::Shuffle] -> Desired(N): calculated based on its child and if +/// cardinality is increased or decreased. +/// - [NetworkBoundary::Broadcast] -> Desired(1): this is a placeholder value because +/// broadcst boundaries depend on their parent's task_count (the amount of consumers). Unlike other +/// boundaries, which depend on their childrens' task_counts, the parent's task count isn't +/// known during this bottom-to-top traversal. This will be correctly in pass 3. /// -/// 4. At a certain point, the function will reach a node that needs a network boundary below; in -/// that case, the node is annotated with a [RequiredNetworkBoundary] value. At this point, all -/// the nodes below must reach a consensus about the final task count for the stage below the -/// network boundary. +/// Regarding required_network_boundary this marks: +/// 1. [RepartitionExec] with [Partitioning::Hash] -> [RequiredNetworkBoundary::Shuffle] +/// 2. [CoalescePartitionsExec] or [SortPreservingMergeExec] -> [RequiredNetworkBoundary::Coalesce] +/// 3. The build (left) child of a [HashJoinExec] with [PartitionMode::CollectLeft] -> [RequiredNetworkBoundary::Broadcast] /// -/// 5. This process is repeated recursively until all nodes are annotated. +/// The example plan after this pass would look like: +/// ```text +/// ┌──────────────────────┐ +/// │ │ required_network_boundary: None +/// │ CoalesceBatches │ task_count: Maximum(1) +/// │ │ Explanation: task_count inherited from child. +/// └───────────▲──────────┘ +/// │ +/// ┌───────────┴──────────┐ +/// │ HashJoin │ required_network_boundary: None +/// │ (CollectLeft) │ task_count: Maximum(1) +/// │ │ Explanation: With two children (X, Y) a node chooses +/// └────▲────────────▲────┘ Maximum(Y) if X: Desired(N) and Y: Maximum(M). +/// │ │ +/// ┌─────────┘ └──────────┐ +/// Build Side Probe Side +/// │ │ +/// required_network_boundary: Some(Broadcast) ┌───────────┴──────────┐ ┌───────────┴──────────┐ +/// task_count: Desired(1) │ │ │ │ required_network_boundary: Some(Coalesce) +/// Explanation: Is a [NetworkBoundary::Broadcast] because its the build │ CoalescePartitions │ │ CoalescePartitions │ task_count: Maximum(1) +/// child of a [PartitionMode::CollectLeft] [HashJoinExec]. task_count is a │ │ │ │ Explanation: Is a [NetworkBoundary::Coalesce] because it is a +/// placeholder since [NetworkBoundary::Broadcast] depends on its parent's └───────────▲──────────┘ └───────────▲──────────┘ [CoalescePartitionsExec] and not the build child. +/// task_count but is not known here (will be set in next pass). │ │ +/// ┌───────────┴──────────┐ ┌───────────┴──────────┐ +/// required_network_boundary: None │ │ │ │ required_network_boundary: None +/// task_count: Desired(2) │ DataSource │ │ Projection │ task_count: Desired(2) +/// Explanation: task_count calculated by the [TaskEstimator]. │ │ │ │ Explanation: task_count inherited from child. +/// └──────────────────────┘ └───────────▲──────────┘ +/// │ +/// ┌───────────┴──────────┐ required_network_boundary: Some(Shuffle) +/// │ Aggregation │ task_count: Desired(2) +/// │ (Final) │ Explanation: Is a [NetworkBoundary::Shuffle] because its child +/// │ │ is a [RepartitionExec]. task_count calculated based on +/// └───────────▲──────────┘ cardinality, which is reduced in stage below. +/// │ +/// ┌───────────┴──────────┐ +/// │ │ required_network_boundary: None +/// │ Repartition │ task_count: Desired(4) +/// │ │ Explanation: task_count inherited from child. +/// └───────────▲──────────┘ +/// │ +/// ┌───────────┴──────────┐ +/// │ Aggregation │ required_network_boundary: None +/// │ (Partial) │ task_count: Desired(4) +/// │ │ Explanation: task_count inherited from child. +/// └───────────▲──────────┘ +/// │ +/// ┌───────────┴──────────┐ +/// │ │ required_network_boundary: None +/// │ DataSource │ task_count: Desired(4) +/// │ │ Explanation: task_count calculated by the [TaskEstimator]. +/// └──────────────────────┘ +/// ``` /// -/// ## Example: +/// # Pass 2: Set Operators' Task Count Top-to-bottom /// -/// Following the process above, an annotated plan will look like this: +/// This pass is a top-to-bottom pass that sets each node's task count in the [AnnotatedPlan] that +/// depends on their parent's task_count. This pass is used for nodes marked [NetworkBoundary::Broadcast]. +/// As a result this marks: +/// 1. [NetworkBoundary::Broadcast] -> Parent's task_count if task_count > 1. +/// - This pass will downgrade [NetworkBoundary::Broadcast] -> [NetworkBoundary::Coalesce] +/// is task_count == 1 since their is no benefit to broadcasting and caching adds slight +/// overhead. +/// 2. All other nodes -> Unchanged since their task_count does not depend on their parent's. /// +/// The example plan after this phase would look like: /// ```text -/// ┌────────────────────┐ task_count: Maximum(1) (because we try to coalesce all partitions into 1) -/// │ CoalescePartitions │ network_boundary: Some(Coalesce) -/// └──────────▲─────────┘ -/// │ -/// ┌──────────┴─────────┐ task_count: Desired(3) (inherited from the child) -/// │ Projection │ network_boundary: None -/// └──────────▲─────────┘ -/// │ -/// ┌──────────┴─────────┐ task_count: Desired(3) (as this node requires a network boundary below, -/// │ Aggregation │ and the stage below reduces the cardinality of the data because of the -/// │ (final) │ partial aggregation, we can choose a smaller amount of tasks) -/// └──────────▲─────────┘ network_boundary: Some(Shuffle) (because the child is a repartition) -/// │ -/// ┌──────────┴─────────┐ task_count: Desired(4) (inherited from the child) -/// │ Repartition │ network_boundary: None -/// └──────────▲─────────┘ -/// │ -/// ┌──────────┴─────────┐ task_count: Desired(4) (inherited from the child) -/// │ Aggregation │ network_boundary: None -/// │ (partial) │ -/// └──────────▲─────────┘ -/// │ -/// ┌──────────┴─────────┐ task_count: Desired(4) (this was set by a TaskEstimator implementation) -/// │ DataSourceExec │ network_boundary: None -/// └────────────────────┘ -/// ``` -/// +/// ┌──────────────────────┐ +/// │ │ required_network_boundary: None +/// │ CoalesceBatches │ task_count: Maximum(1) +/// │ │ +/// └───────────▲──────────┘ +/// │ +/// ┌───────────┴──────────┐ +/// │ HashJoin │ required_network_boundary: None +/// │ (CollectLeft) │ task_count: Maximum(1) +/// │ │ +/// └────▲────────────▲────┘ +/// │ │ +/// ┌─────────┘ └──────────┐ +/// Build Side Probe Side +/// │ │ +/// required_network_boundary: Some(Broadcast) ┌───────────┴──────────┐ ┌───────────┴──────────┐ +/// task_count: Maximum(1) │ │ │ │ required_network_boundary: Some(Coalesce) +/// Explanation: Inherited from its parent since it is │ CoalescePartitions │ │ CoalescePartitions │ task_count: Maximum(1) +/// the build child of a [HashJoinExec] with │ │ │ │ +/// [PartitionMode::CollectLeft] └───────────▲──────────┘ └───────────▲──────────┘ +/// │ │ +/// ┌───────────┴──────────┐ ┌───────────┴──────────┐ +/// required_network_boundary: None │ │ │ │ required_network_boundary: None +/// task_count: Desired(2) │ DataSource │ │ Projection │ task_count: Desired(2) +/// │ │ │ │ +/// └──────────────────────┘ └───────────▲──────────┘ +/// │ +/// ┌───────────┴──────────┐ +/// │ Aggregation │ required_network_boundary: Some(Shuffle) +/// │ (Final) │ task_count: Desired(2) +/// │ │ +/// └───────────▲──────────┘ +/// │ +/// ┌───────────┴──────────┐ +/// │ │ required_network_boundary: None +/// │ Repartition │ task_count: Desired(4) +/// │ │ +/// └───────────▲──────────┘ +/// │ +/// ┌───────────┴──────────┐ +/// │ Aggregation │ required_network_boundary: None +/// │ (Partial) │ task_count: Desired(4) +/// │ │ +/// └───────────▲──────────┘ +/// │ +/// ┌───────────┴──────────┐ +/// │ │ required_network_boundary: None +/// │ DataSource │ task_count: Desired(4) +/// │ │ +/// └──────────────────────┘ /// ``` pub(super) fn annotate_plan( plan: Arc, cfg: &ConfigOptions, ) -> Result { - _annotate_plan(plan, cfg, true) + let annotated_plan = _annotate_plan_bottom_up(plan, cfg, true)?; + annotate_plan_top_down(annotated_plan) } -fn _annotate_plan( + +/// This is Phase 1 of annoatation as described above which sets initial task_counts and marks all +/// necessary [NetworkBoundary]. +fn _annotate_plan_bottom_up( plan: Arc, cfg: &ConfigOptions, root: bool, ) -> Result { - use TaskCountAnnotation::*; let d_cfg = DistributedConfig::from_config_options(cfg)?; + let broadcast_joins_enabled = d_cfg.broadcast_joins_enabled; let estimator = &d_cfg.__private_task_estimator; let n_workers = d_cfg.__private_worker_resolver.0.get_urls()?.len().max(1); let annotated_children = plan .children() .iter() - .map(|child| _annotate_plan(Arc::clone(child), cfg, false)) + .map(|child| _annotate_plan_bottom_up(Arc::clone(child), cfg, false)) .collect::, _>>()?; if plan.children().is_empty() { @@ -184,10 +318,10 @@ fn _annotate_plan( task_count = Desired(count); } else if let Some(node) = plan.as_any().downcast_ref::() && node.mode == PartitionMode::CollectLeft + && !broadcast_joins_enabled { - // We cannot distribute CollectLeft HashJoinExec nodes yet. Once - // https://github.com/datafusion-contrib/datafusion-distributed/pull/229 lands, - // we can remove this check. + // Onlly distriubte CollectLeft HashJoins after we broadcast more intelligently or when it + // is explicitly enabled. task_count = Maximum(1); } else { // The task count for this plan is decided by the biggest task count from the children; unless @@ -206,13 +340,28 @@ fn _annotate_plan( task_count = task_count.limit(n_workers); - // The plan does not need a NetworkBoundary, so just take the biggest task count from - // the children and annotate the plan with that. - let mut annotated_plan = AnnotatedPlan { - required_network_boundary: required_network_boundary_below(plan.as_ref()), - children: annotated_children, - task_count, - plan, + // Check if this plan needs a network boundary below it. + let boundary = required_network_boundary_below(plan.as_ref(), broadcast_joins_enabled); + + // For Broadcast, mark the direct build child. + let mut annotated_plan = if boundary == Some(RequiredNetworkBoundary::Broadcast) { + let mut children = annotated_children; + if let Some(build_child) = children.first_mut() { + build_child.required_network_boundary = Some(RequiredNetworkBoundary::Broadcast); + } + AnnotatedPlan { + required_network_boundary: None, + children, + task_count, + plan, + } + } else { + AnnotatedPlan { + required_network_boundary: boundary, + children: annotated_children, + task_count, + plan, + } }; // The plan needs a NetworkBoundary. At this point we have all the info we need for choosing @@ -224,10 +373,12 @@ fn _annotate_plan( d_cfg: &DistributedConfig, ) -> Result<(), DataFusionError> { plan.task_count = task_count.clone(); + // For Shuffle/Coalesce boundaries, we've set task_count above but don't propagate + // further (children are in a different stage). if plan.required_network_boundary.is_some() { - // nothing to propagate here, all the nodes below the network boundary were already - // assigned a task count, we do not want to overwrite it. - } else if d_cfg.children_isolator_unions && plan.plan.as_any().is::() { + return Ok(()); + } + if d_cfg.children_isolator_unions && plan.plan.as_any().is::() { // Propagating through ChildrenIsolatorUnionExec is not that easy, each child will // be executed in its own task, and therefore, they will act as if they were in executing // in a non-distributed context. The ChildrenIsolatorUnionExec itself will make sure to @@ -315,14 +466,17 @@ fn _annotate_plan( /// Returns if the [ExecutionPlan] requires a network boundary below it, and if it does, the kind /// of network boundary ([RequiredNetworkBoundary]). -fn required_network_boundary_below(parent: &dyn ExecutionPlan) -> Option { +fn required_network_boundary_below( + parent: &dyn ExecutionPlan, + broadcast_joins_enabled: bool, +) -> Option { let children = parent.children(); let first_child = children.first()?; - if let Some(r_exec) = first_child.as_any().downcast_ref::() { - if matches!(r_exec.partitioning(), Partitioning::Hash(_, _)) { - return Some(RequiredNetworkBoundary::Shuffle); - } + if let Some(r_exec) = first_child.as_any().downcast_ref::() + && matches!(r_exec.partitioning(), Partitioning::Hash(_, _)) + { + return Some(RequiredNetworkBoundary::Shuffle); } if parent.as_any().is::() || parent.as_any().is::() @@ -334,16 +488,51 @@ fn required_network_boundary_below(parent: &dyn ExecutionPlan) -> Option() + && hash_join.partition_mode() == &PartitionMode::CollectLeft + { + return Some(RequiredNetworkBoundary::Broadcast); + } None } +fn annotate_plan_top_down(mut plan: AnnotatedPlan) -> Result { + // Set broadcast children's task_count to parent's task_count + // Downgrade Broadcast to Coalesce if parent's task_count <= 1 since there is no benefit from + // broadcasting to a single consumer. + let parent_task_count = plan.task_count.as_usize(); + for child in &mut plan.children { + if child.required_network_boundary == Some(RequiredNetworkBoundary::Broadcast) { + if parent_task_count > 1 { + child.task_count = plan.task_count.clone(); + } else { + child.required_network_boundary = Some(RequiredNetworkBoundary::Coalesce); + child.task_count = Maximum(1); + } + } + } + + let annotated_children = plan + .children + .into_iter() + .map(annotate_plan_top_down) + .collect::, _>>()?; + + plan.children = annotated_children; + Ok(plan) +} + #[cfg(test)] mod tests { use super::*; use crate::test_utils::in_memory_channel_resolver::InMemoryWorkerResolver; use crate::test_utils::parquet::register_parquet_tables; - use crate::{DistributedExt, TaskEstimation, assert_snapshot}; + use crate::{DistributedConfig, DistributedExt, TaskEstimation, assert_snapshot}; use datafusion::execution::SessionStateBuilder; use datafusion::physical_plan::filter::FilterExec; use datafusion::prelude::{SessionConfig, SessionContext}; @@ -403,6 +592,8 @@ mod tests { ") } + // TODO: should be changed once broadcasting is done more intelligently and not behind a + // feature flag. #[tokio::test] async fn test_left_join() { let query = r#" @@ -418,6 +609,8 @@ mod tests { ") } + // TODO: should be changed once broadcasting is done more intelligently and not behind a + // feature flag. #[tokio::test] async fn test_left_join_distributed() { let query = r#" @@ -469,6 +662,8 @@ mod tests { ") } + // TODO: should be changed once broadcasting is done more intelligently and not behind a + // feature flag. #[tokio::test] async fn test_inner_join() { let query = r#" @@ -630,6 +825,81 @@ mod tests { ") } + #[tokio::test] + async fn test_broadcast_join_annotation() { + let query = r#" + SELECT a."MinTemp", b."MaxTemp" + FROM weather a INNER JOIN weather b + ON a."RainToday" = b."RainToday" + "#; + let annotated = sql_to_annotated_broadcast(query, 4, 4, true).await; + assert_snapshot!(annotated, @r" + CoalesceBatchesExec: task_count=Desired(3) + HashJoinExec: task_count=Desired(3) + CoalescePartitionsExec: task_count=Desired(3), required_network_boundary=Broadcast + DataSourceExec: task_count=Desired(3) + DataSourceExec: task_count=Desired(3) + ") + } + + #[tokio::test] + async fn test_broadcast_downgrade_single_consumer() { + let query = r#" + SELECT a."MinTemp", b."MaxTemp" + FROM weather a INNER JOIN weather b + ON a."RainToday" = b."RainToday" + "#; + let annotated = sql_to_annotated_broadcast(query, 1, 1, true).await; + // With single consumer, broadcast should downgrade to coalesce + assert_snapshot!(annotated, @r" + CoalesceBatchesExec: task_count=Desired(1) + HashJoinExec: task_count=Desired(1) + DataSourceExec: task_count=Maximum(1), required_network_boundary=Coalesce + DataSourceExec: task_count=Desired(1) + ") + } + + #[tokio::test] + async fn test_broadcast_disabled_default() { + let query = r#" + SELECT a."MinTemp", b."MaxTemp" + FROM weather a INNER JOIN weather b + ON a."RainToday" = b."RainToday" + "#; + let annotated = sql_to_annotated_broadcast(query, 4, 4, false).await; + // With broadcast disabled, no Broadcast annotation should appear + assert!(!annotated.contains("Broadcast")); + assert_snapshot!(annotated, @r" + CoalesceBatchesExec: task_count=Maximum(1) + HashJoinExec: task_count=Maximum(1) + CoalescePartitionsExec: task_count=Maximum(1) + DataSourceExec: task_count=Maximum(1) + DataSourceExec: task_count=Maximum(1) + ") + } + + #[tokio::test] + async fn test_broadcast_multi_join_chain() { + let query = r#" + SELECT a."MinTemp", b."MaxTemp", c."Rainfall" + FROM weather a + INNER JOIN weather b ON a."RainToday" = b."RainToday" + INNER JOIN weather c ON b."RainToday" = c."RainToday" + "#; + let annotated = sql_to_annotated_broadcast(query, 4, 4, true).await; + assert_snapshot!(annotated, @r" + CoalesceBatchesExec: task_count=Maximum(1) + HashJoinExec: task_count=Maximum(1) + CoalescePartitionsExec: task_count=Maximum(1), required_network_boundary=Coalesce + CoalesceBatchesExec: task_count=Desired(3) + HashJoinExec: task_count=Desired(3) + CoalescePartitionsExec: task_count=Desired(3), required_network_boundary=Broadcast + DataSourceExec: task_count=Desired(3) + DataSourceExec: task_count=Desired(3) + DataSourceExec: task_count=Maximum(1) + ") + } + #[allow(clippy::type_complexity)] struct CallbackEstimator { f: Arc Option + Send + Sync>, @@ -673,6 +943,41 @@ mod tests { sql_to_annotated_with_options(query, move |b| b).await } + async fn sql_to_annotated_broadcast( + query: &str, + target_partitions: usize, + num_workers: usize, + broadcast_enabled: bool, + ) -> String { + let mut config = SessionConfig::new() + .with_target_partitions(target_partitions) + .with_information_schema(true); + + let d_cfg = DistributedConfig { + broadcast_joins_enabled: broadcast_enabled, + ..Default::default() + }; + config.set_distributed_option_extension(d_cfg).unwrap(); + + let state = SessionStateBuilder::new() + .with_default_features() + .with_config(config) + .with_distributed_worker_resolver(InMemoryWorkerResolver::new(num_workers)) + .build(); + + let ctx = SessionContext::new_with_state(state); + register_parquet_tables(&ctx).await.unwrap(); + + let df = ctx.sql(query).await.unwrap(); + + let annotated = annotate_plan( + df.create_physical_plan().await.unwrap(), + ctx.state_ref().read().config_options().as_ref(), + ) + .expect("failed to annotate plan"); + format!("{annotated:?}") + } + async fn sql_to_annotated_with_estimator( query: &str, estimator: impl Fn(&T) -> Option + Send + Sync + 'static, diff --git a/src/execution_plans/broadcast.rs b/src/execution_plans/broadcast.rs new file mode 100644 index 00000000..7077de5c --- /dev/null +++ b/src/execution_plans/broadcast.rs @@ -0,0 +1,187 @@ +use crate::common::require_one_child; +use datafusion::arrow::array::RecordBatch; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::error::Result; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, +}; +use futures::stream; +use std::any::Any; +use std::fmt::Formatter; +use std::sync::Arc; +use tokio::sync::OnceCell; + +/// [ExecutionPlan] that scales up partitions for network broadcasting. +/// +/// This plan takes N input partitions and exposes N*M output partitions, +/// where M is the number of consumer tasks. Each virtual partition `i` +/// returns the cached result of input partition `i % N`. +/// +/// This allows each consumer task to fetch a unique set of partition numbers, +/// the virtual partitions, while all receiving the same data via the actual partitions. +/// This structure maintains the invariant that each partition is executed exactly +/// once by the framework. +/// +/// Broadcast is used in a 1 to many context, like this: +/// ```text +/// ┌────────────────────────┐ ┌────────────────────────┐ ┌────────────────────────┐ ■ +/// │ NetworkBroadcastExec │ │ NetworkBroadcastExec │ ... │ NetworkBroadcastExec │ │ +/// │ (task 1) │ │ (task 2) │ │ (task M) │ Stage N+1 +/// └┬─┬─────┬───┬───────────┘ └───────┬─┬─────┬────┬───┘ └─────┬──────┬─────┬────┬┘ │ +/// │0│ │N-1│ │N│ │2N-1│ │(M-1)N│ │MN-1│ │ +/// └▲┘ ... └▲──┘ └▲┘ ... └──▲─┘ └───▲──┘ ... └──▲─┘ ■ +/// │ │ Populates │ │ │ │ +/// │ └────Cache Index ───┐ Cache Hit Cache Hit ┌──Cache Hit────┘ │ +/// │ N-1 │ Index 0 Index N-1 │ │ +/// └────Populates ─────┐ │ │ │ │ ┌───Cache Hit──┘ +/// Cache Index 0 │ │ │ │ │ │ +/// ┌┴┐ ... ┌┴──┐ ┌┴┐ ... ┌──┴─┐ ... ┌───┴──┐ ... ┌───┴┐ ■ +/// │0│ │N-1│ │N│ │2N-1│ │(M-1)N│ │MN-1│ │ +/// ┌┴─┴─────┴───┴──────┴─┴─────┴────┴───────────────────┴──────┴─────┴────┴┐ │ +/// │ BroadcastExec │ │ +/// │ ┌───────────────────────────┐ │ │ +/// │ │ Batch Cache │ │ Stage N +/// │ │┌─────────┐ ┌─────────┐│ │ │ +/// │ ││ index 0 │ ... │index N-1││ │ │ +/// │ │└─────────┘ └─────────┘│ │ │ +/// │ └───────────────────────────┘ │ │ +/// └───────────────────────────┬─┬──────────┬───┬──────────────────────────┘ ■ +/// │0│ │N-1│ +/// └▲┘ ... └─▲─┘ +/// │ │ +/// ┌──┘ └──┐ +/// │ │ ■ +/// ┌┴┐ ... ┌──┴┐ │ +/// │0│ │N-1│ Stage N-1 +/// ┌┴─┴───────────────┴───┴┐ │ +/// │Arc │ │ +/// └───────────────────────┘ ■ +/// ``` +/// +/// Notice that the first consumer task, [NetworkBroadcastExec] task 1, triggers the execution of +/// the operator below the [BroadCastExec] and populates each cache index with the repective +/// partition. Subsequent consumer tasks, rather than executing the same partitions, read the +/// data from the cache for each partition. +#[derive(Debug)] +pub struct BroadcastExec { + input: Arc, + consumer_task_count: usize, + properties: PlanProperties, + cached_batches: Vec>>>>, +} + +impl BroadcastExec { + pub fn new(input: Arc, consumer_task_count: usize) -> Self { + let input_partition_count = input.properties().partitioning.partition_count(); + let output_partition_count = input_partition_count * consumer_task_count; + + let properties = input + .properties() + .clone() + .with_partitioning(Partitioning::UnknownPartitioning(output_partition_count)); + + let cached_batches = (0..input_partition_count) + .map(|_| Arc::new(OnceCell::new())) + .collect(); + + Self { + input, + consumer_task_count, + properties, + cached_batches, + } + } + + pub fn input_partition_count(&self) -> usize { + self.input.properties().partitioning.partition_count() + } + + pub fn consumer_task_count(&self) -> usize { + self.consumer_task_count + } +} + +impl DisplayAs for BroadcastExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + let input_partition_count = self.input_partition_count(); + write!( + f, + "BroadcastExec: input_partitions={}, consumer_tasks={}, output_partitions={}", + input_partition_count, + self.consumer_task_count, + input_partition_count * self.consumer_task_count + ) + } +} + +impl ExecutionPlan for BroadcastExec { + fn name(&self) -> &str { + "BroadcastExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(Self::new( + require_one_child(children)?, + self.consumer_task_count, + ))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let real_partition = partition % self.input_partition_count(); + let cache = Arc::clone(&self.cached_batches[real_partition]); + let input = Arc::clone(&self.input); + let schema = self.schema(); + + // TODO: Stream batches as they're produced instead of collect-then-emit. Currently we + // wait for all batches before consumers receive any. Streaming would allow overlapping + // production with network transfer. + // + // Challenges: late subscribers must replay from buffer since tokio::sync::broadcast drops old messages, + // need proper error propagation to all consumers and backpressure handling. + let stream = futures::stream::once(async move { + let batches = cache + .get_or_try_init(|| async { + let stream = input.execute(real_partition, context)?; + let batches: Vec = + futures::TryStreamExt::try_collect(stream).await?; + Ok::<_, datafusion::error::DataFusionError>(Arc::new(batches)) + }) + .await?; + let batches = Arc::clone(batches); + let batches_vec: Vec = batches.iter().cloned().collect(); + Ok::<_, datafusion::error::DataFusionError>(stream::iter( + batches_vec.into_iter().map(Ok), + )) + }); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + schema, + futures::TryStreamExt::try_flatten(stream), + ))) + } + + fn schema(&self) -> SchemaRef { + self.input.schema() + } +} diff --git a/src/execution_plans/mod.rs b/src/execution_plans/mod.rs index 9dbb027c..cdf27ae1 100644 --- a/src/execution_plans/mod.rs +++ b/src/execution_plans/mod.rs @@ -1,14 +1,18 @@ +mod broadcast; mod children_isolator_union; mod common; mod distributed; mod metrics; +mod network_broadcast; mod network_coalesce; mod network_shuffle; mod partition_isolator; +pub use broadcast::BroadcastExec; pub use children_isolator_union::ChildrenIsolatorUnionExec; pub use distributed::DistributedExec; pub(crate) use metrics::MetricsWrapperExec; +pub use network_broadcast::NetworkBroadcastExec; pub use network_coalesce::NetworkCoalesceExec; pub use network_shuffle::NetworkShuffleExec; pub use partition_isolator::PartitionIsolatorExec; diff --git a/src/execution_plans/network_broadcast.rs b/src/execution_plans/network_broadcast.rs new file mode 100644 index 00000000..83eac032 --- /dev/null +++ b/src/execution_plans/network_broadcast.rs @@ -0,0 +1,339 @@ +use crate::DistributedConfig; +use crate::DistributedTaskContext; +use crate::common::require_one_child; +use crate::config_extension_ext::ContextGrpcMetadata; +use crate::distributed_planner::NetworkBoundary; +use crate::execution_plans::common::{manually_propagate_distributed_config, spawn_select_all}; +use crate::flight_service::DoGet; +use crate::metrics::MetricsCollectingStream; +use crate::metrics::proto::MetricsSetProto; +use crate::networking::get_distributed_channel_resolver; +use crate::protobuf::StageKey; +use crate::protobuf::{map_flight_to_datafusion_error, map_status_to_datafusion_error}; +use crate::stage::{MaybeEncodedPlan, Stage}; +use arrow_flight::Ticket; +use arrow_flight::decode::FlightRecordBatchStream; +use arrow_flight::error::FlightError; +use bytes::Bytes; +use dashmap::DashMap; +use datafusion::common::internal_datafusion_err; +use datafusion::error::DataFusionError; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, +}; +use futures::{StreamExt, TryFutureExt, TryStreamExt}; +use http::Extensions; +use prost::Message; +use std::any::Any; +use std::fmt::Formatter; +use std::sync::Arc; +use tonic::Request; +use tonic::metadata::MetadataMap; +use uuid::Uuid; + +/// Network boundary for broadcasting data to all consumer tasks. +/// +/// This operator works with [BroadcastExec] which scales up partitions so each +/// consumer task fetches a unique set of partition numbers. Each partition request +/// is sent to all stage tasks because PartitionIsolatorExec maps the same logical +/// partition to different actual data on each task. +/// +/// Here are some examples of how [NetworkBroadcastExec] distributes data: +/// +/// # 1 to many +/// +/// ```text +/// ┌────────────────────────┐ ┌────────────────────────┐ ■ +/// │ NetworkBroadcastExec │ │ NetworkBroadcastExec │ │ +/// │ (task 1) │ ... │ (task M) │ │ +/// │ │ │ │ Stage N +/// │ Populates Caches │ │ Populates Caches │ │ +/// └────────┬─┬┬─┬┬─┬───────┘ └────────┬─┬┬─┬┬─┬───────┘ │ +/// │0││1││2│ │0││1││2│ │ +/// └▲┘└▲┘└▲┘ └▲┘└▲┘└▲┘ ■ +/// │ │ │ │ │ │ +/// │ │ │ │ │ │ +/// │ │ │ │ │ │ +/// │ │ └─────────────┐ ┌──────────────────┘ │ │ +/// │ └─────────────┐ │ │ ┌───────────────┘ │ +/// └─────────────┐ │ │ │ │ ┌─────────────┘ +/// │ │ │ │ │ │ +/// ┌┴┐┌┴┐┌┴┐ ... ┌───┴┐┌───┴┐┌──┴─┐ +/// │1││2││3│ │NM-3││NM-2││NM-1│ ■ +/// ┌┴─┴┴─┴┴─┴─────┴────┴┴────┴┴────┴─┐ │ +/// │ BroadcastExec │ │ +/// │ ┌───────────────┐ │ Stage N-1 +/// │ │ Batch Cache │ │ │ +/// │ │ ┌─┐ ┌─┐ ┌─┐ │ │ │ +/// │ │ │0│ │1│ │2│ │ │ │ +/// │ │ └─┘ └─┘ └─┘ │ │ │ +/// │ └───────────────┘ │ │ +/// └───────────┬─┬─┬─┬─┬─┬───────────┘ │ +/// │0│ │1│ │2│ │ +/// └▲┘ └▲┘ └▲┘ ■ +/// │ │ │ +/// │ │ │ +/// │ │ │ +/// ┌┴┐ ┌┴┐ ┌┴┐ ■ +/// │0│ │1│ │2│ │ +/// ┌──────┴─┴─┴─┴─┴─┴──────┐ Stage N-2 +/// │Arc │ │ +/// │ (task 1) │ │ +/// └───────────────────────┘ ■ +/// ``` +/// +/// # Many to many +/// +/// ```text +/// ┌────────────────────────┐ ┌────────────────────────┐ ■ +/// │ NetworkBroadcastExec │ │ NetworkBroadcastExec │ │ +/// │ (task 1) │ │ (task M) │ │ +/// │ │ ... │ │ Stage N +/// │ Populates Caches │ │ Cache Hits │ │ +/// └────────┬─┬┬─┬┬─┬───────┘ └────────┬─┬┬─┬┬─┬───────┘ │ +/// │0││1││2│ │0││1││2│ │ +/// └▲┘└▲┘└▲┘ └▲┘└▲┘└▲┘ ■ +/// │ │ │ │ │ │ +/// ┌──────────┴──┼──┼────────────────────────────────┐ │ │ │ +/// │ ┌──────────┴──┼────────────────────────────────┼──┐ │ │ │ +/// │ │ ┌──────────┴────────────────────────────────┼──┼──┐ │ │ │ +/// │ │ │ │ │ │ │ │ │ +/// │ │ │ ┌─────────────────────────────────┼──┼──┼────┴──┼─┐│ +/// │ │ │ │ ┌───────────────────────────┼──┼──┼───────┴─┼┼─────┐ +/// │ │ │ │ │ ┌─────────────────────┼──┼──┼─────────┼┴─────┼────┐ +/// │ │ │ │ │ │ │ │ │ │ │ │ +/// ┌┴┐┌┴┐┌┴┐ ... ┌──┴─┐┌──┴─┐┌──┴─┐ ┌┴┐┌┴┐┌┴┐ ... ┌──┴─┐┌───┴┐┌──┴─┐ ■ +/// │0││1││2│ │3M-3││3M-2││3M-1│ │0││1││2│ │3M-3││3M-2││3M-1│ │ +/// ┌┴─┴┴─┴┴─┴─────┴────┴┴────┴┴────┴┐ ┌┴─┴┴─┴┴─┴─────┴────┴┴────┴┴────┴┐ │ +/// │ BroadcastExec │ │ BroadcastExec │ │ +/// │ ┌───────────────┐ │ │ ┌───────────────┐ │ │ +/// │ │ Batch Cache │ │ │ │ Batch Cache │ │ │ +/// │ │ ┌─┐ ┌─┐ ┌─┐ │ │ ... │ │ ┌─┐ ┌─┐ ┌─┐ │ │ Stage N-1 +/// │ │ │0│ │1│ │2│ │ │ │ │ │0│ │1│ │2│ │ │ │ +/// │ │ └─┘ └─┘ └─┘ │ │ │ │ └─┘ └─┘ └─┘ │ │ │ +/// │ └───────────────┘ │ │ └───────────────┘ │ │ +/// └───────────┬─┬─┬─┬─┬─┬──────────┘ └───────────┬─┬─┬─┬─┬─┬──────────┘ │ +/// │0│ │1│ │2│ │0│ │1│ │2│ │ +/// └▲┘ └▲┘ └▲┘ └▲┘ └▲┘ └▲┘ ■ +/// │ │ │ │ │ │ +/// │ │ │ │ │ │ +/// │ │ │ │ │ │ +/// ┌┴┐ ┌┴┐ ┌┴┐ ┌┴┐ ┌┴┐ ┌┴┐ ■ +/// │0│ │1│ │2│ │0│ │1│ │2│ │ +/// ┌──────┴─┴─┴─┴─┴─┴──────┐ ┌──────┴─┴─┴─┴─┴─┴──────┐ Stage N-2 +/// │Arc │ ... │Arc │ │ +/// │ (task 1) │ │ (task N) │ │ +/// └───────────────────────┘ └───────────────────────┘ ■ +/// ``` +/// +/// Notice in this diagram that each [NetworkBroadcastExec] sends a request to fetch data from each +/// [BroadcastExec] in the stage below per partition. This is because each [BroadcastExec] has its +/// own cache which contains partial results for the partition. It is the [NetworkBroadcastExec]'s +/// job to merge these partial partitions to then broadcast complete data to the consumers. +#[derive(Debug, Clone)] +pub struct NetworkBroadcastExec { + pub(crate) properties: PlanProperties, + pub(crate) input_stage: Stage, + pub(crate) metrics_collection: Arc>>, +} + +impl NetworkBroadcastExec { + /// Creates a [NetworkBroadcastExec]. + /// + /// Use [Self::with_inner_broadcast] to create the full broadcast stack with + /// BroadcastExec for caching. + pub fn try_new( + input: Arc, + query_id: Uuid, + stage_num: usize, + input_task_count: usize, + ) -> Result { + let total_partitions = input.properties().partitioning.partition_count(); + let input_partition_count = + if let Some(broadcast) = input.as_any().downcast_ref::() { + broadcast.input_partition_count() + } else { + total_partitions + }; + let input_stage = Stage::new(query_id, stage_num, input.clone(), input_task_count); + let properties = input + .properties() + .clone() + .with_partitioning(Partitioning::UnknownPartitioning(input_partition_count)); + + Ok(Self { + properties, + input_stage, + metrics_collection: Default::default(), + }) + } + + /// Builds inner broadcast stack: BroadcastExec -> NetworkBroadcastExec + /// + /// This is the standard way to create a broadcast network boundary. The stack: + /// 1. [BroadcastExec]: Caches batches and scales partitions + /// 2. [NetworkBroadcastExec]: Fetches cached data from producer tasks + pub fn with_inner_broadcast( + input: Arc, + query_id: Uuid, + stage_num: usize, + consumer_task_count: usize, + input_task_count: usize, + ) -> Result, DataFusionError> { + let broadcast_exec = Arc::new(super::BroadcastExec::new(input, consumer_task_count)); + Ok(Arc::new(Self::try_new( + broadcast_exec, + query_id, + stage_num, + input_task_count, + )?)) + } +} + +impl NetworkBoundary for NetworkBroadcastExec { + fn with_input_stage( + &self, + input_stage: Stage, + ) -> Result, DataFusionError> { + let mut self_clone = self.clone(); + self_clone.input_stage = input_stage; + Ok(Arc::new(self_clone)) + } + + fn input_stage(&self) -> &Stage { + &self.input_stage + } +} + +impl DisplayAs for NetworkBroadcastExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + let input_tasks = self.input_stage.tasks.len(); + let stage = self.input_stage.num; + let consumer_partitions = self.properties.partitioning.partition_count(); + let stage_partitions = self + .input_stage + .plan + .decoded() + .map(|p| p.properties().partitioning.partition_count()) + .unwrap_or(0); + write!( + f, + "[Stage {stage}] => NetworkBroadcastExec: partitions_per_consumer={consumer_partitions}, stage_partitions={stage_partitions}, input_tasks={input_tasks}", + ) + } +} + +impl ExecutionPlan for NetworkBroadcastExec { + fn name(&self) -> &str { + "NetworkBroadcastExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + match &self.input_stage.plan { + MaybeEncodedPlan::Decoded(plan) => vec![plan], + MaybeEncodedPlan::Encoded(_) => vec![], + } + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + let new_child = require_one_child(children)?; + let mut new_stage = self.input_stage.clone(); + new_stage.plan = MaybeEncodedPlan::Decoded(new_child); + + Ok(Arc::new(Self { + properties: self.properties.clone(), + input_stage: new_stage, + metrics_collection: self.metrics_collection.clone(), + })) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let task_context = DistributedTaskContext::from_ctx(&context); + let task_index = task_context.task_index; + + let channel_resolver = get_distributed_channel_resolver(context.as_ref()); + let d_cfg = DistributedConfig::from_config_options(context.session_config().options())?; + let retrieve_metrics = d_cfg.collect_metrics; + + let input_stage = &self.input_stage; + let encoded_input_plan = input_stage.plan.encoded()?; + let input_stage_tasks = input_stage.tasks.to_vec(); + let input_task_count = input_stage_tasks.len(); + let input_stage_num = input_stage.num as u64; + let query_id = Bytes::from(input_stage.query_id.as_bytes().to_vec()); + let input_partition_count = self.properties.partitioning.partition_count(); + let stage_partition = task_index * input_partition_count + partition; + + let context_headers = ContextGrpcMetadata::headers_from_ctx(&context); + let context_headers = manually_propagate_distributed_config(context_headers, d_cfg); + let metrics_collection = self.metrics_collection.clone(); + + let stream = input_stage_tasks.into_iter().enumerate().map(|(i, task)| { + let channel_resolver = Arc::clone(&channel_resolver); + let metrics_collection_capture = metrics_collection.clone(); + + let ticket = Request::from_parts( + MetadataMap::from_headers(context_headers.clone()), + Extensions::default(), + Ticket { + ticket: DoGet { + plan_proto: encoded_input_plan.clone(), + target_partition: stage_partition as u64, + stage_key: Some(StageKey::new(query_id.clone(), input_stage_num, i as u64)), + target_task_index: i as u64, + target_task_count: input_task_count as u64, + } + .encode_to_vec() + .into(), + }, + ); + + async move { + let url = task.url.ok_or(internal_datafusion_err!( + "NetworkBroadcastExec: task is unassigned" + ))?; + + let mut client = channel_resolver.get_flight_client_for_url(&url).await?; + let stream = client + .do_get(ticket) + .await + .map_err(map_status_to_datafusion_error)? + .into_inner() + .map_err(|err| FlightError::Tonic(Box::new(err))); + + let stream = if retrieve_metrics { + MetricsCollectingStream::new(stream, metrics_collection_capture).left_stream() + } else { + stream.right_stream() + }; + + Ok(FlightRecordBatchStream::new_from_flight_data(stream) + .map_err(map_flight_to_datafusion_error)) + } + .try_flatten_stream() + .boxed() + }); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + spawn_select_all(stream.collect(), Arc::clone(context.memory_pool())), + ))) + } +} diff --git a/src/lib.rs b/src/lib.rs index 50e1ef65..ea40e603 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,7 +20,8 @@ pub use distributed_planner::{ TaskCountAnnotation, TaskEstimation, TaskEstimator, }; pub use execution_plans::{ - DistributedExec, NetworkCoalesceExec, NetworkShuffleExec, PartitionIsolatorExec, + BroadcastExec, DistributedExec, NetworkBroadcastExec, NetworkCoalesceExec, NetworkShuffleExec, + PartitionIsolatorExec, }; pub use flight_service::{ DefaultSessionBuilder, MappedWorkerSessionBuilder, MappedWorkerSessionBuilderExt, Worker, diff --git a/src/metrics/task_metrics_collector.rs b/src/metrics/task_metrics_collector.rs index f2870e77..58cf00bb 100644 --- a/src/metrics/task_metrics_collector.rs +++ b/src/metrics/task_metrics_collector.rs @@ -1,3 +1,4 @@ +use crate::NetworkBroadcastExec; use crate::execution_plans::NetworkCoalesceExec; use crate::execution_plans::NetworkShuffleExec; use crate::metrics::proto::MetricsSetProto; @@ -44,6 +45,8 @@ impl TreeNodeRewriter for TaskMetricsCollector { Some(Arc::clone(&node.metrics_collection)) } else if let Some(node) = plan.as_any().downcast_ref::() { Some(Arc::clone(&node.metrics_collection)) + } else if let Some(node) = plan.as_any().downcast_ref::() { + Some(Arc::clone(&node.metrics_collection)) } else { None }; diff --git a/src/protobuf/distributed_codec.rs b/src/protobuf/distributed_codec.rs index 7d253313..1307c277 100644 --- a/src/protobuf/distributed_codec.rs +++ b/src/protobuf/distributed_codec.rs @@ -1,5 +1,7 @@ use super::get_distributed_user_codecs; -use crate::execution_plans::{ChildrenIsolatorUnionExec, NetworkCoalesceExec}; +use crate::execution_plans::{ + BroadcastExec, ChildrenIsolatorUnionExec, NetworkBroadcastExec, NetworkCoalesceExec, +}; use crate::stage::{ExecutionTask, MaybeEncodedPlan, Stage}; use crate::{DistributedTaskContext, NetworkBoundary}; use crate::{NetworkShuffleExec, PartitionIsolatorExec}; @@ -153,6 +155,46 @@ impl PhysicalExtensionCodec for DistributedCodec { n_tasks as usize, ))) } + DistributedExecNode::NetworkBroadcast(NetworkBroadcastExecProto { + schema, + partitioning, + input_stage, + }) => { + let schema: Schema = schema + .as_ref() + .map(|s| s.try_into()) + .ok_or(proto_error("NetworkBroadcastExec is missing schema"))??; + + let partitioning = parse_protobuf_partitioning( + partitioning.as_ref(), + ctx, + &schema, + &DistributedCodec {}, + )? + .ok_or(proto_error("NetworkBroadcastExec is missing partitioning"))?; + + Ok(Arc::new(new_network_broadcast_exec( + partitioning, + Arc::new(schema), + parse_stage_proto(input_stage, inputs)?, + ))) + } + DistributedExecNode::Broadcast(BroadcastExecProto { + consumer_task_count, + }) => { + if inputs.len() != 1 { + return Err(proto_error(format!( + "BroadcastExec expects exactly one child, got {}", + inputs.len() + ))); + } + + let child = inputs.first().unwrap(); + Ok(Arc::new(BroadcastExec::new( + child.clone(), + consumer_task_count as usize, + ))) + } DistributedExecNode::ChildrenIsolatorUnion(ChildrenIsolatorUnionExecProto { partition_count, task_idx_map, @@ -248,6 +290,31 @@ impl PhysicalExtensionCodec for DistributedCodec { node: Some(DistributedExecNode::PartitionIsolator(inner)), }; + wrapper.encode(buf).map_err(|e| proto_error(format!("{e}"))) + } else if let Some(node) = node.as_any().downcast_ref::() { + let inner = NetworkBroadcastExecProto { + schema: Some(node.schema().try_into()?), + partitioning: Some(serialize_partitioning( + node.properties().output_partitioning(), + &DistributedCodec {}, + )?), + input_stage: Some(encode_stage_proto(node.input_stage())?), + }; + + let wrapper = DistributedExecProto { + node: Some(DistributedExecNode::NetworkBroadcast(inner)), + }; + + wrapper.encode(buf).map_err(|e| proto_error(format!("{e}"))) + } else if let Some(node) = node.as_any().downcast_ref::() { + let inner = BroadcastExecProto { + consumer_task_count: node.consumer_task_count() as u64, + }; + + let wrapper = DistributedExecProto { + node: Some(DistributedExecNode::Broadcast(inner)), + }; + wrapper.encode(buf).map_err(|e| proto_error(format!("{e}"))) } else if let Some(node) = node.as_any().downcast_ref::() { let inner = ChildrenIsolatorUnionExecProto { @@ -331,7 +398,7 @@ pub struct ExecutionTaskProto { #[derive(Clone, PartialEq, ::prost::Message)] pub struct DistributedExecProto { - #[prost(oneof = "DistributedExecNode", tags = "1, 2, 3, 4, 5")] + #[prost(oneof = "DistributedExecNode", tags = "1, 2, 3, 4, 5, 6")] pub node: Option, } @@ -345,6 +412,10 @@ pub enum DistributedExecNode { PartitionIsolator(PartitionIsolatorExecProto), #[prost(message, tag = "4")] ChildrenIsolatorUnion(ChildrenIsolatorUnionExecProto), + #[prost(message, tag = "5")] + NetworkBroadcast(NetworkBroadcastExecProto), + #[prost(message, tag = "6")] + Broadcast(BroadcastExecProto), } #[derive(Clone, PartialEq, ::prost::Message)] @@ -437,6 +508,39 @@ fn new_network_coalesce_tasks_exec( } } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct NetworkBroadcastExecProto { + #[prost(message, optional, tag = "1")] + schema: Option, + #[prost(message, optional, tag = "2")] + partitioning: Option, + #[prost(message, optional, tag = "3")] + input_stage: Option, +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BroadcastExecProto { + #[prost(uint64, tag = "1")] + pub consumer_task_count: u64, +} + +fn new_network_broadcast_exec( + partitioning: Partitioning, + schema: SchemaRef, + input_stage: Stage, +) -> NetworkBroadcastExec { + NetworkBroadcastExec { + properties: PlanProperties::new( + EquivalenceProperties::new(schema), + partitioning, + EmissionType::Incremental, + Boundedness::Bounded, + ), + input_stage, + metrics_collection: Default::default(), + } +} + fn encode_tasks(tasks: &[ExecutionTask]) -> Vec { tasks .iter() diff --git a/tests/tpcds_correctness_test.rs b/tests/tpcds_correctness_test.rs index 51fda085..41f70f40 100644 --- a/tests/tpcds_correctness_test.rs +++ b/tests/tpcds_correctness_test.rs @@ -565,6 +565,13 @@ mod tests { .with_distributed_files_per_task(FILES_PER_TASK)? .with_distributed_cardinality_effect_task_scale_factor(CARDINALITY_TASK_COUNT_FACTOR)?; + // Enable broadcast joins if BROADCAST_JOINS_ENABLED env var is set + let d_ctx = if std::env::var("BROADCAST_JOINS_ENABLED").is_ok() { + d_ctx.with_distributed_broadcast_joins_enabled(true)? + } else { + d_ctx + }; + tpcds::register_tables(&s_ctx, &data_dir).await?; tpcds::register_tables(&d_ctx, &data_dir).await?; diff --git a/tests/tpch_correctness_test.rs b/tests/tpch_correctness_test.rs index 012dea32..d3e6089f 100644 --- a/tests/tpch_correctness_test.rs +++ b/tests/tpch_correctness_test.rs @@ -2,9 +2,9 @@ mod tests { use datafusion::physical_plan::execute_stream; use datafusion::prelude::SessionContext; - use datafusion_distributed::DefaultSessionBuilder; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::test_utils::tpch; + use datafusion_distributed::{DefaultSessionBuilder, DistributedExt}; use futures::TryStreamExt; use std::error::Error; use std::fmt::Display; @@ -137,6 +137,13 @@ mod tests { async fn test_tpch_query(sql: String) -> Result<(), Box> { let (ctx, _guard) = start_localhost_context(4, DefaultSessionBuilder).await; + // Enable broadcast joins if BROADCAST_JOINS_ENABLED env var is set + let ctx = if std::env::var("BROADCAST_JOINS_ENABLED").is_ok() { + ctx.with_distributed_broadcast_joins_enabled(true)? + } else { + ctx + }; + let results_d = run_tpch_query(ctx, sql.clone()).await?; let results_s = run_tpch_query(SessionContext::new(), sql).await?;