diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index ed9646b..1631f2a 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -79,7 +79,3 @@ impl ArrowFlightEndpoint { )))) } } - -fn invalid_argument(msg: impl Into) -> Result { - Err(Status::invalid_argument(msg)) -} diff --git a/src/physical_optimizer.rs b/src/physical_optimizer.rs index 56f1a2f..54c460e 100644 --- a/src/physical_optimizer.rs +++ b/src/physical_optimizer.rs @@ -1,9 +1,12 @@ use std::sync::Arc; +use crate::{plan::PartitionIsolatorExec, ArrowFlightReadExec}; +use datafusion::common::tree_node::TreeNodeRecursion; +use datafusion::error::DataFusionError; use datafusion::{ common::{ internal_datafusion_err, - tree_node::{Transformed, TreeNode, TreeNodeRewriter}, + tree_node::{Transformed, TreeNode}, }, config::ConfigOptions, error::Result, @@ -14,8 +17,6 @@ use datafusion::{ }; use datafusion_proto::physical_plan::PhysicalExtensionCodec; -use crate::{plan::PartitionIsolatorExec, ArrowFlightReadExec}; - use super::stage::ExecutionStage; #[derive(Debug, Default)] @@ -75,11 +76,9 @@ impl PhysicalOptimizerRule for DistributedPhysicalOptimizerRule { displayable(plan.as_ref()).indent(false) ); - let mut planner = StagePlanner::new(self.codec.clone(), self.partitions_per_task); - plan.rewrite(&mut planner)?; - planner - .finish() - .map(|stage| stage as Arc) + let plan = self.apply_network_boundaries(plan)?; + let plan = self.distribute_plan(plan)?; + Ok(Arc::new(plan)) } fn name(&self) -> &str { @@ -91,171 +90,78 @@ impl PhysicalOptimizerRule for DistributedPhysicalOptimizerRule { } } -/// StagePlanner is a TreeNodeRewriter that walks the plan tree and creates -/// a tree of ExecutionStage nodes that represent discrete stages of execution -/// can are separated by a data shuffle. -/// -/// See https://howqueryengineswork.com/13-distributed-query.html for more information -/// about distributed execution. -struct StagePlanner { - /// used to keep track of the current plan head - plan_head: Option>, - /// Current depth in the plan tree, as we walk the tree - depth: usize, - /// Input stages collected so far. Each entry is a tuple of (plan tree depth, stage). - /// This allows us to keep track of the depth in the plan tree - /// where we created the stage. That way when we create a new - /// stage, we can tell if it is a peer to the current input stages or - /// should be a parent (if its depth is a smaller number) - input_stages: Vec<(usize, ExecutionStage)>, - /// current stage number - stage_counter: usize, - /// Optional codec to assist in serializing and deserializing any custom - codec: Option>, - /// partitions_per_task is used to determine how many tasks to create for each stage - partitions_per_task: Option, -} - -impl StagePlanner { - fn new( - codec: Option>, - partitions_per_task: Option, - ) -> Self { - StagePlanner { - plan_head: None, - depth: 0, - input_stages: vec![], - stage_counter: 1, - codec, - partitions_per_task, - } - } - - fn finish(mut self) -> Result> { - let stage = if self.input_stages.is_empty() { - ExecutionStage::new( - self.stage_counter, - self.plan_head - .take() - .ok_or_else(|| internal_datafusion_err!("No plan head set"))?, - vec![], - ) - } else if self.depth < self.input_stages[0].0 { - // There is more plan above the last stage we created, so we need to - // create a new stage that includes the last plan head - ExecutionStage::new( - self.stage_counter, - self.plan_head - .take() - .ok_or_else(|| internal_datafusion_err!("No plan head set"))?, - self.input_stages - .into_iter() - .map(|(_, stage)| Arc::new(stage)) - .collect(), - ) - } else { - // We have a plan head, and we are at the same depth as the last stage we created, - // so we can just return the last stage - self.input_stages.last().unwrap().1.clone() - }; - - // assign the proper tree depth to each stage in the tree - fn assign_tree_depth(stage: &ExecutionStage, depth: usize) { - stage - .depth - .store(depth as u64, std::sync::atomic::Ordering::Relaxed); - for input in stage.child_stages_iter() { - assign_tree_depth(input, depth + 1); +impl DistributedPhysicalOptimizerRule { + pub fn apply_network_boundaries( + &self, + plan: Arc, + ) -> Result, DataFusionError> { + let result = plan.transform_up(|plan| { + if plan.as_any().downcast_ref::().is_some() { + let child = Arc::clone(plan.children().first().cloned().ok_or( + internal_datafusion_err!("Expected RepartitionExec to have a child"), + )?); + + let maybe_isolated_plan = if let Some(ppt) = self.partitions_per_task { + let isolated = Arc::new(PartitionIsolatorExec::new(child, ppt)); + plan.with_new_children(vec![isolated])? + } else { + plan + }; + + return Ok(Transformed::yes(Arc::new( + ArrowFlightReadExec::new_pending( + Arc::clone(&maybe_isolated_plan), + maybe_isolated_plan.output_partitioning().clone(), + ), + ))); } - } - assign_tree_depth(&stage, 0); - Ok(Arc::new(stage)) + Ok(Transformed::no(plan)) + })?; + Ok(result.data) } -} - -impl TreeNodeRewriter for StagePlanner { - type Node = Arc; - fn f_down(&mut self, plan: Self::Node) -> Result> { - self.depth += 1; - Ok(Transformed::no(plan)) + pub fn distribute_plan( + &self, + plan: Arc, + ) -> Result { + self._distribute_plan_inner(plan, &mut 1, 0) } - fn f_up(&mut self, plan: Self::Node) -> Result> { - self.depth -= 1; - - // keep track of where we are - self.plan_head = Some(plan.clone()); - - // determine if we need to shuffle data, and thus create a new stage - // at this shuffle boundary - if let Some(repartition_exec) = plan.as_any().downcast_ref::() { - // time to create a stage here so include all previous seen stages deeper than us as - // our input stages - let child_stages = self - .input_stages - .iter() - .rev() - .take_while(|(depth, _)| *depth > self.depth) - .map(|(_, stage)| stage.clone()) - .collect::>(); - - self.input_stages.retain(|(depth, _)| *depth <= self.depth); - - let maybe_isolated_plan = if let Some(partitions_per_task) = self.partitions_per_task { - let child = repartition_exec - .children() - .first() - .ok_or(internal_datafusion_err!( - "RepartitionExec has no children, cannot create PartitionIsolatorExec" - ))? - .clone() - .clone(); // just clone the Arcs - let isolated = Arc::new(PartitionIsolatorExec::new(child, partitions_per_task)); - plan.clone().with_new_children(vec![isolated])? - } else { - plan.clone() + fn _distribute_plan_inner( + &self, + plan: Arc, + num: &mut usize, + depth: usize, + ) -> Result { + let mut inputs = vec![]; + + let distributed = plan.transform_down(|plan| { + let Some(node) = plan.as_any().downcast_ref::() else { + return Ok(Transformed::no(plan)); }; - - let mut stage = ExecutionStage::new( - self.stage_counter, - maybe_isolated_plan, - child_stages.into_iter().map(Arc::new).collect(), - ); - - if let Some(partitions_per_task) = self.partitions_per_task { - stage = stage.with_maximum_partitions_per_task(partitions_per_task); - } - if let Some(codec) = self.codec.as_ref() { - stage = stage.with_codec(codec.clone()); - } - - self.input_stages.push((self.depth, stage)); - - // As we are walking up the plan tree, we've now put what we've encountered so far - // into a stage. We want to replace this plan now with an ArrowFlightReadExec - // which will be able to consume from this stage over the network. - // - // That way as we walk further up the tree and build the next stage, the leaf - // node in that plan will be an ArrowFlightReadExec that can read from - // - // Note that we use the original plans partitioning and schema for ArrowFlightReadExec. - // If we divide it up in to tasks, then that parittion will need to be gathered from - // among them - let name = format!("Stage {:<3}", self.stage_counter); - let read = Arc::new(ArrowFlightReadExec::new( - plan.output_partitioning().clone(), - plan.schema(), - self.stage_counter, - )); - - self.stage_counter += 1; - - Ok(Transformed::yes(read as Self::Node)) - } else { - Ok(Transformed::no(plan)) + let child = Arc::clone(node.children().first().cloned().ok_or( + internal_datafusion_err!("Expected ArrowFlightExecRead to have a child"), + )?); + let stage = self._distribute_plan_inner(child, num, depth + 1)?; + let node = Arc::new(node.to_distributed(stage.num)?); + inputs.push(stage); + Ok(Transformed::new(node, true, TreeNodeRecursion::Jump)) + })?; + + let inputs = inputs.into_iter().map(Arc::new).collect(); + let mut stage = ExecutionStage::new(*num, distributed.data, inputs); + *num += 1; + + if let Some(partitions_per_task) = self.partitions_per_task { + stage = stage.with_maximum_partitions_per_task(partitions_per_task); + } + if let Some(codec) = self.codec.as_ref() { + stage = stage.with_codec(codec.clone()); } + stage.depth = depth; + + Ok(stage) } } @@ -427,6 +333,20 @@ mod tests { │partitions [out:4 ] ArrowFlightReadExec: Stage 4 │ └────────────────────────────────────────────────── + ┌───── Stage 2 Task: partitions: 0..3,unassigned] + │partitions [out:4 <-- in:4 ] RepartitionExec: partitioning=Hash([RainTomorrow@0], 4), input_partitions=4 + │partitions [out:4 <-- in:4 ] AggregateExec: mode=Partial, gby=[RainTomorrow@1 as RainTomorrow], aggr=[avg(weather.MinTemp)] + │partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192 + │partitions [out:4 <-- in:4 ] FilterExec: RainToday@1 = yes, projection=[MinTemp@0, RainTomorrow@2] + │partitions [out:4 ] ArrowFlightReadExec: Stage 1 + │ + └────────────────────────────────────────────────── + ┌───── Stage 1 Task: partitions: 0..3,unassigned] + │partitions [out:4 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 + │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MinTemp, RainToday, RainTomorrow], file_type=parquet, predicate=RainToday@1 = yes, pruning_predicate=RainToday_null_count@2 != row_count@3 AND RainToday_min@0 <= yes AND yes <= RainToday_max@1, required_guarantees=[RainToday in (yes)] + │ + │ + └────────────────────────────────────────────────── ┌───── Stage 4 Task: partitions: 0..3,unassigned] │partitions [out:4 <-- in:4 ] RepartitionExec: partitioning=Hash([RainTomorrow@0], 4), input_partitions=4 │partitions [out:4 <-- in:4 ] AggregateExec: mode=Partial, gby=[RainTomorrow@1 as RainTomorrow], aggr=[avg(weather.MaxTemp)] @@ -441,20 +361,6 @@ mod tests { │ │ └────────────────────────────────────────────────── - ┌───── Stage 2 Task: partitions: 0..3,unassigned] - │partitions [out:4 <-- in:4 ] RepartitionExec: partitioning=Hash([RainTomorrow@0], 4), input_partitions=4 - │partitions [out:4 <-- in:4 ] AggregateExec: mode=Partial, gby=[RainTomorrow@1 as RainTomorrow], aggr=[avg(weather.MinTemp)] - │partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192 - │partitions [out:4 <-- in:4 ] FilterExec: RainToday@1 = yes, projection=[MinTemp@0, RainTomorrow@2] - │partitions [out:4 ] ArrowFlightReadExec: Stage 1 - │ - └────────────────────────────────────────────────── - ┌───── Stage 1 Task: partitions: 0..3,unassigned] - │partitions [out:4 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 - │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MinTemp, RainToday, RainTomorrow], file_type=parquet, predicate=RainToday@1 = yes, pruning_predicate=RainToday_null_count@2 != row_count@3 AND RainToday_min@0 <= yes AND yes <= RainToday_max@1, required_guarantees=[RainToday in (yes)] - │ - │ - └────────────────────────────────────────────────── "); } diff --git a/src/plan/arrow_flight_read.rs b/src/plan/arrow_flight_read.rs index ae15fc8..73cb070 100644 --- a/src/plan/arrow_flight_read.rs +++ b/src/plan/arrow_flight_read.rs @@ -8,8 +8,8 @@ use arrow_flight::error::FlightError; use arrow_flight::flight_service_client::FlightServiceClient; use arrow_flight::Ticket; use datafusion::arrow::datatypes::SchemaRef; -use datafusion::common::{internal_datafusion_err, plan_err}; -use datafusion::error::{DataFusionError, Result}; +use datafusion::common::{exec_err, internal_datafusion_err, internal_err, plan_err}; +use datafusion::error::DataFusionError; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; @@ -22,32 +22,86 @@ use std::fmt::Formatter; use std::sync::Arc; use url::Url; +/// This node has two variants. +/// 1. Pending: it acts as a placeholder for the distributed optimization step to mark it as ready. +/// 2. Ready: runs within a distributed stage and queries the next input stage over the network +/// using Arrow Flight. #[derive(Debug, Clone)] -pub struct ArrowFlightReadExec { - /// the number of the stage we are reading from - pub stage_num: usize, +pub enum ArrowFlightReadExec { + Pending(ArrowFlightReadPendingExec), + Ready(ArrowFlightReadReadyExec), +} + +/// Placeholder version of the [ArrowFlightReadExec] node. It acts as a marker for the +/// distributed optimization step, which will replace it with the appropriate +/// [ArrowFlightReadReadyExec] node. +#[derive(Debug, Clone)] +pub struct ArrowFlightReadPendingExec { + properties: PlanProperties, + child: Arc, +} + +/// Ready version of the [ArrowFlightReadExec] node. This node can be created in +/// just two ways: +/// - by the distributed optimization step based on an original [ArrowFlightReadPendingExec] +/// - deserialized from a protobuf plan sent over the network. +#[derive(Debug, Clone)] +pub struct ArrowFlightReadReadyExec { /// the properties we advertise for this execution plan properties: PlanProperties, + pub(crate) stage_num: usize, } impl ArrowFlightReadExec { - pub fn new(partitioning: Partitioning, schema: SchemaRef, stage_num: usize) -> Self { + pub fn new_pending(child: Arc, partitioning: Partitioning) -> Self { + Self::Pending(ArrowFlightReadPendingExec { + properties: PlanProperties::new( + EquivalenceProperties::new(child.schema()), + partitioning, + EmissionType::Incremental, + Boundedness::Bounded, + ), + child, + }) + } + + pub(crate) fn new_ready( + partitioning: Partitioning, + schema: SchemaRef, + stage_num: usize, + ) -> Self { let properties = PlanProperties::new( EquivalenceProperties::new(schema), partitioning, EmissionType::Incremental, Boundedness::Bounded, ); - Self { + Self::Ready(ArrowFlightReadReadyExec { properties, stage_num, + }) + } + + pub(crate) fn to_distributed(&self, stage_num: usize) -> Result { + match self { + ArrowFlightReadExec::Pending(p) => Ok(Self::new_ready( + p.properties.partitioning.clone(), + p.child.schema(), + stage_num, + )), + _ => internal_err!("ArrowFlightReadExec is already distributed"), } } } impl DisplayAs for ArrowFlightReadExec { fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - write!(f, "ArrowFlightReadExec: Stage {:<3}", self.stage_num) + match self { + ArrowFlightReadExec::Pending(_) => write!(f, "ArrowFlightReadExec"), + ArrowFlightReadExec::Ready(v) => { + write!(f, "ArrowFlightReadExec: Stage {:<3}", v.stage_num) + } + } } } @@ -61,17 +115,23 @@ impl ExecutionPlan for ArrowFlightReadExec { } fn properties(&self) -> &PlanProperties { - &self.properties + match self { + ArrowFlightReadExec::Pending(v) => &v.properties, + ArrowFlightReadExec::Ready(v) => &v.properties, + } } fn children(&self) -> Vec<&Arc> { - vec![] + match self { + ArrowFlightReadExec::Pending(v) => vec![&v.child], + ArrowFlightReadExec::Ready(_) => vec![], + } } fn with_new_children( self: Arc, children: Vec>, - ) -> datafusion::common::Result> { + ) -> Result, DataFusionError> { if !children.is_empty() { return plan_err!( "ArrowFlightReadExec: wrong number of children, expected 0, got {}", @@ -85,9 +145,14 @@ impl ExecutionPlan for ArrowFlightReadExec { &self, partition: usize, context: Arc, - ) -> datafusion::common::Result { - /// get the channel manager and current stage from our context + ) -> Result { + let ArrowFlightReadExec::Ready(this) = self else { + return exec_err!("ArrowFlightReadExec is not ready, was the distributed optimization step performed?"); + }; + + // get the channel manager and current stage from our context let channel_manager: ChannelManager = context.as_ref().try_into()?; + let stage = context .session_config() .get_extension::() @@ -99,10 +164,10 @@ impl ExecutionPlan for ArrowFlightReadExec { // reading from let child_stage = stage .child_stages_iter() - .find(|s| s.num == self.stage_num) + .find(|s| s.num == this.stage_num) .ok_or(internal_datafusion_err!( "ArrowFlightReadExec: no child stage with num {}", - self.stage_num + this.stage_num ))?; let child_stage_tasks = child_stage.tasks.clone(); @@ -127,7 +192,7 @@ impl ExecutionPlan for ArrowFlightReadExec { let schema = child_stage.plan.schema(); let stream = async move { - let futs = child_stage_tasks.iter().map(|task| async { + let futs = child_stage_tasks.iter().enumerate().map(|(i, task)| async { let url = task.url()?.ok_or(internal_datafusion_err!( "ArrowFlightReadExec: task is unassigned, cannot proceed" ))?; diff --git a/src/plan/codec.rs b/src/plan/codec.rs index f4b3871..1bc2cf4 100644 --- a/src/plan/codec.rs +++ b/src/plan/codec.rs @@ -52,7 +52,7 @@ impl PhysicalExtensionCodec for DistributedCodec { )? .ok_or(proto_error("ArrowFlightReadExec is missing partitioning"))?; - Ok(Arc::new(ArrowFlightReadExec::new( + Ok(Arc::new(ArrowFlightReadExec::new_ready( partioning, Arc::new(schema), stage_num as usize, @@ -84,13 +84,18 @@ impl PhysicalExtensionCodec for DistributedCodec { buf: &mut Vec, ) -> datafusion::common::Result<()> { if let Some(node) = node.as_any().downcast_ref::() { + let ArrowFlightReadExec::Ready(ready_node) = node else { + return Err(proto_error( + "deserialized an ArrowFlightReadExec that is not ready", + )); + }; let inner = ArrowFlightReadExecProto { schema: Some(node.schema().try_into()?), partitioning: Some(serialize_partitioning( node.properties().output_partitioning(), &DistributedCodec {}, )?), - stage_num: node.stage_num as u64, + stage_num: ready_node.stage_num as u64, }; let wrapper = DistributedExecProto { @@ -172,7 +177,8 @@ mod tests { let schema = schema_i32("a"); let part = Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 4); - let plan: Arc = Arc::new(ArrowFlightReadExec::new(part, schema, 0)); + let plan: Arc = + Arc::new(ArrowFlightReadExec::new_ready(part, schema, 0)); let mut buf = Vec::new(); codec.try_encode(plan.clone(), &mut buf)?; @@ -189,7 +195,7 @@ mod tests { let registry = MemoryFunctionRegistry::new(); let schema = schema_i32("b"); - let flight = Arc::new(ArrowFlightReadExec::new( + let flight = Arc::new(ArrowFlightReadExec::new_ready( Partitioning::UnknownPartitioning(1), schema, 0, @@ -212,12 +218,12 @@ mod tests { let registry = MemoryFunctionRegistry::new(); let schema = schema_i32("c"); - let left = Arc::new(ArrowFlightReadExec::new( + let left = Arc::new(ArrowFlightReadExec::new_ready( Partitioning::RoundRobinBatch(2), schema.clone(), 0, )); - let right = Arc::new(ArrowFlightReadExec::new( + let right = Arc::new(ArrowFlightReadExec::new_ready( Partitioning::RoundRobinBatch(2), schema.clone(), 1, @@ -241,7 +247,7 @@ mod tests { let registry = MemoryFunctionRegistry::new(); let schema = schema_i32("d"); - let flight = Arc::new(ArrowFlightReadExec::new( + let flight = Arc::new(ArrowFlightReadExec::new_ready( Partitioning::UnknownPartitioning(1), schema.clone(), 0, diff --git a/src/stage/display.rs b/src/stage/display.rs index 6d5e3d6..f3b7642 100644 --- a/src/stage/display.rs +++ b/src/stage/display.rs @@ -54,15 +54,13 @@ impl DisplayAs for ExecutionStage { )?; let plan_str = display_plan_with_partition_in_out(self.plan.as_ref()) .map_err(|_| std::fmt::Error {})?; - let plan_str = plan_str.replace( - '\n', - &format!("\n{}{}", " ".repeat(self.depth()), VERTICAL), - ); - writeln!(f, "{}{}{}", " ".repeat(self.depth()), VERTICAL, plan_str)?; + let plan_str = + plan_str.replace('\n', &format!("\n{}{}", " ".repeat(self.depth), VERTICAL)); + writeln!(f, "{}{}{}", " ".repeat(self.depth), VERTICAL, plan_str)?; write!( f, "{}{}{}", - " ".repeat(self.depth()), + " ".repeat(self.depth), LDCORNER, HORIZONTAL.repeat(50) )?; diff --git a/src/stage/proto.rs b/src/stage/proto.rs index 7d0d0d8..f0b127f 100644 --- a/src/stage/proto.rs +++ b/src/stage/proto.rs @@ -94,7 +94,7 @@ pub fn stage_from_proto( inputs, tasks: msg.tasks, codec: Some(codec), - depth: std::sync::atomic::AtomicU64::new(0), + depth: 0, }) } @@ -159,7 +159,7 @@ mod tests { inputs: vec![], tasks: vec![], codec: Some(Arc::new(DefaultPhysicalExtensionCodec {})), - depth: std::sync::atomic::AtomicU64::new(0), + depth: 0, }; // Convert to proto message diff --git a/src/stage/stage.rs b/src/stage/stage.rs index c971c1f..b161f03 100644 --- a/src/stage/stage.rs +++ b/src/stage/stage.rs @@ -1,4 +1,3 @@ -use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use datafusion::common::internal_err; @@ -20,7 +19,7 @@ use crate::ChannelManager; /// /// see https://howqueryengineswork.com/13-distributed-query.html /// -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ExecutionStage { /// Our stage number pub num: usize, @@ -36,23 +35,7 @@ pub struct ExecutionStage { /// An optional codec to assist in serializing and deserializing this stage pub codec: Option>, /// tree depth of our location in the stage tree, used for display only - pub(crate) depth: AtomicU64, -} - -impl Clone for ExecutionStage { - /// Creates a shallow clone of this `ExecutionStage`. The plan, input stages, - /// and codec are cloned Arcs as we dont need to duplicate the underlying data, - fn clone(&self) -> Self { - ExecutionStage { - num: self.num, - name: self.name.clone(), - plan: self.plan.clone(), - inputs: self.inputs.to_vec(), - tasks: self.tasks.clone(), - codec: self.codec.clone(), - depth: AtomicU64::new(self.depth.load(Ordering::Relaxed)), - } - } + pub depth: usize, } impl ExecutionStage { @@ -83,7 +66,7 @@ impl ExecutionStage { .collect(), tasks: vec![ExecutionTask::new(partition_group)], codec: None, - depth: AtomicU64::new(0), + depth: 0, } } @@ -192,15 +175,11 @@ impl ExecutionStage { inputs: assigned_children, tasks: assigned_tasks, codec: self.codec.clone(), - depth: AtomicU64::new(self.depth.load(Ordering::Relaxed)), + depth: self.depth, }; Ok(assigned_stage) } - - pub(crate) fn depth(&self) -> usize { - self.depth.load(Ordering::Relaxed) as usize - } } impl ExecutionPlan for ExecutionStage { @@ -227,7 +206,7 @@ impl ExecutionPlan for ExecutionStage { inputs: children, tasks: self.tasks.clone(), codec: self.codec.clone(), - depth: AtomicU64::new(self.depth.load(Ordering::Relaxed)), + depth: self.depth, })) } diff --git a/tests/common/localhost.rs b/tests/common/localhost.rs index 5944e29..4b2d314 100644 --- a/tests/common/localhost.rs +++ b/tests/common/localhost.rs @@ -4,7 +4,6 @@ use datafusion::common::DataFusionError; use datafusion::execution::SessionStateBuilder; use datafusion::prelude::SessionContext; use datafusion::{common::runtime::JoinSet, prelude::SessionConfig}; -use datafusion_distributed::physical_optimizer::DistributedPhysicalOptimizerRule; use datafusion_distributed::{ ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelManager, ChannelResolver, SessionBuilder, }; @@ -50,12 +49,9 @@ where let config = SessionConfig::new().with_target_partitions(3); - let rule = DistributedPhysicalOptimizerRule::default().with_maximum_partitions_per_task(4); - let state = SessionStateBuilder::new() .with_default_features() .with_config(config) - .with_physical_optimizer_rule(Arc::new(rule)) .build(); let ctx = SessionContext::new_with_state(state); diff --git a/tests/common/plan.rs b/tests/common/plan.rs index 892bba5..df8560e 100644 --- a/tests/common/plan.rs +++ b/tests/common/plan.rs @@ -4,6 +4,7 @@ use datafusion::error::DataFusionError; use datafusion::physical_expr::Partitioning; use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode}; use datafusion::physical_plan::ExecutionPlan; +use datafusion_distributed::physical_optimizer::DistributedPhysicalOptimizerRule; use datafusion_distributed::ArrowFlightReadExec; use std::sync::Arc; @@ -11,58 +12,55 @@ pub fn distribute_aggregate( plan: Arc, ) -> Result, DataFusionError> { let mut aggregate_partial_found = false; - Ok(plan - .transform_up(|node| { - let Some(agg) = node.as_any().downcast_ref::() else { - return Ok(Transformed::no(node)); - }; + let transformed = plan.transform_up(|node| { + let Some(agg) = node.as_any().downcast_ref::() else { + return Ok(Transformed::no(node)); + }; - match agg.mode() { - AggregateMode::Partial => { - if aggregate_partial_found { - return plan_err!("Two consecutive partial aggregations found"); - } - aggregate_partial_found = true; - let expr = agg - .group_expr() - .expr() - .iter() - .map(|(v, _)| Arc::clone(v)) - .collect::>(); - - if node.children().len() != 1 { - return plan_err!("Aggregate must have exactly one child"); - } - let child = node.children()[0].clone(); + match agg.mode() { + AggregateMode::Partial => { + if aggregate_partial_found { + return plan_err!("Two consecutive partial aggregations found"); + } + aggregate_partial_found = true; + let expr = agg + .group_expr() + .expr() + .iter() + .map(|(v, _)| Arc::clone(v)) + .collect::>(); - let node = node.with_new_children(vec![Arc::new(ArrowFlightReadExec::new( - Partitioning::Hash(expr, 1), - child.schema(), - 0, - ))])?; - Ok(Transformed::yes(node)) + if node.children().len() != 1 { + return plan_err!("Aggregate must have exactly one child"); } - AggregateMode::Final - | AggregateMode::FinalPartitioned - | AggregateMode::Single - | AggregateMode::SinglePartitioned => { - if !aggregate_partial_found { - return plan_err!("No partial aggregate found before the final one"); - } + let child = node.children()[0].clone(); - if node.children().len() != 1 { - return plan_err!("Aggregate must have exactly one child"); - } - let child = node.children()[0].clone(); + let node = node.with_new_children(vec![Arc::new( + ArrowFlightReadExec::new_pending(child, Partitioning::Hash(expr, 1)), + )])?; + Ok(Transformed::yes(node)) + } + AggregateMode::Final + | AggregateMode::FinalPartitioned + | AggregateMode::Single + | AggregateMode::SinglePartitioned => { + if !aggregate_partial_found { + return plan_err!("No partial aggregate found before the final one"); + } - let node = node.with_new_children(vec![Arc::new(ArrowFlightReadExec::new( - Partitioning::RoundRobinBatch(8), - child.schema(), - 1, - ))])?; - Ok(Transformed::yes(node)) + if node.children().len() != 1 { + return plan_err!("Aggregate must have exactly one child"); } + let child = node.children()[0].clone(); + + let node = node.with_new_children(vec![Arc::new( + ArrowFlightReadExec::new_pending(child, Partitioning::RoundRobinBatch(8)), + )])?; + Ok(Transformed::yes(node)) } - })? - .data) + } + })?; + Ok(Arc::new( + DistributedPhysicalOptimizerRule::default().distribute_plan(transformed.data)?, + )) } diff --git a/tests/custom_extension_codec.rs b/tests/custom_extension_codec.rs index 7367616..012f72c 100644 --- a/tests/custom_extension_codec.rs +++ b/tests/custom_extension_codec.rs @@ -27,6 +27,7 @@ mod tests { use datafusion::physical_plan::{ displayable, execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; + use datafusion_distributed::physical_optimizer::DistributedPhysicalOptimizerRule; use datafusion_distributed::{ArrowFlightReadExec, SessionBuilder}; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_proto::protobuf::proto_error; @@ -67,8 +68,10 @@ mod tests { "); let distributed_plan = build_plan(true)?; + let distributed_plan = + DistributedPhysicalOptimizerRule::default().distribute_plan(distributed_plan)?; - assert_snapshot!(displayable(distributed_plan.as_ref()).indent(true).to_string(), @r" + assert_snapshot!(displayable(&distributed_plan).indent(true).to_string(), @r" SortExec: expr=[numbers@0 DESC NULLS LAST], preserve_partitioning=[false] RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=10 ArrowFlightReadExec: input_tasks=10 hash_expr=[] stage_id=UUID input_stage_id=UUID input_hosts=[http://localhost:50050/, http://localhost:50051/, http://localhost:50052/, http://localhost:50050/, http://localhost:50051/, http://localhost:50052/, http://localhost:50050/, http://localhost:50051/, http://localhost:50052/, http://localhost:50050/] @@ -93,7 +96,7 @@ mod tests { +---------+ "); - let stream = execute_stream(distributed_plan, ctx.task_ctx())?; + let stream = execute_stream(Arc::new(distributed_plan), ctx.task_ctx())?; let batches_distributed = stream.try_collect::>().await?; assert_snapshot!(pretty_format_batches(&batches_distributed).unwrap(), @r" @@ -123,10 +126,9 @@ mod tests { )?); if distributed { - plan = Arc::new(ArrowFlightReadExec::new( + plan = Arc::new(ArrowFlightReadExec::new_pending( + plan.clone(), Partitioning::Hash(vec![col("numbers", &plan.schema())?], 1), - plan.clone().schema(), - 0, // TODO: stage num should be assigned by someone else )); } @@ -139,10 +141,9 @@ mod tests { )); if distributed { - plan = Arc::new(ArrowFlightReadExec::new( + plan = Arc::new(ArrowFlightReadExec::new_pending( + plan.clone(), Partitioning::RoundRobinBatch(10), - plan.clone().schema(), - 1, // TODO: stage num should be assigned by someone else )); plan = Arc::new(RepartitionExec::try_new( diff --git a/tests/error_propagation.rs b/tests/error_propagation.rs index 05e3e95..b559f62 100644 --- a/tests/error_propagation.rs +++ b/tests/error_propagation.rs @@ -15,6 +15,7 @@ mod tests { use datafusion::physical_plan::{ execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; + use datafusion_distributed::physical_optimizer::DistributedPhysicalOptimizerRule; use datafusion_distributed::{ArrowFlightReadExec, SessionBuilder}; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_proto::protobuf::proto_error; @@ -49,14 +50,14 @@ mod tests { let mut plan: Arc = Arc::new(ErrorExec::new("something failed")); - for (i, size) in [1, 2, 3].iter().enumerate() { - plan = Arc::new(ArrowFlightReadExec::new( - Partitioning::RoundRobinBatch(*size as usize), - plan.schema(), - i, + for size in [1, 2, 3] { + plan = Arc::new(ArrowFlightReadExec::new_pending( + plan, + Partitioning::RoundRobinBatch(size), )); } - let stream = execute_stream(plan, ctx.task_ctx())?; + let plan = DistributedPhysicalOptimizerRule::default().distribute_plan(plan)?; + let stream = execute_stream(Arc::new(plan), ctx.task_ctx())?; let Err(err) = stream.try_collect::>().await else { panic!("Should have failed") diff --git a/tests/highly_distributed_query.rs b/tests/highly_distributed_query.rs index 385c430..d2b247b 100644 --- a/tests/highly_distributed_query.rs +++ b/tests/highly_distributed_query.rs @@ -30,11 +30,10 @@ mod tests { let physical_str = displayable(physical.as_ref()).indent(true).to_string(); let mut physical_distributed = physical.clone(); - for (i, size) in [1, 10, 5].iter().enumerate() { - physical_distributed = Arc::new(ArrowFlightReadExec::new( - Partitioning::RoundRobinBatch(*size as usize), - physical_distributed.schema(), - i, + for size in [1, 10, 5] { + physical_distributed = Arc::new(ArrowFlightReadExec::new_pending( + physical_distributed, + Partitioning::RoundRobinBatch(size), )); } let physical_distributed_str = displayable(physical_distributed.as_ref())