diff --git a/src/distributed_planner/distributed_physical_optimizer_rule.rs b/src/distributed_planner/distributed_physical_optimizer_rule.rs index 970e9f0..eb0f6bd 100644 --- a/src/distributed_planner/distributed_physical_optimizer_rule.rs +++ b/src/distributed_planner/distributed_physical_optimizer_rule.rs @@ -1,10 +1,8 @@ use crate::distributed_planner::distributed_config::DistributedConfig; use crate::distributed_planner::distributed_plan_error::get_distribute_plan_err; use crate::distributed_planner::task_estimator::TaskEstimator; -use crate::distributed_planner::{ - DistributedPlanError, NetworkBoundaryExt, limit_tasks_err, non_distributable_err, -}; -use crate::execution_plans::{DistributedExec, NetworkCoalesceExec}; +use crate::distributed_planner::{DistributedPlanError, NetworkBoundaryExt, non_distributable_err}; +use crate::execution_plans::{DistributedExec, NetworkBroadcastExec, NetworkCoalesceExec}; use crate::stage::Stage; use crate::{ChannelResolver, NetworkShuffleExec, PartitionIsolatorExec}; use datafusion::common::plan_err; @@ -48,7 +46,7 @@ use uuid::Uuid; /// /// /// 2. Break down the plan into stages -/// +/// /// Based on the network boundaries ([NetworkShuffleExec], [NetworkCoalesceExec], ...) placed in /// the plan by the first step, the plan is divided into stages and tasks are assigned to each /// stage. @@ -408,14 +406,29 @@ fn _distribute_plan_inner( n_tasks: usize, ) -> Result { let mut distributed = plan.clone().transform_down(|plan| { - // We cannot break down CollectLeft hash joins into more than 1 task, as these need - // a full materialized build size with all the data in it. - // - // Maybe in the future these can be broadcast joins? - if let Some(node) = plan.as_any().downcast_ref::() { - if n_tasks > 1 && node.mode == PartitionMode::CollectLeft { - return Err(limit_tasks_err(1)); - } + // Handle CollectLeft hash joins by injecting NetworkBroadcastExec on the left side + // This allows the small left table to be broadcast to all workers, enabling parallel execution + if let Some(node) = plan.as_any().downcast_ref::() + && n_tasks > 1 && node.mode == PartitionMode::CollectLeft + { + let broadcast_left = Arc::new(NetworkBroadcastExec::new( + node.left().clone(), + 1, + )); + + // Reconstruct HashJoinExec with broadcast left side + let new_join = HashJoinExec::try_new( + broadcast_left, + node.right().clone(), + node.on().to_vec(), + node.filter().cloned(), + node.join_type(), + node.projection.clone(), + *node.partition_mode(), + node.null_equality(), + )?; + + return Ok(Transformed::yes(Arc::new(new_join))); } // We cannot distribute [StreamingTableExec] nodes, so abort distribution. @@ -733,11 +746,21 @@ mod tests { }) .await; assert_snapshot!(plan, @r" - CoalesceBatchesExec: target_batch_size=8192 - HashJoinExec: mode=CollectLeft, join_type=Left, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2] - CoalescePartitionsExec - DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet - DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet + ┌───── DistributedExec ── Tasks: t0:[p0] + │ CoalescePartitionsExec + │ [Stage 2] => NetworkCoalesceExec: output_partitions=3, input_tasks=3 + └────────────────────────────────────────────────── + ┌───── Stage 2 ── Tasks: t0:[p0] t1:[p1] t2:[p2] + │ CoalesceBatchesExec: target_batch_size=8192 + │ HashJoinExec: mode=CollectLeft, join_type=Left, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2] + │ NetworkBroadcastExec: [Stage 1] (1 tasks) + │ PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0] + │ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet + └────────────────────────────────────────────────── + ┌───── Stage 1 ── Tasks: t0:[p0] + │ CoalescePartitionsExec + │ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet + └────────────────────────────────────────────────── "); } @@ -773,39 +796,45 @@ mod tests { assert_snapshot!(plan, @r" ┌───── DistributedExec ── Tasks: t0:[p0] │ CoalescePartitionsExec - │ CoalesceBatchesExec: target_batch_size=8192 - │ HashJoinExec: mode=CollectLeft, join_type=Left, on=[(RainTomorrow@1, RainTomorrow@1)], projection=[MinTemp@0, MaxTemp@2] - │ CoalescePartitionsExec - │ [Stage 2] => NetworkCoalesceExec: output_partitions=8, input_tasks=2 - │ ProjectionExec: expr=[avg(weather.MaxTemp)@1 as MaxTemp, RainTomorrow@0 as RainTomorrow] - │ AggregateExec: mode=FinalPartitioned, gby=[RainTomorrow@0 as RainTomorrow], aggr=[avg(weather.MaxTemp)] - │ [Stage 3] => NetworkShuffleExec: output_partitions=4, input_tasks=3 + │ [Stage 5] => NetworkCoalesceExec: output_partitions=8, input_tasks=2 └────────────────────────────────────────────────── - ┌───── Stage 2 ── Tasks: t0:[p0..p3] t1:[p0..p3] - │ ProjectionExec: expr=[avg(weather.MinTemp)@1 as MinTemp, RainTomorrow@0 as RainTomorrow] - │ AggregateExec: mode=FinalPartitioned, gby=[RainTomorrow@0 as RainTomorrow], aggr=[avg(weather.MinTemp)] - │ [Stage 1] => NetworkShuffleExec: output_partitions=4, input_tasks=3 + ┌───── Stage 5 ── Tasks: t0:[p0..p3] t1:[p4..p7] + │ CoalesceBatchesExec: target_batch_size=8192 + │ HashJoinExec: mode=CollectLeft, join_type=Left, on=[(RainTomorrow@1, RainTomorrow@1)], projection=[MinTemp@0, MaxTemp@2] + │ NetworkBroadcastExec: [Stage 3] (1 tasks) + │ ProjectionExec: expr=[avg(weather.MaxTemp)@1 as MaxTemp, RainTomorrow@0 as RainTomorrow] + │ AggregateExec: mode=FinalPartitioned, gby=[RainTomorrow@0 as RainTomorrow], aggr=[avg(weather.MaxTemp)] + │ [Stage 4] => NetworkShuffleExec: output_partitions=4, input_tasks=3 └────────────────────────────────────────────────── - ┌───── Stage 1 ── Tasks: t0:[p0..p7] t1:[p0..p7] t2:[p0..p7] + ┌───── Stage 3 ── Tasks: t0:[p0] + │ CoalescePartitionsExec + │ [Stage 2] => NetworkCoalesceExec: output_partitions=8, input_tasks=2 + └────────────────────────────────────────────────── + ┌───── Stage 2 ── Tasks: t0:[p0..p3] t1:[p0..p3] + │ ProjectionExec: expr=[avg(weather.MinTemp)@1 as MinTemp, RainTomorrow@0 as RainTomorrow] + │ AggregateExec: mode=FinalPartitioned, gby=[RainTomorrow@0 as RainTomorrow], aggr=[avg(weather.MinTemp)] + │ [Stage 1] => NetworkShuffleExec: output_partitions=4, input_tasks=3 + └────────────────────────────────────────────────── + ┌───── Stage 1 ── Tasks: t0:[p0..p7] t1:[p0..p7] t2:[p0..p7] + │ CoalesceBatchesExec: target_batch_size=8192 + │ RepartitionExec: partitioning=Hash([RainTomorrow@0], 8), input_partitions=4 + │ AggregateExec: mode=Partial, gby=[RainTomorrow@1 as RainTomorrow], aggr=[avg(weather.MinTemp)] + │ CoalesceBatchesExec: target_batch_size=8192 + │ FilterExec: RainToday@1 = yes, projection=[MinTemp@0, RainTomorrow@2] + │ RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 + │ PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0] + │ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday, RainTomorrow], file_type=parquet, predicate=RainToday@1 = yes, pruning_predicate=RainToday_null_count@2 != row_count@3 AND RainToday_min@0 <= yes AND yes <= RainToday_max@1, required_guarantees=[RainToday in (yes)] + └────────────────────────────────────────────────── + ┌───── Stage 4 ── Tasks: t0:[p0..p7] t1:[p0..p7] t2:[p0..p7] │ CoalesceBatchesExec: target_batch_size=8192 │ RepartitionExec: partitioning=Hash([RainTomorrow@0], 8), input_partitions=4 - │ AggregateExec: mode=Partial, gby=[RainTomorrow@1 as RainTomorrow], aggr=[avg(weather.MinTemp)] + │ AggregateExec: mode=Partial, gby=[RainTomorrow@1 as RainTomorrow], aggr=[avg(weather.MaxTemp)] │ CoalesceBatchesExec: target_batch_size=8192 - │ FilterExec: RainToday@1 = yes, projection=[MinTemp@0, RainTomorrow@2] + │ FilterExec: RainToday@1 = no, projection=[MaxTemp@0, RainTomorrow@2] │ RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 │ PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0] - │ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday, RainTomorrow], file_type=parquet, predicate=RainToday@1 = yes, pruning_predicate=RainToday_null_count@2 != row_count@3 AND RainToday_min@0 <= yes AND yes <= RainToday_max@1, required_guarantees=[RainToday in (yes)] + │ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday, RainTomorrow], file_type=parquet, predicate=RainToday@1 = no, pruning_predicate=RainToday_null_count@2 != row_count@3 AND RainToday_min@0 <= no AND no <= RainToday_max@1, required_guarantees=[RainToday in (no)] └────────────────────────────────────────────────── - ┌───── Stage 3 ── Tasks: t0:[p0..p3] t1:[p0..p3] t2:[p0..p3] - │ CoalesceBatchesExec: target_batch_size=8192 - │ RepartitionExec: partitioning=Hash([RainTomorrow@0], 4), input_partitions=4 - │ AggregateExec: mode=Partial, gby=[RainTomorrow@1 as RainTomorrow], aggr=[avg(weather.MaxTemp)] - │ CoalesceBatchesExec: target_batch_size=8192 - │ FilterExec: RainToday@1 = no, projection=[MaxTemp@0, RainTomorrow@2] - │ RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 - │ PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0] - │ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday, RainTomorrow], file_type=parquet, predicate=RainToday@1 = no, pruning_predicate=RainToday_null_count@2 != row_count@3 AND RainToday_min@0 <= no AND no <= RainToday_max@1, required_guarantees=[RainToday in (no)] - └────────────────────────────────────────────────── "); } diff --git a/src/distributed_planner/network_boundary.rs b/src/distributed_planner/network_boundary.rs index 8053a53..17febc5 100644 --- a/src/distributed_planner/network_boundary.rs +++ b/src/distributed_planner/network_boundary.rs @@ -1,3 +1,4 @@ +use crate::execution_plans::NetworkBroadcastExec; use crate::{NetworkCoalesceExec, NetworkShuffleExec, Stage}; use datafusion::common::plan_err; use datafusion::physical_plan::ExecutionPlan; @@ -82,6 +83,8 @@ impl NetworkBoundaryExt for dyn ExecutionPlan { Some(node) } else if let Some(node) = self.as_any().downcast_ref::() { Some(node) + } else if let Some(node) = self.as_any().downcast_ref::() { + Some(node) } else { None } diff --git a/src/execution_plans/mod.rs b/src/execution_plans/mod.rs index 377c273..e96d7e8 100644 --- a/src/execution_plans/mod.rs +++ b/src/execution_plans/mod.rs @@ -1,12 +1,14 @@ mod common; mod distributed; mod metrics; +mod network_broadcast; mod network_coalesce; mod network_shuffle; mod partition_isolator; pub use distributed::DistributedExec; pub(crate) use metrics::MetricsWrapperExec; +pub use network_broadcast::{NetworkBroadcastExec, NetworkBroadcastReady}; pub use network_coalesce::{NetworkCoalesceExec, NetworkCoalesceReady}; pub use network_shuffle::{NetworkShuffleExec, NetworkShuffleReadyExec}; pub use partition_isolator::PartitionIsolatorExec; diff --git a/src/execution_plans/network_broadcast.rs b/src/execution_plans/network_broadcast.rs new file mode 100644 index 0000000..c7d7c5d --- /dev/null +++ b/src/execution_plans/network_broadcast.rs @@ -0,0 +1,332 @@ +use crate::ChannelResolver; +use crate::channel_resolver_ext::get_distributed_channel_resolver; +use crate::config_extension_ext::ContextGrpcMetadata; +use crate::distributed_planner::{InputStageInfo, NetworkBoundary}; +use crate::execution_plans::common::require_one_child; +use crate::flight_service::DoGet; +use crate::metrics::MetricsCollectingStream; +use crate::metrics::proto::MetricsSetProto; +use crate::protobuf::{StageKey, map_flight_to_datafusion_error, map_status_to_datafusion_error}; +use crate::stage::{MaybeEncodedPlan, Stage}; +use arrow_flight::Ticket; +use arrow_flight::decode::FlightRecordBatchStream; +use arrow_flight::error::FlightError; +use bytes::Bytes; +use dashmap::DashMap; +use datafusion::common::{exec_err, internal_datafusion_err, plan_err}; +use datafusion::error::DataFusionError; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; +use futures::{StreamExt, TryFutureExt, TryStreamExt}; +use http::Extensions; +use prost::Message; +use std::any::Any; +use std::sync::Arc; +use tonic::Request; +use tonic::metadata::MetadataMap; + +/// [ExecutionPlan] that broadcasts data from a single task to multiple tasks across the network. +/// +/// This operator is used when a small dataset needs to be replicated to all workers in the next +/// stage. The most common use case is hash joins with `CollectLeft` partition mode, where the +/// small build side (left table) is collected into a single partition and then broadcast to all +/// workers processing the large probe side (right table). +/// +/// Unlike [NetworkShuffleExec] which redistributes data across tasks, [NetworkBroadcastExec] +/// replicates the entire input to each task in the next stage. This allows parallel execution +/// of operations that would otherwise be forced to run single-threaded. +/// +/// 1 to many (broadcast) +/// +/// ┌───────────────────────────┐ ┌───────────────────────────┐ ┌───────────────────────────┐ ■ +/// │ NetworkBroadcastExec │ │ NetworkBroadcastExec │ │ NetworkBroadcastExec │ │ +/// │ (task 1) │ │ (task 2) │ │ (task 3) │ │ +/// │ (full copy) │ │ (full copy) │ │ (full copy) │ Stage N+1 +/// └───────────────────────────┘ └───────────────────────────┘ └───────────────────────────┘ │ +/// ▲ ▲ ▲ │ +/// │ │ │ ■ +/// └──────────────────────────────┴──────────────────────────────┘ +/// │ ■ +/// ┌───────────────────────────┐ │ +/// │ CoalesceExec or │ │ +/// │ HashJoinExec build │ Stage N +/// │ (task 1) │ │ +/// └───────────────────────────┘ │ +/// ■ +/// +/// Broadcast join example (CollectLeft hash join) +/// +/// Stage N+1: Hash Join (3 tasks running in parallel) +/// ┌──────────────────────┐ ┌──────────────────────┐ ┌──────────────────────┐ +/// │ HashJoinExec t1 │ │ HashJoinExec t2 │ │ HashJoinExec t3 │ +/// │ left: small (bcast) │ │ left: small (bcast) │ │ left: small (bcast) │ +/// │ right: large (p1) │ │ right: large (p2) │ │ right: large (p3) │ +/// └──┬─────────────────┬─┘ └──┬─────────────────┬─┘ └──┬─────────────────┬─┘ +/// │ │ │ │ │ │ +/// ▼ │ ▼ │ ▼ │ +/// NetworkBroadcast │ NetworkBroadcast │ NetworkBroadcast │ +/// (full copy) │ (full copy) │ (full copy) │ +/// │ │ │ │ │ │ +/// │ │ │ │ │ │ +/// │ ▼ │ ▼ │ ▼ +/// │ Large table │ Large table │ Large table +/// │ partition 1 │ partition 2 │ partition 3 +/// │ │ │ │ │ │ +/// └─────────────────┼────────┴─────────────────┼────────┴─────────────────┘ +/// │ │ +/// Stage N: Small table collected + Large table partitioned +/// │ │ +/// ┌───────▼──────┐ ┌────────▼────────┐ +/// │ Small table │ │ Large table │ +/// │ (1 task, │ │ (3 partitions) │ +/// │ collected) │ └─────────────────┘ +/// └──────────────┘ +/// +/// The communication between two stages across a [NetworkBroadcastExec] has these characteristics: +/// +/// - The input stage typically has 1 task containing the collected/small dataset +/// - Each task in Stage N+1 receives a complete copy of the data from Stage N +/// - This enables parallel execution while ensuring all tasks have access to the full dataset +/// - Commonly used for broadcast joins where the build side is small enough to replicate +/// +/// This node has two variants: +/// 1. Pending: acts as a placeholder for the distributed optimization step to mark it as ready. +/// 2. Ready: runs within a distributed stage and queries the input stage over the network +/// using Arrow Flight, broadcasting the data to all tasks. +/// +/// [NetworkShuffleExec]: crate::execution_plans::NetworkShuffleExec +#[derive(Debug, Clone)] +pub enum NetworkBroadcastExec { + Pending(NetworkBroadcastPending), + Ready(NetworkBroadcastReady), +} + +/// Placeholder version of the [NetworkBroadcastExec] node. It acts as a marker for the +/// distributed optimization step, which will replace it with the appropriate +/// [NetworkBroadcastReady] node. +#[derive(Debug, Clone)] +pub struct NetworkBroadcastPending { + properties: PlanProperties, + input_tasks: usize, + input: Arc, +} + +/// Ready version of the [NetworkBroadcastExec] node. This node is created by: +/// - the distributed optimization step based on an original [NetworkBroadcastPending] +/// - deserialized from a protobuf plan sent over the network. +/// +/// This variant contains the input [Stage] information and executes by broadcasting +/// data from the input stage to all tasks in the current stage over Arrow Flight. +#[derive(Debug, Clone)] +pub struct NetworkBroadcastReady { + pub(crate) properties: PlanProperties, + pub(crate) input_stage: Stage, + pub(crate) metrics_collection: Arc>>, +} + +impl NetworkBroadcastExec { + pub fn new(input: Arc, input_tasks: usize) -> Self { + Self::Pending(NetworkBroadcastPending { + properties: input.properties().clone(), + input_tasks, + input, + }) + } +} + +impl NetworkBoundary for NetworkBroadcastExec { + fn get_input_stage_info( + &self, + _n_tasks: usize, + ) -> datafusion::common::Result { + let Self::Pending(pending) = self else { + return plan_err!("cannot only return wrapped child if on Pending state"); + }; + + Ok(InputStageInfo { + plan: Arc::clone(&pending.input), + task_count: pending.input_tasks, + }) + } + + fn with_input_task_count( + &self, + input_tasks: usize, + ) -> datafusion::common::Result> { + match self { + Self::Pending(pending) => Ok(Arc::new(Self::Pending(NetworkBroadcastPending { + properties: pending.properties.clone(), + input_tasks, + input: pending.input.clone(), + }))), + Self::Ready(_) => { + plan_err!("Self can only re-assign input tasks if in 'Pending' state") + } + } + } + + fn input_task_count(&self) -> usize { + match self { + Self::Pending(v) => v.input_tasks, + Self::Ready(v) => v.input_stage.tasks.len(), + } + } + + fn with_input_stage( + &self, + input_stage: Stage, + ) -> Result, DataFusionError> { + match self { + Self::Pending(pending) => { + let ready = NetworkBroadcastReady { + properties: pending.properties.clone(), + input_stage, + metrics_collection: Default::default(), + }; + Ok(Arc::new(Self::Ready(ready))) + } + Self::Ready(ready) => { + let mut ready = ready.clone(); + ready.input_stage = input_stage; + Ok(Arc::new(Self::Ready(ready))) + } + } + } + + fn input_stage(&self) -> Option<&Stage> { + match self { + Self::Pending(_) => None, + Self::Ready(v) => Some(&v.input_stage), + } + } +} + +impl DisplayAs for NetworkBroadcastExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + NetworkBroadcastExec::Pending(_) => { + write!(f, "NetworkBroadcastExec: [Pending]") + } + NetworkBroadcastExec::Ready(ready) => { + write!( + f, + "NetworkBroadcastExec: [Stage {}] ({} tasks)", + ready.input_stage.num, + ready.input_stage.tasks.len() + ) + } + } + } +} + +impl ExecutionPlan for NetworkBroadcastExec { + fn name(&self) -> &str { + "NetworkBroadcastExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + match self { + NetworkBroadcastExec::Pending(v) => v.input.properties(), + NetworkBroadcastExec::Ready(v) => &v.properties, + } + } + + fn children(&self) -> Vec<&Arc> { + match self { + NetworkBroadcastExec::Pending(v) => vec![&v.input], + NetworkBroadcastExec::Ready(v) => match &v.input_stage.plan { + MaybeEncodedPlan::Decoded(v) => vec![v], + MaybeEncodedPlan::Encoded(_) => vec![], + }, + } + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + match self.as_ref() { + Self::Pending(v) => { + let mut v = v.clone(); + v.input = require_one_child(children)?; + Ok(Arc::new(Self::Pending(v))) + } + Self::Ready(v) => { + let mut v = v.clone(); + v.input_stage.plan = MaybeEncodedPlan::Decoded(require_one_child(children)?); + Ok(Arc::new(Self::Ready(v))) + } + } + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let NetworkBroadcastExec::Ready(self_ready) = self else { + return exec_err!( + "NetworkBroadcastExec is not ready, was the distributed optimization step performed?" + ); + }; + + let channel_resolver = get_distributed_channel_resolver(context.session_config())?; + let input_stage = &self_ready.input_stage; + let encoded_input_plan = input_stage.plan.encoded()?; + let input_stage_tasks = input_stage.tasks.to_vec(); + let input_task_count = input_stage_tasks.len(); + let input_stage_num = input_stage.num as u64; + let query_id = Bytes::from(input_stage.query_id.as_bytes().to_vec()); + let context_headers = ContextGrpcMetadata::headers_from_ctx(&context); + + let stream = input_stage_tasks.into_iter().enumerate().map(|(i, task)| { + let channel_resolver = Arc::clone(&channel_resolver); + let ticket = Request::from_parts( + MetadataMap::from_headers(context_headers.clone()), + Extensions::default(), + Ticket { + ticket: DoGet { + plan_proto: encoded_input_plan.clone(), + target_partition: partition as u64, + stage_key: Some(StageKey::new(query_id.clone(), input_stage_num, i as u64)), + target_task_index: i as u64, + target_task_count: input_task_count as u64, + } + .encode_to_vec() + .into(), + }, + ); + + let metrics_collection_capture = self_ready.metrics_collection.clone(); + async move { + let url = task.url.ok_or(internal_datafusion_err!( + "NetworkBroadcastExec: task is unassigned, cannot proceed" + ))?; + let mut client = channel_resolver.get_flight_client_for_url(&url).await?; + let stream = client + .do_get(ticket) + .await + .map_err(map_status_to_datafusion_error)? + .into_inner() + .map_err(|err| FlightError::Tonic(Box::new(err))); + let metrics_collecting_stream = + MetricsCollectingStream::new(stream, metrics_collection_capture); + Ok( + FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream) + .map_err(map_flight_to_datafusion_error), + ) + } + .try_flatten_stream() + .boxed() + }); + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + futures::stream::select_all(stream), + ))) + } +} diff --git a/src/protobuf/distributed_codec.rs b/src/protobuf/distributed_codec.rs index ee4fb04..f18f2cd 100644 --- a/src/protobuf/distributed_codec.rs +++ b/src/protobuf/distributed_codec.rs @@ -1,6 +1,9 @@ use super::get_distributed_user_codecs; use crate::NetworkBoundary; -use crate::execution_plans::{NetworkCoalesceExec, NetworkCoalesceReady, NetworkShuffleReadyExec}; +use crate::execution_plans::{ + NetworkBroadcastExec, NetworkBroadcastReady, NetworkCoalesceExec, NetworkCoalesceReady, + NetworkShuffleReadyExec, +}; use crate::stage::{ExecutionTask, MaybeEncodedPlan, Stage}; use crate::{NetworkShuffleExec, PartitionIsolatorExec}; use bytes::Bytes; @@ -136,6 +139,30 @@ impl PhysicalExtensionCodec for DistributedCodec { parse_stage_proto(input_stage, inputs)?, ))) } + DistributedExecNode::NetworkBroadcast(NetworkBroadcastExecProto { + schema, + partitioning, + input_stage, + }) => { + let schema: Schema = schema + .as_ref() + .map(|s| s.try_into()) + .ok_or(proto_error("NetworkBroadcastExec is missing schema"))??; + + let partitioning = parse_protobuf_partitioning( + partitioning.as_ref(), + ctx, + &schema, + &DistributedCodec {}, + )? + .ok_or(proto_error("NetworkBroadcastExec is missing partitioning"))?; + + Ok(Arc::new(new_network_broadcast_exec( + partitioning, + Arc::new(schema), + parse_stage_proto(input_stage, inputs)?, + ))) + } DistributedExecNode::PartitionIsolator(PartitionIsolatorExecProto { n_tasks }) => { if inputs.len() != 1 { return Err(proto_error(format!( @@ -203,6 +230,21 @@ impl PhysicalExtensionCodec for DistributedCodec { node: Some(DistributedExecNode::NetworkCoalesceTasks(inner)), }; + wrapper.encode(buf).map_err(|e| proto_error(format!("{e}"))) + } else if let Some(node) = node.as_any().downcast_ref::() { + let inner = NetworkBroadcastExecProto { + schema: Some(node.schema().try_into()?), + partitioning: Some(serialize_partitioning( + node.properties().output_partitioning(), + &DistributedCodec {}, + )?), + input_stage: Some(encode_stage_proto(node.input_stage())?), + }; + + let wrapper = DistributedExecProto { + node: Some(DistributedExecNode::NetworkBroadcast(inner)), + }; + wrapper.encode(buf).map_err(|e| proto_error(format!("{e}"))) } else if let Some(node) = node.as_any().downcast_ref::() { let PartitionIsolatorExec::Ready(ready_node) = node else { @@ -289,6 +331,8 @@ pub enum DistributedExecNode { NetworkCoalesceTasks(NetworkCoalesceExecProto), #[prost(message, tag = "3")] PartitionIsolator(PartitionIsolatorExecProto), + #[prost(message, tag = "4")] + NetworkBroadcast(NetworkBroadcastExecProto), } #[derive(Clone, PartialEq, ::prost::Message)] @@ -357,6 +401,36 @@ fn new_network_coalesce_tasks_exec( }) } +/// Protobuf representation of the [NetworkBroadcastExec] physical node. It serves as +/// an intermediate format for serializing/deserializing [NetworkBroadcastExec] nodes +/// to send them over the wire. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct NetworkBroadcastExecProto { + #[prost(message, optional, tag = "1")] + schema: Option, + #[prost(message, optional, tag = "2")] + partitioning: Option, + #[prost(message, optional, tag = "3")] + input_stage: Option, +} + +fn new_network_broadcast_exec( + partitioning: Partitioning, + schema: SchemaRef, + input_stage: Stage, +) -> NetworkBroadcastExec { + NetworkBroadcastExec::Ready(NetworkBroadcastReady { + properties: PlanProperties::new( + EquivalenceProperties::new(schema), + partitioning, + EmissionType::Incremental, + Boundedness::Bounded, + ), + input_stage, + metrics_collection: Default::default(), + }) +} + fn encode_tasks(tasks: &[ExecutionTask]) -> Vec { tasks .iter() @@ -600,4 +674,23 @@ mod tests { Ok(()) } + + #[test] + fn test_roundtrip_single_broadcast() -> datafusion::common::Result<()> { + let codec = DistributedCodec; + let ctx = create_context(); + + let schema = schema_i32("h"); + let part = Partitioning::Hash(vec![Arc::new(Column::new("h", 0))], 4); + let plan: Arc = + Arc::new(new_network_broadcast_exec(part, schema, dummy_stage())); + + let mut buf = Vec::new(); + codec.try_encode(plan.clone(), &mut buf)?; + + let decoded = codec.try_decode(&buf, &[empty_exec()], &ctx)?; + assert_eq!(repr(&plan), repr(&decoded)); + + Ok(()) + } } diff --git a/tests/broadcast_join_test.rs b/tests/broadcast_join_test.rs new file mode 100644 index 0000000..adf20c3 --- /dev/null +++ b/tests/broadcast_join_test.rs @@ -0,0 +1,328 @@ +#[cfg(all(feature = "integration", test))] +mod tests { + use datafusion::arrow::array::{Int32Array, StringArray}; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::arrow::record_batch::RecordBatch; + use datafusion::arrow::util::pretty::pretty_format_batches; + use datafusion::physical_plan::execute_stream; + use datafusion::prelude::SessionContext; + use datafusion_distributed::test_utils::localhost::start_localhost_context; + use datafusion_distributed::{DefaultSessionBuilder, display_plan_ascii}; + use futures::TryStreamExt; + use parquet::arrow::ArrowWriter; + use std::error::Error; + use std::sync::Arc; + use uuid::Uuid; + + /// Helper function to create a small dimension table (100 rows) across multiple files + async fn create_small_table(ctx: &SessionContext) -> Result<(), Box> { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("dim_value", DataType::Utf8, false), + ])); + + // Create temp directory for partitioned data + use std::fs; + let temp_dir = std::env::temp_dir(); + let table_dir = temp_dir.join(format!("small_table_{}", Uuid::new_v4())); + fs::create_dir(&table_dir)?; + + // Create 100 rows split across 1 file (small table should be in one partition) + let mut id_values = Vec::new(); + let mut dim_values = Vec::new(); + for i in 0..100 { + id_values.push(i); + dim_values.push(format!("dim_{}", i % 10)); + } + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(id_values)), + Arc::new(StringArray::from(dim_values)), + ], + )?; + + // Write to parquet file + let file_path = table_dir.join("part-0.parquet"); + let file = std::fs::File::create(&file_path)?; + let mut writer = ArrowWriter::try_new(file, schema.clone(), None)?; + writer.write(&batch)?; + writer.close()?; + + // Register as parquet table with directory + ctx.register_parquet( + "small_table", + table_dir.to_str().unwrap(), + datafusion::prelude::ParquetReadOptions::default(), + ) + .await?; + + Ok(()) + } + + /// Helper function to create a large fact table (10,000 rows) across multiple files + async fn create_large_table(ctx: &SessionContext) -> Result<(), Box> { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("fact_value", DataType::Int32, false), + ])); + + // Create temp directory for partitioned data + use std::fs; + let temp_dir = std::env::temp_dir(); + let table_dir = temp_dir.join(format!("large_table_{}", Uuid::new_v4())); + fs::create_dir(&table_dir)?; + + // Create 10,000 rows split across 3 files (to match the 3 workers) + let rows_per_file = 3334; + for file_num in 0..3 { + let mut id_values = Vec::new(); + let mut fact_values = Vec::new(); + + let start_row = file_num * rows_per_file; + let end_row = if file_num == 2 { + 10000 // Last file gets remaining rows + } else { + start_row + rows_per_file + }; + + for i in start_row..end_row { + let id = i % 100; // Cycle through 0-99 to match small table + id_values.push(id); + fact_values.push(i); + } + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(id_values)), + Arc::new(Int32Array::from(fact_values)), + ], + )?; + + // Write to parquet file + let file_path = table_dir.join(format!("part-{}.parquet", file_num)); + let file = std::fs::File::create(&file_path)?; + let mut writer = ArrowWriter::try_new(file, schema.clone(), None)?; + writer.write(&batch)?; + writer.close()?; + } + + // Register as parquet table with directory + ctx.register_parquet( + "large_table", + table_dir.to_str().unwrap(), + datafusion::prelude::ParquetReadOptions::default(), + ) + .await?; + + Ok(()) + } + + #[tokio::test] + async fn test_broadcast_join_basic() -> Result<(), Box> { + let (ctx_distributed, _guard) = start_localhost_context(3, DefaultSessionBuilder).await; + + // Setup tables + create_small_table(&ctx_distributed).await?; + create_large_table(&ctx_distributed).await?; + + // Test query: join large table with small table + // Use aggregation to force distributed execution + let query = r#" + SELECT COUNT(*) as total + FROM large_table AS l + JOIN small_table AS s + ON l.id = s.id + "#; + + let df = ctx_distributed.sql(query).await?; + let physical = df.create_physical_plan().await?; + let physical_str = display_plan_ascii(physical.as_ref(), false); + + println!("Physical plan:\n{}", physical_str); + + // Verify that the plan contains NetworkBroadcastExec + assert!( + physical_str.contains("NetworkBroadcast"), + "Expected NetworkBroadcastExec in plan, got:\n{}", + physical_str + ); + + let results = execute_stream(physical, ctx_distributed.task_ctx())? + .try_collect::>() + .await?; + + assert!(!results.is_empty(), "Expected non-empty results"); + + // Verify result + let total_rows: usize = results.iter().map(|batch| batch.num_rows()).sum(); + assert_eq!( + total_rows, 1, + "Expected 1 row for COUNT(*), got {}", + total_rows + ); + + Ok(()) + } + + #[tokio::test] + async fn test_broadcast_join_uses_multiple_tasks() -> Result<(), Box> { + let (ctx_distributed, _guard) = start_localhost_context(4, DefaultSessionBuilder).await; + + // Setup tables + create_small_table(&ctx_distributed).await?; + create_large_table(&ctx_distributed).await?; + + let query = r#" + SELECT COUNT(*) as total + FROM large_table AS l + JOIN small_table AS s + ON l.id = s.id + "#; + + let df = ctx_distributed.sql(query).await?; + let physical = df.create_physical_plan().await?; + let physical_str = display_plan_ascii(physical.as_ref(), false); + + // Verify that the plan contains NetworkBroadcastExec + assert!( + physical_str.contains("NetworkBroadcast"), + "Expected NetworkBroadcastExec in plan" + ); + + // Check that there are multiple tasks being used (not forced to 1) + // Stage 2 (where the join happens) should have multiple tasks like "t0:[p0] t1:[p1] t2:[p2]" + assert!( + physical_str.contains("Stage 2") && physical_str.contains("t1:"), + "Expected multiple tasks in Stage 2 for distributed execution, plan:\n{}", + physical_str + ); + + // Execute query + let results = execute_stream(physical, ctx_distributed.task_ctx())? + .try_collect::>() + .await?; + + assert_eq!( + results.len(), + 1, + "Expected single result batch for COUNT(*)" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_broadcast_join_correctness() -> Result<(), Box> { + // Setup distributed context with 3 workers + let (ctx_distributed, _guard) = start_localhost_context(3, DefaultSessionBuilder).await; + + // Setup single-node context + let ctx_single = SessionContext::default(); + *ctx_single.state_ref().write().config_mut() = ctx_distributed.copied_config(); + + // Create tables in both contexts + create_small_table(&ctx_distributed).await?; + create_large_table(&ctx_distributed).await?; + create_small_table(&ctx_single).await?; + create_large_table(&ctx_single).await?; + + let query = r#" + SELECT l.id, COUNT(*) as cnt, SUM(l.fact_value) as total + FROM large_table AS l + JOIN small_table AS s + ON l.id = s.id + GROUP BY l.id + ORDER BY l.id + "#; + + // Execute on single-node + let df_single = ctx_single.sql(query).await?; + let physical_single = df_single.create_physical_plan().await?; + let results_single = execute_stream(physical_single, ctx_single.task_ctx())? + .try_collect::>() + .await?; + let single_output = pretty_format_batches(&results_single)?; + + // Execute on distributed + let df_distributed = ctx_distributed.sql(query).await?; + let physical_distributed = df_distributed.create_physical_plan().await?; + let physical_str = display_plan_ascii(physical_distributed.as_ref(), false); + + // Verify NetworkBroadcastExec is in the plan + assert!( + physical_str.contains("NetworkBroadcast"), + "Expected NetworkBroadcastExec in distributed plan" + ); + + let results_distributed = execute_stream(physical_distributed, ctx_distributed.task_ctx())? + .try_collect::>() + .await?; + let distributed_output = pretty_format_batches(&results_distributed)?; + + // Compare results - they should be identical + assert_eq!( + single_output.to_string(), + distributed_output.to_string(), + "Distributed and single-node results should match.\nSingle:\n{}\nDistributed:\n{}", + single_output, + distributed_output + ); + + // Verify we got results for all 100 IDs + let total_rows: usize = results_distributed + .iter() + .map(|batch| batch.num_rows()) + .sum(); + assert_eq!( + total_rows, 100, + "Expected 100 rows (one per ID), got {}", + total_rows + ); + + Ok(()) + } + + #[tokio::test] + async fn test_broadcast_join_with_filter() -> Result<(), Box> { + let (ctx_distributed, _guard) = start_localhost_context(3, DefaultSessionBuilder).await; + + create_small_table(&ctx_distributed).await?; + create_large_table(&ctx_distributed).await?; + + // Test with WHERE clause to ensure filtering works correctly with broadcast + let query = r#" + SELECT l.id, COUNT(*) as cnt + FROM large_table AS l + JOIN small_table AS s + ON l.id = s.id + WHERE l.fact_value < 5000 + GROUP BY l.id + ORDER BY l.id + "#; + + let df = ctx_distributed.sql(query).await?; + let physical = df.create_physical_plan().await?; + + let results = execute_stream(physical, ctx_distributed.task_ctx())? + .try_collect::>() + .await?; + + assert!( + !results.is_empty(), + "Expected non-empty results with filter" + ); + + // Verify that we only get results for filtered data + let total_rows: usize = results.iter().map(|batch| batch.num_rows()).sum(); + assert!( + total_rows > 0 && total_rows <= 100, + "Expected reasonable number of rows after filtering, got {}", + total_rows + ); + + Ok(()) + } +}