From e3096a6cfe4c89e9426f4a75c37a7e1ddef8336e Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Thu, 4 Sep 2025 20:33:31 +0200 Subject: [PATCH 1/6] move plans/ to execution_plans/ --- src/{plan => execution_plans}/arrow_flight_read.rs | 2 +- src/{plan => execution_plans}/codec.rs | 2 +- src/{plan => execution_plans}/mod.rs | 4 ++-- .../isolator.rs => execution_plans/partition_isolator.rs} | 0 src/flight_service/do_get.rs | 2 +- src/lib.rs | 4 ++-- src/physical_optimizer.rs | 2 +- src/stage/display.rs | 2 +- 8 files changed, 9 insertions(+), 9 deletions(-) rename src/{plan => execution_plans}/arrow_flight_read.rs (99%) rename src/{plan => execution_plans}/codec.rs (99%) rename src/{plan => execution_plans}/mod.rs (55%) rename src/{plan/isolator.rs => execution_plans/partition_isolator.rs} (100%) diff --git a/src/plan/arrow_flight_read.rs b/src/execution_plans/arrow_flight_read.rs similarity index 99% rename from src/plan/arrow_flight_read.rs rename to src/execution_plans/arrow_flight_read.rs index 700b959..84865dc 100644 --- a/src/plan/arrow_flight_read.rs +++ b/src/execution_plans/arrow_flight_read.rs @@ -1,8 +1,8 @@ use crate::channel_manager_ext::get_distributed_channel_resolver; use crate::config_extension_ext::ContextGrpcMetadata; use crate::errors::{map_flight_to_datafusion_error, map_status_to_datafusion_error}; +use crate::execution_plans::DistributedCodec; use crate::flight_service::{DoGet, StageKey}; -use crate::plan::DistributedCodec; use crate::stage::{proto_from_stage, ExecutionStage}; use crate::ChannelResolver; use arrow_flight::decode::FlightRecordBatchStream; diff --git a/src/plan/codec.rs b/src/execution_plans/codec.rs similarity index 99% rename from src/plan/codec.rs rename to src/execution_plans/codec.rs index 48abe72..aa5f226 100644 --- a/src/plan/codec.rs +++ b/src/execution_plans/codec.rs @@ -1,6 +1,6 @@ use super::PartitionIsolatorExec; use crate::common::ComposedPhysicalExtensionCodec; -use crate::plan::arrow_flight_read::ArrowFlightReadExec; +use crate::execution_plans::arrow_flight_read::ArrowFlightReadExec; use crate::user_codec_ext::get_distributed_user_codec; use datafusion::arrow::datatypes::Schema; use datafusion::execution::FunctionRegistry; diff --git a/src/plan/mod.rs b/src/execution_plans/mod.rs similarity index 55% rename from src/plan/mod.rs rename to src/execution_plans/mod.rs index 7d2a98e..699111e 100644 --- a/src/plan/mod.rs +++ b/src/execution_plans/mod.rs @@ -1,7 +1,7 @@ mod arrow_flight_read; mod codec; -mod isolator; +mod partition_isolator; pub use arrow_flight_read::ArrowFlightReadExec; pub use codec::DistributedCodec; -pub use isolator::{PartitionGroup, PartitionIsolatorExec}; +pub use partition_isolator::{PartitionGroup, PartitionIsolatorExec}; diff --git a/src/plan/isolator.rs b/src/execution_plans/partition_isolator.rs similarity index 100% rename from src/plan/isolator.rs rename to src/execution_plans/partition_isolator.rs diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index bda998b..365c737 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -1,9 +1,9 @@ use super::service::StageKey; use crate::config_extension_ext::ContextGrpcMetadata; use crate::errors::datafusion_error_to_tonic_status; +use crate::execution_plans::{DistributedCodec, PartitionGroup}; use crate::flight_service::service::ArrowFlightEndpoint; use crate::flight_service::session_builder::DistributedSessionBuilderContext; -use crate::plan::{DistributedCodec, PartitionGroup}; use crate::stage::{stage_from_proto, ExecutionStage, ExecutionStageProto}; use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; diff --git a/src/lib.rs b/src/lib.rs index 1746695..a310174 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,9 +5,9 @@ mod common; mod config_extension_ext; mod distributed_ext; mod errors; +mod execution_plans; mod flight_service; mod physical_optimizer; -mod plan; mod stage; mod task; mod user_codec_ext; @@ -17,11 +17,11 @@ pub mod test_utils; pub use channel_manager_ext::{BoxCloneSyncChannel, ChannelResolver}; pub use distributed_ext::DistributedExt; +pub use execution_plans::{ArrowFlightReadExec, PartitionIsolatorExec}; pub use flight_service::{ ArrowFlightEndpoint, DefaultSessionBuilder, DistributedSessionBuilder, DistributedSessionBuilderContext, MappedDistributedSessionBuilder, MappedDistributedSessionBuilderExt, }; pub use physical_optimizer::DistributedPhysicalOptimizerRule; -pub use plan::ArrowFlightReadExec; pub use stage::{display_stage_graphviz, ExecutionStage}; diff --git a/src/physical_optimizer.rs b/src/physical_optimizer.rs index 7dfeeb3..91e1a21 100644 --- a/src/physical_optimizer.rs +++ b/src/physical_optimizer.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::stage::ExecutionStage; -use crate::{plan::PartitionIsolatorExec, ArrowFlightReadExec}; +use super::{ArrowFlightReadExec, PartitionIsolatorExec}; use datafusion::common::tree_node::TreeNodeRecursion; use datafusion::error::DataFusionError; use datafusion::physical_plan::joins::PartitionMode; diff --git a/src/stage/display.rs b/src/stage/display.rs index fdf0572..bcf7bb2 100644 --- a/src/stage/display.rs +++ b/src/stage/display.rs @@ -1,5 +1,5 @@ use super::ExecutionStage; -use crate::plan::PartitionIsolatorExec; +use crate::PartitionIsolatorExec; use crate::{ task::{format_pg, ExecutionTask}, ArrowFlightReadExec, From f6267376ecb601dc6e310e09e3278030836c4b94 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Thu, 4 Sep 2025 20:47:53 +0200 Subject: [PATCH 2/6] Create protobuf/ folder --- src/distributed_ext.rs | 2 +- src/execution_plans/arrow_flight_read.rs | 4 ++-- src/execution_plans/mod.rs | 2 -- src/flight_service/do_get.rs | 16 ++++++++-------- src/lib.rs | 3 +-- .../codec.rs => protobuf/distributed_codec.rs} | 5 ++--- .../execution_stage_proto.rs} | 12 +++++------- src/protobuf/mod.rs | 7 +++++++ .../user_codec.rs} | 0 src/stage/display.rs | 8 +++----- src/stage/execution_stage.rs | 2 +- src/stage/mod.rs | 4 ++-- src/{ => stage}/task.rs | 0 13 files changed, 32 insertions(+), 33 deletions(-) rename src/{execution_plans/codec.rs => protobuf/distributed_codec.rs} (98%) rename src/{stage/proto.rs => protobuf/execution_stage_proto.rs} (96%) create mode 100644 src/protobuf/mod.rs rename src/{user_codec_ext.rs => protobuf/user_codec.rs} (100%) rename src/{ => stage}/task.rs (100%) diff --git a/src/distributed_ext.rs b/src/distributed_ext.rs index f7f06e0..503e5dc 100644 --- a/src/distributed_ext.rs +++ b/src/distributed_ext.rs @@ -2,7 +2,7 @@ use crate::channel_manager_ext::set_distributed_channel_resolver; use crate::config_extension_ext::{ set_distributed_option_extension, set_distributed_option_extension_from_headers, }; -use crate::user_codec_ext::set_distributed_user_codec; +use crate::protobuf::set_distributed_user_codec; use crate::ChannelResolver; use datafusion::common::DataFusionError; use datafusion::config::ConfigExtension; diff --git a/src/execution_plans/arrow_flight_read.rs b/src/execution_plans/arrow_flight_read.rs index 84865dc..22dec53 100644 --- a/src/execution_plans/arrow_flight_read.rs +++ b/src/execution_plans/arrow_flight_read.rs @@ -1,9 +1,9 @@ use crate::channel_manager_ext::get_distributed_channel_resolver; use crate::config_extension_ext::ContextGrpcMetadata; use crate::errors::{map_flight_to_datafusion_error, map_status_to_datafusion_error}; -use crate::execution_plans::DistributedCodec; use crate::flight_service::{DoGet, StageKey}; -use crate::stage::{proto_from_stage, ExecutionStage}; +use crate::protobuf::{proto_from_stage, DistributedCodec}; +use crate::stage::ExecutionStage; use crate::ChannelResolver; use arrow_flight::decode::FlightRecordBatchStream; use arrow_flight::error::FlightError; diff --git a/src/execution_plans/mod.rs b/src/execution_plans/mod.rs index 699111e..9d754a2 100644 --- a/src/execution_plans/mod.rs +++ b/src/execution_plans/mod.rs @@ -1,7 +1,5 @@ mod arrow_flight_read; -mod codec; mod partition_isolator; pub use arrow_flight_read::ArrowFlightReadExec; -pub use codec::DistributedCodec; pub use partition_isolator::{PartitionGroup, PartitionIsolatorExec}; diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index 365c737..a0e0e37 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -1,10 +1,11 @@ use super::service::StageKey; use crate::config_extension_ext::ContextGrpcMetadata; use crate::errors::datafusion_error_to_tonic_status; -use crate::execution_plans::{DistributedCodec, PartitionGroup}; +use crate::execution_plans::PartitionGroup; use crate::flight_service::service::ArrowFlightEndpoint; use crate::flight_service::session_builder::DistributedSessionBuilderContext; -use crate::stage::{stage_from_proto, ExecutionStage, ExecutionStageProto}; +use crate::protobuf::{stage_from_proto, DistributedCodec, ExecutionStageProto}; +use crate::stage::ExecutionStage; use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; use arrow_flight::flight_service_server::FlightService; @@ -169,16 +170,15 @@ impl ArrowFlightEndpoint { #[cfg(test)] mod tests { use super::*; + use crate::flight_service::session_builder::DefaultSessionBuilder; + use crate::stage::ExecutionTask; + use arrow_flight::Ticket; + use prost::{bytes::Bytes, Message}; + use tonic::Request; use uuid::Uuid; #[tokio::test] async fn test_task_data_partition_counting() { - use crate::flight_service::session_builder::DefaultSessionBuilder; - use crate::task::ExecutionTask; - use arrow_flight::Ticket; - use prost::{bytes::Bytes, Message}; - use tonic::Request; - // Create ArrowFlightEndpoint with DefaultSessionBuilder let endpoint = ArrowFlightEndpoint::try_new(DefaultSessionBuilder).expect("Failed to create endpoint"); diff --git a/src/lib.rs b/src/lib.rs index a310174..ff69049 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,9 +9,8 @@ mod execution_plans; mod flight_service; mod physical_optimizer; mod stage; -mod task; -mod user_codec_ext; +mod protobuf; #[cfg(any(feature = "integration", test))] pub mod test_utils; diff --git a/src/execution_plans/codec.rs b/src/protobuf/distributed_codec.rs similarity index 98% rename from src/execution_plans/codec.rs rename to src/protobuf/distributed_codec.rs index aa5f226..319fb26 100644 --- a/src/execution_plans/codec.rs +++ b/src/protobuf/distributed_codec.rs @@ -1,7 +1,6 @@ -use super::PartitionIsolatorExec; +use super::get_distributed_user_codec; use crate::common::ComposedPhysicalExtensionCodec; -use crate::execution_plans::arrow_flight_read::ArrowFlightReadExec; -use crate::user_codec_ext::get_distributed_user_codec; +use crate::{ArrowFlightReadExec, PartitionIsolatorExec}; use datafusion::arrow::datatypes::Schema; use datafusion::execution::FunctionRegistry; use datafusion::physical_plan::ExecutionPlan; diff --git a/src/stage/proto.rs b/src/protobuf/execution_stage_proto.rs similarity index 96% rename from src/stage/proto.rs rename to src/protobuf/execution_stage_proto.rs index a4a9f97..ff1e7a4 100644 --- a/src/stage/proto.rs +++ b/src/protobuf/execution_stage_proto.rs @@ -1,5 +1,7 @@ use std::sync::Arc; +use crate::stage::ExecutionTask; +use crate::ExecutionStage; use datafusion::{ common::internal_datafusion_err, error::{DataFusionError, Result}, @@ -11,10 +13,6 @@ use datafusion_proto::{ protobuf::PhysicalPlanNode, }; -use crate::task::ExecutionTask; - -use super::ExecutionStage; - #[derive(Clone, PartialEq, ::prost::Message)] pub struct ExecutionStageProto { /// Our query id @@ -98,6 +96,9 @@ pub fn stage_from_proto( mod tests { use std::sync::Arc; + use crate::protobuf::execution_stage_proto::ExecutionStageProto; + use crate::protobuf::{proto_from_stage, stage_from_proto}; + use crate::ExecutionStage; use datafusion::{ arrow::{ array::{RecordBatch, StringArray, UInt8Array}, @@ -112,9 +113,6 @@ mod tests { use prost::Message; use uuid::Uuid; - use crate::stage::proto::proto_from_stage; - use crate::stage::{proto::stage_from_proto, ExecutionStage, ExecutionStageProto}; - // create a simple mem table fn create_mem_table() -> Arc { let fields = vec![ diff --git a/src/protobuf/mod.rs b/src/protobuf/mod.rs new file mode 100644 index 0000000..036f9cf --- /dev/null +++ b/src/protobuf/mod.rs @@ -0,0 +1,7 @@ +mod distributed_codec; +mod execution_stage_proto; +mod user_codec; + +pub(crate) use distributed_codec::DistributedCodec; +pub(crate) use execution_stage_proto::{proto_from_stage, stage_from_proto, ExecutionStageProto}; +pub(crate) use user_codec::{get_distributed_user_codec, set_distributed_user_codec}; diff --git a/src/user_codec_ext.rs b/src/protobuf/user_codec.rs similarity index 100% rename from src/user_codec_ext.rs rename to src/protobuf/user_codec.rs diff --git a/src/stage/display.rs b/src/stage/display.rs index bcf7bb2..a2eb138 100644 --- a/src/stage/display.rs +++ b/src/stage/display.rs @@ -1,9 +1,7 @@ -use super::ExecutionStage; +use super::{ExecutionStage, ExecutionTask}; +use crate::stage::task::format_pg; +use crate::ArrowFlightReadExec; use crate::PartitionIsolatorExec; -use crate::{ - task::{format_pg, ExecutionTask}, - ArrowFlightReadExec, -}; use datafusion::physical_plan::{displayable, ExecutionPlan, ExecutionPlanProperties}; use datafusion::{ error::Result, diff --git a/src/stage/execution_stage.rs b/src/stage/execution_stage.rs index c4ba9b8..15e7635 100644 --- a/src/stage/execution_stage.rs +++ b/src/stage/execution_stage.rs @@ -7,7 +7,7 @@ use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; use crate::channel_manager_ext::get_distributed_channel_resolver; -use crate::task::ExecutionTask; +use crate::stage::ExecutionTask; use crate::ChannelResolver; use itertools::Itertools; use rand::Rng; diff --git a/src/stage/mod.rs b/src/stage/mod.rs index 3151f96..b23dea0 100644 --- a/src/stage/mod.rs +++ b/src/stage/mod.rs @@ -1,7 +1,7 @@ mod display; mod execution_stage; -mod proto; +mod task; pub use display::display_stage_graphviz; pub use execution_stage::ExecutionStage; -pub use proto::{proto_from_stage, stage_from_proto, ExecutionStageProto}; +pub use task::ExecutionTask; diff --git a/src/task.rs b/src/stage/task.rs similarity index 100% rename from src/task.rs rename to src/stage/task.rs From 3985f3d30d2161104ae4dd7ccb47e1e34418846c Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Thu, 4 Sep 2025 22:09:42 +0200 Subject: [PATCH 3/6] Move stage along with the other execution plan implementations --- src/execution_plans/arrow_flight_read.rs | 6 +- src/execution_plans/mod.rs | 2 + src/execution_plans/stage.rs | 538 ++++++++++++++++++ src/flight_service/do_get.rs | 91 +-- src/lib.rs | 6 +- src/physical_optimizer.rs | 11 +- src/protobuf/mod.rs | 4 +- ...xecution_stage_proto.rs => stage_proto.rs} | 77 ++- src/stage/display.rs | 247 -------- src/stage/execution_stage.rs | 288 ---------- src/stage/mod.rs | 7 - src/stage/task.rs | 71 --- 12 files changed, 637 insertions(+), 711 deletions(-) create mode 100644 src/execution_plans/stage.rs rename src/protobuf/{execution_stage_proto.rs => stage_proto.rs} (72%) delete mode 100644 src/stage/display.rs delete mode 100644 src/stage/execution_stage.rs delete mode 100644 src/stage/mod.rs delete mode 100644 src/stage/task.rs diff --git a/src/execution_plans/arrow_flight_read.rs b/src/execution_plans/arrow_flight_read.rs index 22dec53..86ab6f8 100644 --- a/src/execution_plans/arrow_flight_read.rs +++ b/src/execution_plans/arrow_flight_read.rs @@ -1,9 +1,9 @@ use crate::channel_manager_ext::get_distributed_channel_resolver; use crate::config_extension_ext::ContextGrpcMetadata; use crate::errors::{map_flight_to_datafusion_error, map_status_to_datafusion_error}; +use crate::execution_plans::StageExec; use crate::flight_service::{DoGet, StageKey}; use crate::protobuf::{proto_from_stage, DistributedCodec}; -use crate::stage::ExecutionStage; use crate::ChannelResolver; use arrow_flight::decode::FlightRecordBatchStream; use arrow_flight::error::FlightError; @@ -161,7 +161,7 @@ impl ExecutionPlan for ArrowFlightReadExec { let stage = context .session_config() - .get_extension::() + .get_extension::() .ok_or(internal_datafusion_err!( "ArrowFlightReadExec requires an ExecutionStage in the session config" ))?; @@ -218,7 +218,7 @@ impl ExecutionPlan for ArrowFlightReadExec { ); async move { - let url = task.url()?.ok_or(internal_datafusion_err!( + let url = task.url.ok_or(internal_datafusion_err!( "ArrowFlightReadExec: task is unassigned, cannot proceed" ))?; diff --git a/src/execution_plans/mod.rs b/src/execution_plans/mod.rs index 9d754a2..7cd5a5d 100644 --- a/src/execution_plans/mod.rs +++ b/src/execution_plans/mod.rs @@ -1,5 +1,7 @@ mod arrow_flight_read; mod partition_isolator; +mod stage; pub use arrow_flight_read::ArrowFlightReadExec; pub use partition_isolator::{PartitionGroup, PartitionIsolatorExec}; +pub use stage::{display_stage_graphviz, ExecutionTask, StageExec}; diff --git a/src/execution_plans/stage.rs b/src/execution_plans/stage.rs new file mode 100644 index 0000000..efb72e2 --- /dev/null +++ b/src/execution_plans/stage.rs @@ -0,0 +1,538 @@ +use crate::channel_manager_ext::get_distributed_channel_resolver; +use crate::{ArrowFlightReadExec, ChannelResolver, PartitionIsolatorExec}; +use datafusion::common::{exec_err, internal_err}; +use datafusion::error::{DataFusionError, Result}; +use datafusion::execution::TaskContext; +use datafusion::physical_plan::{ + displayable, DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, +}; +use datafusion::prelude::SessionContext; +use itertools::Itertools; +use rand::Rng; +use std::sync::Arc; +use url::Url; +use uuid::Uuid; + +/// A unit of isolation for a portion of a physical execution plan +/// that can be executed independently and across a network boundary. +/// It implements [`ExecutionPlan`] and can be executed to produce a +/// stream of record batches. +/// +/// An ExecutionTask is a finer grained unit of work compared to an ExecutionStage. +/// One ExecutionStage will create one or more ExecutionTasks +/// +/// When an [`StageExec`] is execute()'d if will execute its plan and return a stream +/// of record batches. +/// +/// If the stage has input stages, then it those input stages will be executed on remote resources +/// and will be provided the remainder of the stage tree. +/// +/// For example if our stage tree looks like this: +/// +/// ```text +/// ┌─────────┐ +/// │ stage 1 │ +/// └───┬─────┘ +/// │ +/// ┌──────┴────────┐ +/// ┌────┴────┐ ┌────┴────┐ +/// │ stage 2 │ │ stage 3 │ +/// └────┬────┘ └─────────┘ +/// │ +/// ┌──────┴────────┐ +/// ┌────┴────┐ ┌────┴────┐ +/// │ stage 4 │ │ Stage 5 │ +/// └─────────┘ └─────────┘ +/// +/// ``` +/// +/// Then executing Stage 1 will run its plan locally. Stage 1 has two inputs, Stage 2 and Stage 3. We +/// know these will execute on remote resources. As such the plan for Stage 1 must contain an +/// [`ArrowFlightReadExec`] node that will read the results of Stage 2 and Stage 3 and coalese the +/// results. +/// +/// When Stage 1's [`ArrowFlightReadExec`] node is executed, it makes an ArrowFlightRequest to the +/// host assigned in the Stage. It provides the following Stage tree serialilzed in the body of the +/// Arrow Flight Ticket: +/// +/// ```text +/// ┌─────────┐ +/// │ Stage 2 │ +/// └────┬────┘ +/// │ +/// ┌──────┴────────┐ +/// ┌────┴────┐ ┌────┴────┐ +/// │ Stage 4 │ │ Stage 5 │ +/// └─────────┘ └─────────┘ +/// +/// ``` +/// +/// The receiving ArrowFlightEndpoint will then execute Stage 2 and will repeat this process. +/// +/// When Stage 4 is executed, it has no input tasks, so it is assumed that the plan included in that +/// Stage can complete on its own; its likely holding a leaf node in the overall phyysical plan and +/// producing data from a [`DataSourceExec`]. +#[derive(Debug, Clone)] +pub struct StageExec { + /// Our query_id + pub query_id: Uuid, + /// Our stage number + pub num: usize, + /// Our stage name + pub name: String, + /// The physical execution plan that this stage will execute. + pub plan: Arc, + /// The input stages to this stage + pub inputs: Vec>, + /// Our tasks which tell us how finely grained to execute the partitions in + /// the plan + pub tasks: Vec, + /// tree depth of our location in the stage tree, used for display only + pub depth: usize, +} + +#[derive(Debug, Clone)] +pub struct ExecutionTask { + /// The url of the worker that will execute this task. A None value is interpreted as + /// unassigned. + pub url: Option, + /// The partitions that we can execute from this plan + pub partition_group: Vec, +} + +impl StageExec { + /// Creates a new `ExecutionStage` with the given plan and inputs. One task will be created + /// responsible for partitions in the plan. + pub fn new( + query_id: Uuid, + num: usize, + plan: Arc, + inputs: Vec>, + ) -> Self { + let name = format!("Stage {:<3}", num); + let partition_group = (0..plan.properties().partitioning.partition_count()).collect(); + StageExec { + query_id, + num, + name, + plan, + inputs: inputs + .into_iter() + .map(|s| s as Arc) + .collect(), + tasks: vec![ExecutionTask { + partition_group, + url: None, + }], + depth: 0, + } + } + + /// Recalculate the tasks for this stage based on the number of partitions in the plan + /// and the maximum number of partitions per task. + /// + /// This will unset any worker assignments + pub fn with_maximum_partitions_per_task(mut self, max_partitions_per_task: usize) -> Self { + let partitions = self.plan.properties().partitioning.partition_count(); + + self.tasks = (0..partitions) + .chunks(max_partitions_per_task) + .into_iter() + .map(|partition_group| ExecutionTask { + partition_group: partition_group.collect(), + url: None, + }) + .collect(); + self + } + + /// Returns the name of this stage + pub fn name(&self) -> String { + format!("Stage {:<3}", self.num) + } + + /// Returns an iterator over the child stages of this stage cast as &ExecutionStage + /// which can be useful + pub fn child_stages_iter(&self) -> impl Iterator { + self.inputs + .iter() + .filter_map(|s| s.as_any().downcast_ref::()) + } + + /// Returns the name of this stage including child stage numbers if any. + pub fn name_with_children(&self) -> String { + let child_str = if self.inputs.is_empty() { + "".to_string() + } else { + format!( + " Child Stages:[{}] ", + self.child_stages_iter() + .map(|s| format!("{}", s.num)) + .collect::>() + .join(", ") + ) + }; + format!("Stage {:<3}{}", self.num, child_str) + } + + pub fn try_assign(self, channel_resolver: &impl ChannelResolver) -> Result { + let urls: Vec = channel_resolver.get_urls()?; + if urls.is_empty() { + return internal_err!("No URLs found in ChannelManager"); + } + + Ok(self) + } + + fn try_assign_urls(&self, urls: &[Url]) -> Result { + let assigned_children = self + .child_stages_iter() + .map(|child| { + child + .clone() // TODO: avoid cloning if possible + .try_assign_urls(urls) + .map(|c| Arc::new(c) as Arc) + }) + .collect::>>()?; + + // pick a random starting position + let mut rng = rand::thread_rng(); + let start_idx = rng.gen_range(0..urls.len()); + + let assigned_tasks = self + .tasks + .iter() + .enumerate() + .map(|(i, task)| ExecutionTask { + partition_group: task.partition_group.clone(), + url: Some(urls[(start_idx + i) % urls.len()].clone()), + }) + .collect::>(); + + let assigned_stage = StageExec { + query_id: self.query_id, + num: self.num, + name: self.name.clone(), + plan: self.plan.clone(), + inputs: assigned_children, + tasks: assigned_tasks, + depth: self.depth, + }; + + Ok(assigned_stage) + } +} + +impl ExecutionPlan for StageExec { + fn name(&self) -> &str { + &self.name + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn children(&self) -> Vec<&Arc> { + self.inputs.iter().collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(StageExec { + query_id: self.query_id, + num: self.num, + name: self.name.clone(), + plan: self.plan.clone(), + inputs: children, + tasks: self.tasks.clone(), + depth: self.depth, + })) + } + + fn properties(&self) -> &datafusion::physical_plan::PlanProperties { + self.plan.properties() + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let stage = self + .as_any() + .downcast_ref::() + .expect("Unwrapping myself should always work"); + + let Some(channel_resolver) = get_distributed_channel_resolver(context.session_config()) + else { + return exec_err!("ChannelManager not found in session config"); + }; + + let urls = channel_resolver.get_urls()?; + + let assigned_stage = stage + .try_assign_urls(&urls) + .map(Arc::new) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; + + // insert the stage into the context so that ExecutionPlan nodes + // that care about the stage can access it + let config = context + .session_config() + .clone() + .with_extension(assigned_stage.clone()); + + let new_ctx = + SessionContext::new_with_config_rt(config, context.runtime_env().clone()).task_ctx(); + + assigned_stage.plan.execute(partition, new_ctx) + } +} + +/// Be able to display a nice tree for stages. +/// +/// The challenge to doing this at the moment is that `TreeRenderVistor` +/// in [`datafusion::physical_plan::display`] is not public, and that it also +/// is specific to a `ExecutionPlan` trait object, which we don't have. +/// +/// TODO: try to upstream a change to make rendering of Trees (logical, physical, stages) against +/// a generic trait rather than a specific trait object. This would allow us to +/// use the same rendering code for all trees, including stages. +/// +/// In the meantime, we can make a dummy ExecutionPlan that will let us render +/// the Stage tree. +use std::fmt::Write; + +// Unicode box-drawing characters for creating borders and connections. +const LTCORNER: &str = "┌"; // Left top corner +const LDCORNER: &str = "└"; // Left bottom corner +const VERTICAL: &str = "│"; // Vertical line +const HORIZONTAL: &str = "─"; // Horizontal line + +impl StageExec { + fn format(&self, plan: &dyn ExecutionPlan, indent: usize, f: &mut String) -> std::fmt::Result { + let mut node_str = displayable(plan).one_line().to_string(); + node_str.pop(); + write!(f, "{} {node_str}", " ".repeat(indent))?; + + if let Some(ArrowFlightReadExec::Ready(ready)) = + plan.as_any().downcast_ref::() + { + let Some(input_stage) = &self.child_stages_iter().find(|v| v.num == ready.stage_num) + else { + writeln!(f, "Wrong partition number {}", ready.stage_num)?; + return Ok(()); + }; + let tasks = input_stage.tasks.len(); + let partitions = plan.output_partitioning().partition_count(); + let stage = ready.stage_num; + write!( + f, + " input_stage={stage}, input_partitions={partitions}, input_tasks={tasks}", + )?; + } + + if plan.as_any().is::() { + write!(f, " {}", format_tasks_for_partition_isolator(&self.tasks))?; + } + writeln!(f)?; + + for child in plan.children() { + self.format(child.as_ref(), indent + 2, f)?; + } + Ok(()) + } +} + +impl DisplayAs for StageExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + #[allow(clippy::format_in_format_args)] + match t { + DisplayFormatType::Default => { + write!(f, "{}", self.name) + } + DisplayFormatType::Verbose => { + writeln!( + f, + "{}{}{}{}", + LTCORNER, + HORIZONTAL.repeat(5), + format!(" {} ", self.name), + format_tasks_for_stage(&self.tasks), + )?; + + let mut plan_str = String::new(); + self.format(self.plan.as_ref(), 0, &mut plan_str)?; + let plan_str = plan_str + .split('\n') + .filter(|v| !v.is_empty()) + .collect::>() + .join(&format!("\n{}{}", " ".repeat(self.depth), VERTICAL)); + writeln!(f, "{}{}{}", " ".repeat(self.depth), VERTICAL, plan_str)?; + write!( + f, + "{}{}{}", + " ".repeat(self.depth), + LDCORNER, + HORIZONTAL.repeat(50) + )?; + + Ok(()) + } + DisplayFormatType::TreeRender => write!(f, "{}", format_tasks_for_stage(&self.tasks),), + } + } +} + +fn format_tasks_for_stage(tasks: &[ExecutionTask]) -> String { + let mut result = "Tasks: ".to_string(); + for (i, t) in tasks.iter().enumerate() { + result += &format!("t{i}:["); + result += &t.partition_group.iter().map(|v| format!("p{v}")).join(","); + result += "] " + } + result +} + +fn format_tasks_for_partition_isolator(tasks: &[ExecutionTask]) -> String { + let mut result = "Tasks: ".to_string(); + let mut partitions = vec![]; + for t in tasks.iter() { + partitions.extend(vec!["__".to_string(); t.partition_group.len()]) + } + for (i, t) in tasks.iter().enumerate() { + let mut partitions = partitions.clone(); + for (i, p) in t.partition_group.iter().enumerate() { + partitions[*p] = format!("p{i}") + } + result += &format!("t{i}:[{}] ", partitions.join(",")); + } + result +} + +pub fn display_stage_graphviz(stage: &StageExec) -> Result { + let mut f = String::new(); + + let num_colors = 5; // this should aggree with the colorscheme chosen from + // https://graphviz.org/doc/info/colors.html + let colorscheme = "spectral5"; + + writeln!(f, "digraph G {{")?; + writeln!(f, " node[shape=rect];")?; + writeln!(f, " rankdir=BT;")?; + writeln!(f, " ranksep=2;")?; + writeln!(f, " edge[colorscheme={},penwidth=2.0];", colorscheme)?; + + // we'll keep a stack of stage ref, parrent stage ref + let mut stack: Vec<(&StageExec, Option<&StageExec>)> = vec![(stage, None)]; + + while let Some((stage, parent)) = stack.pop() { + writeln!(f, " subgraph cluster_{} {{", stage.num)?; + writeln!(f, " node[shape=record];")?; + writeln!(f, " label=\"{}\";", stage.name())?; + writeln!(f, " labeljust=r;")?; + writeln!(f, " labelloc=b;")?; // this will put the label at the top as our + // rankdir=BT + + stage.tasks.iter().try_for_each(|task| { + let lab = task + .partition_group + .iter() + .map(|p| format!("{}", p, p)) + .collect::>() + .join("|"); + writeln!( + f, + " \"{}_{}\"[label = \"{}\"]", + stage.num, + format_partition_group(&task.partition_group), + lab, + )?; + + if let Some(our_parent) = parent { + our_parent.tasks.iter().try_for_each(|ptask| { + task.partition_group.iter().try_for_each(|partition| { + ptask.partition_group.iter().try_for_each(|ppartition| { + writeln!( + f, + " \"{}_{}\":p{}:n -> \"{}_{}\":p{}:s[color={}]", + stage.num, + format_partition_group(&task.partition_group), + partition, + our_parent.num, + format_partition_group(&ptask.partition_group), + ppartition, + (partition) % num_colors + 1 + ) + }) + }) + })?; + } + + Ok::<(), std::fmt::Error>(()) + })?; + + // now we try to force the left right nature of tasks to be honored + writeln!(f, " {{")?; + writeln!(f, " rank = same;")?; + stage.tasks.iter().try_for_each(|task| { + writeln!( + f, + " \"{}_{}\"", + stage.num, + format_partition_group(&task.partition_group) + )?; + + Ok::<(), std::fmt::Error>(()) + })?; + writeln!(f, " }}")?; + // combined with rank = same, the invisible edges will force the tasks to be + // laid out in a single row within the stage + for i in 0..stage.tasks.len() - 1 { + writeln!( + f, + " \"{}_{}\":w -> \"{}_{}\":e[style=invis]", + stage.num, + format_partition_group(&stage.tasks[i].partition_group), + stage.num, + format_partition_group(&stage.tasks[i + 1].partition_group), + )?; + } + + // add a node for the plan, its way too big! Alternatives to add it? + /*writeln!( + f, + " \"{}_plan\"[label = \"{}\", shape=box];", + stage.num, + displayable(stage.plan.as_ref()).indent(false) + )?; + */ + + writeln!(f, " }}")?; + + for child in stage.child_stages_iter() { + stack.push((child, Some(stage))); + } + } + + writeln!(f, "}}")?; + Ok(f) +} + +pub fn format_partition_group(partition_group: &[usize]) -> String { + if partition_group.len() > 2 { + format!( + "{}..{}", + partition_group[0], + partition_group[partition_group.len() - 1] + ) + } else { + partition_group + .iter() + .map(|pg| format!("{pg}")) + .collect::>() + .join(",") + } +} diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index a0e0e37..c316525 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -1,11 +1,10 @@ use super::service::StageKey; use crate::config_extension_ext::ContextGrpcMetadata; use crate::errors::datafusion_error_to_tonic_status; -use crate::execution_plans::PartitionGroup; +use crate::execution_plans::{PartitionGroup, StageExec}; use crate::flight_service::service::ArrowFlightEndpoint; use crate::flight_service::session_builder::DistributedSessionBuilderContext; -use crate::protobuf::{stage_from_proto, DistributedCodec, ExecutionStageProto}; -use crate::stage::ExecutionStage; +use crate::protobuf::{stage_from_proto, DistributedCodec, StageExecProto}; use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; use arrow_flight::flight_service_server::FlightService; @@ -23,7 +22,7 @@ use tonic::{Request, Response, Status}; pub struct DoGet { /// The ExecutionStage that we are going to execute #[prost(message, optional, tag = "1")] - pub stage_proto: Option, + pub stage_proto: Option, /// The index to the task within the stage that we want to execute #[prost(uint64, tag = "2")] pub task_number: u64, @@ -43,7 +42,7 @@ pub struct DoGet { /// by concurrent requests for the same task which execute separate partitions. pub struct TaskData { pub(super) state: SessionState, - pub(super) stage: Arc, + pub(super) stage: Arc, ///num_partitions_remaining is initialized to the total number of partitions in the task (not /// only tasks in the partition group). This is decremented for each request to the endpoint /// for this task. Once this count is zero, the task is likely complete. The task may not be @@ -80,8 +79,7 @@ impl ArrowFlightEndpoint { stage.name() )))?; - let partition_group = - PartitionGroup(task.partition_group.iter().map(|p| *p as usize).collect()); + let partition_group = PartitionGroup(task.partition_group.clone()); state.config_mut().set_extension(Arc::new(partition_group)); let inner_plan = stage.plan.clone(); @@ -171,8 +169,15 @@ impl ArrowFlightEndpoint { mod tests { use super::*; use crate::flight_service::session_builder::DefaultSessionBuilder; - use crate::stage::ExecutionTask; + use crate::protobuf::proto_from_stage; + use crate::ExecutionTask; + use arrow::datatypes::{Schema, SchemaRef}; use arrow_flight::Ticket; + use datafusion::physical_expr::Partitioning; + use datafusion::physical_plan::empty::EmptyExec; + use datafusion::physical_plan::repartition::RepartitionExec; + use datafusion::physical_plan::ExecutionPlan; + use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; use prost::{bytes::Bytes, Message}; use tonic::Request; use uuid::Uuid; @@ -187,47 +192,45 @@ mod tests { let num_tasks = 3; let num_partitions_per_task = 3; let stage_id = 1; - let query_id_uuid = Uuid::new_v4(); - let query_id = query_id_uuid.as_bytes().to_vec(); + let query_id = Uuid::new_v4(); // Set up protos. let mut tasks = Vec::new(); for i in 0..num_tasks { tasks.push(ExecutionTask { - url_str: None, + url: None, partition_group: vec![i], // Set a random partition in the partition group. }); } - let stage_proto = ExecutionStageProto { - query_id: query_id.clone(), + let stage = StageExec { + query_id, num: 1, name: format!("test_stage_{}", 1), - plan: Some(Box::new(create_mock_physical_plan_proto( - num_partitions_per_task, - ))), + plan: create_mock_physical_plan(num_partitions_per_task), inputs: vec![], tasks, + depth: 0, }; let task_keys = [ StageKey { - query_id: query_id_uuid.to_string(), + query_id: query_id.to_string(), stage_id, task_number: 0, }, StageKey { - query_id: query_id_uuid.to_string(), + query_id: query_id.to_string(), stage_id, task_number: 1, }, StageKey { - query_id: query_id_uuid.to_string(), + query_id: query_id.to_string(), stage_id, task_number: 2, }, ]; - + let stage_proto = proto_from_stage(&stage, &DefaultPhysicalExtensionCodec {}).unwrap(); let stage_proto_for_closure = stage_proto.clone(); let endpoint_ref = &endpoint; let do_get = async move |partition: u64, task_number: u64, stage_key: StageKey| { @@ -251,20 +254,15 @@ mod tests { }; // For each task, call do_get() for each partition except the last. - for task_number in 0..num_tasks { + for (task_number, task_key) in task_keys.iter().enumerate() { for partition in 0..num_partitions_per_task - 1 { - let result = do_get( - partition as u64, - task_number, - task_keys[task_number as usize].clone(), - ) - .await; + let result = do_get(partition as u64, task_number as u64, task_key.clone()).await; assert!(result.is_ok()); } } // Check that the endpoint has not evicted any task states. - assert_eq!(endpoint.stages.len(), num_tasks as usize); + assert_eq!(endpoint.stages.len(), num_tasks); // Run the last partition of task 0. Any partition number works. Verify that the task state // is evicted because all partitions have been processed. @@ -289,38 +287,9 @@ mod tests { assert_eq!(stored_stage_keys.len(), 0); } - // Helper to create a mock physical plan proto - fn create_mock_physical_plan_proto( - partitions: usize, - ) -> datafusion_proto::protobuf::PhysicalPlanNode { - use datafusion_proto::protobuf::partitioning::PartitionMethod; - use datafusion_proto::protobuf::{ - Partitioning, PhysicalPlanNode, RepartitionExecNode, Schema, - }; - - // Create a repartition node that will have the desired partition count - PhysicalPlanNode { - physical_plan_type: Some( - datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType::Repartition( - Box::new(RepartitionExecNode { - input: Some(Box::new(PhysicalPlanNode { - physical_plan_type: Some( - datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType::Empty( - datafusion_proto::protobuf::EmptyExecNode { - schema: Some(Schema { - columns: vec![], - metadata: std::collections::HashMap::new(), - }) - } - ) - ), - })), - partitioning: Some(Partitioning { - partition_method: Some(PartitionMethod::RoundRobin(partitions as u64)), - }), - }) - ) - ), - } + // Helper to create a mock physical plan + fn create_mock_physical_plan(partitions: usize) -> Arc { + let node = Arc::new(EmptyExec::new(SchemaRef::new(Schema::empty()))); + Arc::new(RepartitionExec::try_new(node, Partitioning::RoundRobinBatch(partitions)).unwrap()) } } diff --git a/src/lib.rs b/src/lib.rs index ff69049..f07103b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,7 +8,6 @@ mod errors; mod execution_plans; mod flight_service; mod physical_optimizer; -mod stage; mod protobuf; #[cfg(any(feature = "integration", test))] @@ -16,11 +15,12 @@ pub mod test_utils; pub use channel_manager_ext::{BoxCloneSyncChannel, ChannelResolver}; pub use distributed_ext::DistributedExt; -pub use execution_plans::{ArrowFlightReadExec, PartitionIsolatorExec}; +pub use execution_plans::{ + display_stage_graphviz, ArrowFlightReadExec, ExecutionTask, PartitionIsolatorExec, StageExec, +}; pub use flight_service::{ ArrowFlightEndpoint, DefaultSessionBuilder, DistributedSessionBuilder, DistributedSessionBuilderContext, MappedDistributedSessionBuilder, MappedDistributedSessionBuilderExt, }; pub use physical_optimizer::DistributedPhysicalOptimizerRule; -pub use stage::{display_stage_graphviz, ExecutionStage}; diff --git a/src/physical_optimizer.rs b/src/physical_optimizer.rs index 91e1a21..70c0a0a 100644 --- a/src/physical_optimizer.rs +++ b/src/physical_optimizer.rs @@ -1,7 +1,6 @@ use std::sync::Arc; -use super::stage::ExecutionStage; -use super::{ArrowFlightReadExec, PartitionIsolatorExec}; +use super::{ArrowFlightReadExec, PartitionIsolatorExec, StageExec}; use datafusion::common::tree_node::TreeNodeRecursion; use datafusion::error::DataFusionError; use datafusion::physical_plan::joins::PartitionMode; @@ -55,7 +54,7 @@ impl PhysicalOptimizerRule for DistributedPhysicalOptimizerRule { _config: &ConfigOptions, ) -> Result> { // We can only optimize plans that are not already distributed - if plan.as_any().is::() { + if plan.as_any().is::() { return Ok(plan); } @@ -106,7 +105,7 @@ impl DistributedPhysicalOptimizerRule { pub fn distribute_plan( &self, plan: Arc, - ) -> Result { + ) -> Result { let query_id = Uuid::new_v4(); self._distribute_plan_inner(query_id, plan, &mut 1, 0) } @@ -117,7 +116,7 @@ impl DistributedPhysicalOptimizerRule { plan: Arc, num: &mut usize, depth: usize, - ) -> Result { + ) -> Result { let mut inputs = vec![]; let distributed = plan.clone().transform_down(|plan| { @@ -134,7 +133,7 @@ impl DistributedPhysicalOptimizerRule { })?; let inputs = inputs.into_iter().map(Arc::new).collect(); - let mut stage = ExecutionStage::new(query_id, *num, distributed.data, inputs); + let mut stage = StageExec::new(query_id, *num, distributed.data, inputs); *num += 1; stage = match (self.partitions_per_task, can_be_divided(&plan)?) { diff --git a/src/protobuf/mod.rs b/src/protobuf/mod.rs index 036f9cf..0891846 100644 --- a/src/protobuf/mod.rs +++ b/src/protobuf/mod.rs @@ -1,7 +1,7 @@ mod distributed_codec; -mod execution_stage_proto; +mod stage_proto; mod user_codec; pub(crate) use distributed_codec::DistributedCodec; -pub(crate) use execution_stage_proto::{proto_from_stage, stage_from_proto, ExecutionStageProto}; +pub(crate) use stage_proto::{proto_from_stage, stage_from_proto, StageExecProto}; pub(crate) use user_codec::{get_distributed_user_codec, set_distributed_user_codec}; diff --git a/src/protobuf/execution_stage_proto.rs b/src/protobuf/stage_proto.rs similarity index 72% rename from src/protobuf/execution_stage_proto.rs rename to src/protobuf/stage_proto.rs index ff1e7a4..332e6f7 100644 --- a/src/protobuf/execution_stage_proto.rs +++ b/src/protobuf/stage_proto.rs @@ -1,7 +1,4 @@ -use std::sync::Arc; - -use crate::stage::ExecutionTask; -use crate::ExecutionStage; +use crate::execution_plans::{ExecutionTask, StageExec}; use datafusion::{ common::internal_datafusion_err, error::{DataFusionError, Result}, @@ -12,56 +9,76 @@ use datafusion_proto::{ physical_plan::{AsExecutionPlan, PhysicalExtensionCodec}, protobuf::PhysicalPlanNode, }; +use std::sync::Arc; +use url::Url; #[derive(Clone, PartialEq, ::prost::Message)] -pub struct ExecutionStageProto { +pub struct StageExecProto { /// Our query id #[prost(bytes, tag = "1")] - pub query_id: Vec, + query_id: Vec, /// Our stage number #[prost(uint64, tag = "2")] - pub num: u64, + num: u64, /// Our stage name #[prost(string, tag = "3")] - pub name: String, + name: String, /// The physical execution plan that this stage will execute. #[prost(message, optional, boxed, tag = "4")] - pub plan: Option>, + plan: Option>, /// The input stages to this stage #[prost(repeated, message, tag = "5")] - pub inputs: Vec, + inputs: Vec, /// Our tasks which tell us how finely grained to execute the partitions in /// the plan #[prost(message, repeated, tag = "6")] - pub tasks: Vec, + tasks: Vec, +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ExecutionTaskProto { + /// The url of the worker that will execute this task. A None value is interpreted as + /// unassigned. + #[prost(string, optional, tag = "1")] + url_str: Option, + /// The partitions that we can execute from this plan + #[prost(uint64, repeated, tag = "2")] + partition_group: Vec, } pub fn proto_from_stage( - stage: &ExecutionStage, + stage: &StageExec, codec: &dyn PhysicalExtensionCodec, -) -> Result { +) -> Result { let proto_plan = PhysicalPlanNode::try_from_physical_plan(stage.plan.clone(), codec)?; let inputs = stage .child_stages_iter() .map(|s| proto_from_stage(s, codec)) .collect::>>()?; - Ok(ExecutionStageProto { + Ok(StageExecProto { query_id: stage.query_id.as_bytes().to_vec(), num: stage.num as u64, name: stage.name(), plan: Some(Box::new(proto_plan)), inputs, - tasks: stage.tasks.clone(), + tasks: stage + .tasks + .iter() + .map(|task| ExecutionTaskProto { + url_str: task.url.as_ref().map(|v| v.to_string()), + partition_group: task.partition_group.iter().map(|v| *v as u64).collect(), + }) + .collect(), }) } pub fn stage_from_proto( - msg: ExecutionStageProto, + msg: StageExecProto, registry: &dyn FunctionRegistry, runtime: &RuntimeEnv, codec: &dyn PhysicalExtensionCodec, -) -> Result { +) -> Result { let plan_node = msg.plan.ok_or(internal_datafusion_err!( "ExecutionStageMsg is missing the plan" ))?; @@ -77,7 +94,7 @@ pub fn stage_from_proto( }) .collect::>>()?; - Ok(ExecutionStage { + Ok(StageExec { query_id: msg .query_id .try_into() @@ -86,7 +103,21 @@ pub fn stage_from_proto( name: msg.name, plan, inputs, - tasks: msg.tasks, + tasks: msg + .tasks + .into_iter() + .map(|task| { + Ok(ExecutionTask { + url: task + .url_str + .map(|u| { + Url::parse(&u).map_err(|_| internal_datafusion_err!("Invalid URL: {u}")) + }) + .transpose()?, + partition_group: task.partition_group.iter().map(|v| *v as usize).collect(), + }) + }) + .collect::>>()?, depth: 0, }) } @@ -96,9 +127,9 @@ pub fn stage_from_proto( mod tests { use std::sync::Arc; - use crate::protobuf::execution_stage_proto::ExecutionStageProto; + use crate::protobuf::stage_proto::StageExecProto; use crate::protobuf::{proto_from_stage, stage_from_proto}; - use crate::ExecutionStage; + use crate::StageExec; use datafusion::{ arrow::{ array::{RecordBatch, StringArray, UInt8Array}, @@ -147,7 +178,7 @@ mod tests { .await?; // Wrap it in an ExecutionStage - let stage = ExecutionStage { + let stage = StageExec { query_id: Uuid::new_v4(), num: 1, name: "TestStage".to_string(), @@ -167,7 +198,7 @@ mod tests { .map_err(|e| internal_datafusion_err!("couldn't encode {e:#?}"))?; // Deserialize from bytes - let decoded_msg = ExecutionStageProto::decode(&buf[..]) + let decoded_msg = StageExecProto::decode(&buf[..]) .map_err(|e| internal_datafusion_err!("couldn't decode {e:#?}"))?; // Convert back to ExecutionStage diff --git a/src/stage/display.rs b/src/stage/display.rs deleted file mode 100644 index a2eb138..0000000 --- a/src/stage/display.rs +++ /dev/null @@ -1,247 +0,0 @@ -use super::{ExecutionStage, ExecutionTask}; -use crate::stage::task::format_pg; -use crate::ArrowFlightReadExec; -use crate::PartitionIsolatorExec; -use datafusion::physical_plan::{displayable, ExecutionPlan, ExecutionPlanProperties}; -use datafusion::{ - error::Result, - physical_plan::{DisplayAs, DisplayFormatType}, -}; -use itertools::Itertools; -/// Be able to display a nice tree for stages. -/// -/// The challenge to doing this at the moment is that `TreeRenderVistor` -/// in [`datafusion::physical_plan::display`] is not public, and that it also -/// is specific to a `ExecutionPlan` trait object, which we don't have. -/// -/// TODO: try to upstream a change to make rendering of Trees (logical, physical, stages) against -/// a generic trait rather than a specific trait object. This would allow us to -/// use the same rendering code for all trees, including stages. -/// -/// In the meantime, we can make a dummy ExecutionPlan that will let us render -/// the Stage tree. -use std::fmt::Write; - -// Unicode box-drawing characters for creating borders and connections. -const LTCORNER: &str = "┌"; // Left top corner -const LDCORNER: &str = "└"; // Left bottom corner -const VERTICAL: &str = "│"; // Vertical line -const HORIZONTAL: &str = "─"; // Horizontal line - -impl ExecutionStage { - fn format(&self, plan: &dyn ExecutionPlan, indent: usize, f: &mut String) -> std::fmt::Result { - let mut node_str = displayable(plan).one_line().to_string(); - node_str.pop(); - write!(f, "{} {node_str}", " ".repeat(indent))?; - - if let Some(ArrowFlightReadExec::Ready(ready)) = - plan.as_any().downcast_ref::() - { - let Some(input_stage) = &self.child_stages_iter().find(|v| v.num == ready.stage_num) - else { - writeln!(f, "Wrong partition number {}", ready.stage_num)?; - return Ok(()); - }; - let tasks = input_stage.tasks.len(); - let partitions = plan.output_partitioning().partition_count(); - let stage = ready.stage_num; - write!( - f, - " input_stage={stage}, input_partitions={partitions}, input_tasks={tasks}", - )?; - } - - if plan.as_any().is::() { - write!(f, " {}", format_tasks_for_partition_isolator(&self.tasks))?; - } - writeln!(f)?; - - for child in plan.children() { - self.format(child.as_ref(), indent + 2, f)?; - } - Ok(()) - } -} - -impl DisplayAs for ExecutionStage { - fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { - #[allow(clippy::format_in_format_args)] - match t { - DisplayFormatType::Default => { - write!(f, "{}", self.name) - } - DisplayFormatType::Verbose => { - writeln!( - f, - "{}{}{}{}", - LTCORNER, - HORIZONTAL.repeat(5), - format!(" {} ", self.name), - format_tasks_for_stage(&self.tasks), - )?; - - let mut plan_str = String::new(); - self.format(self.plan.as_ref(), 0, &mut plan_str)?; - let plan_str = plan_str - .split('\n') - .filter(|v| !v.is_empty()) - .collect::>() - .join(&format!("\n{}{}", " ".repeat(self.depth), VERTICAL)); - writeln!(f, "{}{}{}", " ".repeat(self.depth), VERTICAL, plan_str)?; - write!( - f, - "{}{}{}", - " ".repeat(self.depth), - LDCORNER, - HORIZONTAL.repeat(50) - )?; - - Ok(()) - } - DisplayFormatType::TreeRender => write!( - f, - "{}", - self.tasks - .iter() - .map(|task| format!("{task}")) - .collect::>() - .join("\n") - ), - } - } -} - -pub fn display_stage_graphviz(stage: &ExecutionStage) -> Result { - let mut f = String::new(); - - let num_colors = 5; // this should aggree with the colorscheme chosen from - // https://graphviz.org/doc/info/colors.html - let colorscheme = "spectral5"; - - writeln!(f, "digraph G {{")?; - writeln!(f, " node[shape=rect];")?; - writeln!(f, " rankdir=BT;")?; - writeln!(f, " ranksep=2;")?; - writeln!(f, " edge[colorscheme={},penwidth=2.0];", colorscheme)?; - - // we'll keep a stack of stage ref, parrent stage ref - let mut stack: Vec<(&ExecutionStage, Option<&ExecutionStage>)> = vec![(stage, None)]; - - while let Some((stage, parent)) = stack.pop() { - writeln!(f, " subgraph cluster_{} {{", stage.num)?; - writeln!(f, " node[shape=record];")?; - writeln!(f, " label=\"{}\";", stage.name())?; - writeln!(f, " labeljust=r;")?; - writeln!(f, " labelloc=b;")?; // this will put the label at the top as our - // rankdir=BT - - stage.tasks.iter().try_for_each(|task| { - let lab = task - .partition_group - .iter() - .map(|p| format!("{}", p, p)) - .collect::>() - .join("|"); - writeln!( - f, - " \"{}_{}\"[label = \"{}\"]", - stage.num, - format_pg(&task.partition_group), - lab, - )?; - - if let Some(our_parent) = parent { - our_parent.tasks.iter().try_for_each(|ptask| { - task.partition_group.iter().try_for_each(|partition| { - ptask.partition_group.iter().try_for_each(|ppartition| { - writeln!( - f, - " \"{}_{}\":p{}:n -> \"{}_{}\":p{}:s[color={}]", - stage.num, - format_pg(&task.partition_group), - partition, - our_parent.num, - format_pg(&ptask.partition_group), - ppartition, - (partition) % num_colors + 1 - ) - }) - }) - })?; - } - - Ok::<(), std::fmt::Error>(()) - })?; - - // now we try to force the left right nature of tasks to be honored - writeln!(f, " {{")?; - writeln!(f, " rank = same;")?; - stage.tasks.iter().try_for_each(|task| { - writeln!( - f, - " \"{}_{}\"", - stage.num, - format_pg(&task.partition_group) - )?; - - Ok::<(), std::fmt::Error>(()) - })?; - writeln!(f, " }}")?; - // combined with rank = same, the invisible edges will force the tasks to be - // laid out in a single row within the stage - for i in 0..stage.tasks.len() - 1 { - writeln!( - f, - " \"{}_{}\":w -> \"{}_{}\":e[style=invis]", - stage.num, - format_pg(&stage.tasks[i].partition_group), - stage.num, - format_pg(&stage.tasks[i + 1].partition_group), - )?; - } - - // add a node for the plan, its way too big! Alternatives to add it? - /*writeln!( - f, - " \"{}_plan\"[label = \"{}\", shape=box];", - stage.num, - displayable(stage.plan.as_ref()).indent(false) - )?; - */ - - writeln!(f, " }}")?; - - for child in stage.child_stages_iter() { - stack.push((child, Some(stage))); - } - } - - writeln!(f, "}}")?; - Ok(f) -} - -fn format_tasks_for_stage(tasks: &[ExecutionTask]) -> String { - let mut result = "Tasks: ".to_string(); - for (i, t) in tasks.iter().enumerate() { - result += &format!("t{i}:["); - result += &t.partition_group.iter().map(|v| format!("p{v}")).join(","); - result += "] " - } - result -} - -fn format_tasks_for_partition_isolator(tasks: &[ExecutionTask]) -> String { - let mut result = "Tasks: ".to_string(); - let mut partitions = vec![]; - for t in tasks.iter() { - partitions.extend(vec!["__".to_string(); t.partition_group.len()]) - } - for (i, t) in tasks.iter().enumerate() { - let mut partitions = partitions.clone(); - for (i, p) in t.partition_group.iter().enumerate() { - partitions[*p as usize] = format!("p{i}") - } - result += &format!("t{i}:[{}] ", partitions.join(",")); - } - result -} diff --git a/src/stage/execution_stage.rs b/src/stage/execution_stage.rs deleted file mode 100644 index 15e7635..0000000 --- a/src/stage/execution_stage.rs +++ /dev/null @@ -1,288 +0,0 @@ -use std::sync::Arc; - -use datafusion::common::{exec_err, internal_err}; -use datafusion::error::{DataFusionError, Result}; -use datafusion::execution::TaskContext; -use datafusion::physical_plan::ExecutionPlan; -use datafusion::prelude::SessionContext; - -use crate::channel_manager_ext::get_distributed_channel_resolver; -use crate::stage::ExecutionTask; -use crate::ChannelResolver; -use itertools::Itertools; -use rand::Rng; -use url::Url; -use uuid::Uuid; - -/// A unit of isolation for a portion of a physical execution plan -/// that can be executed independently and across a network boundary. -/// It implements [`ExecutionPlan`] and can be executed to produce a -/// stream of record batches. -/// -/// An ExecutionTask is a finer grained unit of work compared to an ExecutionStage. -/// One ExecutionStage will create one or more ExecutionTasks -/// -/// When an [`ExecutionStage`] is execute()'d if will execute its plan and return a stream -/// of record batches. -/// -/// If the stage has input stages, then it those input stages will be executed on remote resources -/// and will be provided the remainder of the stage tree. -/// -/// For example if our stage tree looks like this: -/// -/// ```text -/// ┌─────────┐ -/// │ stage 1 │ -/// └───┬─────┘ -/// │ -/// ┌──────┴────────┐ -/// ┌────┴────┐ ┌────┴────┐ -/// │ stage 2 │ │ stage 3 │ -/// └────┬────┘ └─────────┘ -/// │ -/// ┌──────┴────────┐ -/// ┌────┴────┐ ┌────┴────┐ -/// │ stage 4 │ │ Stage 5 │ -/// └─────────┘ └─────────┘ -/// -/// ``` -/// -/// Then executing Stage 1 will run its plan locally. Stage 1 has two inputs, Stage 2 and Stage 3. We -/// know these will execute on remote resources. As such the plan for Stage 1 must contain an -/// [`ArrowFlightReadExec`] node that will read the results of Stage 2 and Stage 3 and coalese the -/// results. -/// -/// When Stage 1's [`ArrowFlightReadExec`] node is executed, it makes an ArrowFlightRequest to the -/// host assigned in the Stage. It provides the following Stage tree serialilzed in the body of the -/// Arrow Flight Ticket: -/// -/// ```text -/// ┌─────────┐ -/// │ Stage 2 │ -/// └────┬────┘ -/// │ -/// ┌──────┴────────┐ -/// ┌────┴────┐ ┌────┴────┐ -/// │ Stage 4 │ │ Stage 5 │ -/// └─────────┘ └─────────┘ -/// -/// ``` -/// -/// The receiving ArrowFlightEndpoint will then execute Stage 2 and will repeat this process. -/// -/// When Stage 4 is executed, it has no input tasks, so it is assumed that the plan included in that -/// Stage can complete on its own; its likely holding a leaf node in the overall phyysical plan and -/// producing data from a [`DataSourceExec`]. -#[derive(Debug, Clone)] -pub struct ExecutionStage { - /// Our query_id - pub query_id: Uuid, - /// Our stage number - pub num: usize, - /// Our stage name - pub name: String, - /// The physical execution plan that this stage will execute. - pub plan: Arc, - /// The input stages to this stage - pub inputs: Vec>, - /// Our tasks which tell us how finely grained to execute the partitions in - /// the plan - pub tasks: Vec, - /// tree depth of our location in the stage tree, used for display only - pub depth: usize, -} - -impl ExecutionStage { - /// Creates a new `ExecutionStage` with the given plan and inputs. One task will be created - /// responsible for partitions in the plan. - pub fn new( - query_id: Uuid, - num: usize, - plan: Arc, - inputs: Vec>, - ) -> Self { - let name = format!("Stage {:<3}", num); - let partition_group = (0..plan.properties().partitioning.partition_count()) - .map(|p| p as u64) - .collect(); - ExecutionStage { - query_id, - num, - name, - plan, - inputs: inputs - .into_iter() - .map(|s| s as Arc) - .collect(), - tasks: vec![ExecutionTask::new(partition_group)], - depth: 0, - } - } - - /// Recalculate the tasks for this stage based on the number of partitions in the plan - /// and the maximum number of partitions per task. - /// - /// This will unset any worker assignments - pub fn with_maximum_partitions_per_task(mut self, max_partitions_per_task: usize) -> Self { - let partitions = self.plan.properties().partitioning.partition_count(); - - self.tasks = (0..partitions) - .chunks(max_partitions_per_task) - .into_iter() - .map(|partition_group| { - ExecutionTask::new( - partition_group - .collect::>() - .into_iter() - .map(|p| p as u64) - .collect(), - ) - }) - .collect(); - self - } - - /// Returns the name of this stage - pub fn name(&self) -> String { - format!("Stage {:<3}", self.num) - } - - /// Returns an iterator over the child stages of this stage cast as &ExecutionStage - /// which can be useful - pub fn child_stages_iter(&self) -> impl Iterator { - self.inputs - .iter() - .filter_map(|s| s.as_any().downcast_ref::()) - } - - /// Returns the name of this stage including child stage numbers if any. - pub fn name_with_children(&self) -> String { - let child_str = if self.inputs.is_empty() { - "".to_string() - } else { - format!( - " Child Stages:[{}] ", - self.child_stages_iter() - .map(|s| format!("{}", s.num)) - .collect::>() - .join(", ") - ) - }; - format!("Stage {:<3}{}", self.num, child_str) - } - - pub fn try_assign(self, channel_resolver: &impl ChannelResolver) -> Result { - let urls: Vec = channel_resolver.get_urls()?; - if urls.is_empty() { - return internal_err!("No URLs found in ChannelManager"); - } - - Ok(self) - } - - fn try_assign_urls(&self, urls: &[Url]) -> Result { - let assigned_children = self - .child_stages_iter() - .map(|child| { - child - .clone() // TODO: avoid cloning if possible - .try_assign_urls(urls) - .map(|c| Arc::new(c) as Arc) - }) - .collect::>>()?; - - // pick a random starting position - let mut rng = rand::thread_rng(); - let start_idx = rng.gen_range(0..urls.len()); - - let assigned_tasks = self - .tasks - .iter() - .enumerate() - .map(|(i, task)| { - let url = &urls[(start_idx + i) % urls.len()]; - task.clone().with_assignment(url) - }) - .collect::>(); - - let assigned_stage = ExecutionStage { - query_id: self.query_id, - num: self.num, - name: self.name.clone(), - plan: self.plan.clone(), - inputs: assigned_children, - tasks: assigned_tasks, - depth: self.depth, - }; - - Ok(assigned_stage) - } -} - -impl ExecutionPlan for ExecutionStage { - fn name(&self) -> &str { - &self.name - } - - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn children(&self) -> Vec<&Arc> { - self.inputs.iter().collect() - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - Ok(Arc::new(ExecutionStage { - query_id: self.query_id, - num: self.num, - name: self.name.clone(), - plan: self.plan.clone(), - inputs: children, - tasks: self.tasks.clone(), - depth: self.depth, - })) - } - - fn properties(&self) -> &datafusion::physical_plan::PlanProperties { - self.plan.properties() - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> datafusion::error::Result { - let stage = self - .as_any() - .downcast_ref::() - .expect("Unwrapping myself should always work"); - - let Some(channel_resolver) = get_distributed_channel_resolver(context.session_config()) - else { - return exec_err!("ChannelManager not found in session config"); - }; - - let urls = channel_resolver.get_urls()?; - - let assigned_stage = stage - .try_assign_urls(&urls) - .map(Arc::new) - .map_err(|e| DataFusionError::Execution(e.to_string()))?; - - // insert the stage into the context so that ExecutionPlan nodes - // that care about the stage can access it - let config = context - .session_config() - .clone() - .with_extension(assigned_stage.clone()); - - let new_ctx = - SessionContext::new_with_config_rt(config, context.runtime_env().clone()).task_ctx(); - - assigned_stage.plan.execute(partition, new_ctx) - } -} diff --git a/src/stage/mod.rs b/src/stage/mod.rs deleted file mode 100644 index b23dea0..0000000 --- a/src/stage/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -mod display; -mod execution_stage; -mod task; - -pub use display::display_stage_graphviz; -pub use execution_stage::ExecutionStage; -pub use task::ExecutionTask; diff --git a/src/stage/task.rs b/src/stage/task.rs deleted file mode 100644 index b912834..0000000 --- a/src/stage/task.rs +++ /dev/null @@ -1,71 +0,0 @@ -use core::fmt; -use std::fmt::Display; -use std::fmt::Formatter; - -use datafusion::common::internal_datafusion_err; -use datafusion::error::Result; - -use url::Url; - -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ExecutionTask { - /// The url of the worker that will execute this task. A None value is interpreted as - /// unassinged. - #[prost(string, optional, tag = "1")] - pub url_str: Option, - /// The partitions that we can execute from this plan - #[prost(uint64, repeated, tag = "2")] - pub partition_group: Vec, -} - -impl ExecutionTask { - pub fn new(partition_group: Vec) -> Self { - ExecutionTask { - url_str: None, - partition_group, - } - } - - pub fn with_assignment(mut self, url: &Url) -> Self { - self.url_str = Some(format!("{url}")); - self - } - - /// Returns the url of this worker, a None is unassigned - pub fn url(&self) -> Result> { - self.url_str - .as_ref() - .map(|u| Url::parse(u).map_err(|_| internal_datafusion_err!("Invalid URL: {}", u))) - .transpose() - } -} - -impl Display for ExecutionTask { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!( - f, - "Task: partitions: {},{}]", - format_pg(&self.partition_group), - self.url() - .map_err(|_| std::fmt::Error {})? - .map(|u| u.to_string()) - .unwrap_or("unassigned".to_string()) - ) - } -} - -pub(crate) fn format_pg(partition_group: &[u64]) -> String { - if partition_group.len() > 2 { - format!( - "{}..{}", - partition_group[0], - partition_group[partition_group.len() - 1] - ) - } else { - partition_group - .iter() - .map(|pg| format!("{pg}")) - .collect::>() - .join(",") - } -} From 39ff6e5ea4e326c80fc91022bcaed1f606420acd Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Thu, 4 Sep 2025 22:12:32 +0200 Subject: [PATCH 4/6] Rename channel manager to channel resolver --- src/{channel_manager_ext.rs => channel_resolver_ext.rs} | 0 src/distributed_ext.rs | 2 +- src/execution_plans/arrow_flight_read.rs | 2 +- src/execution_plans/stage.rs | 2 +- src/lib.rs | 4 ++-- 5 files changed, 5 insertions(+), 5 deletions(-) rename src/{channel_manager_ext.rs => channel_resolver_ext.rs} (100%) diff --git a/src/channel_manager_ext.rs b/src/channel_resolver_ext.rs similarity index 100% rename from src/channel_manager_ext.rs rename to src/channel_resolver_ext.rs diff --git a/src/distributed_ext.rs b/src/distributed_ext.rs index 503e5dc..b83f852 100644 --- a/src/distributed_ext.rs +++ b/src/distributed_ext.rs @@ -1,4 +1,4 @@ -use crate::channel_manager_ext::set_distributed_channel_resolver; +use crate::channel_resolver_ext::set_distributed_channel_resolver; use crate::config_extension_ext::{ set_distributed_option_extension, set_distributed_option_extension_from_headers, }; diff --git a/src/execution_plans/arrow_flight_read.rs b/src/execution_plans/arrow_flight_read.rs index 86ab6f8..234cbc8 100644 --- a/src/execution_plans/arrow_flight_read.rs +++ b/src/execution_plans/arrow_flight_read.rs @@ -1,4 +1,4 @@ -use crate::channel_manager_ext::get_distributed_channel_resolver; +use crate::channel_resolver_ext::get_distributed_channel_resolver; use crate::config_extension_ext::ContextGrpcMetadata; use crate::errors::{map_flight_to_datafusion_error, map_status_to_datafusion_error}; use crate::execution_plans::StageExec; diff --git a/src/execution_plans/stage.rs b/src/execution_plans/stage.rs index efb72e2..ff42349 100644 --- a/src/execution_plans/stage.rs +++ b/src/execution_plans/stage.rs @@ -1,4 +1,4 @@ -use crate::channel_manager_ext::get_distributed_channel_resolver; +use crate::channel_resolver_ext::get_distributed_channel_resolver; use crate::{ArrowFlightReadExec, ChannelResolver, PartitionIsolatorExec}; use datafusion::common::{exec_err, internal_err}; use datafusion::error::{DataFusionError, Result}; diff --git a/src/lib.rs b/src/lib.rs index f07103b..9f06ff3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ #![deny(clippy::all)] -mod channel_manager_ext; +mod channel_resolver_ext; mod common; mod config_extension_ext; mod distributed_ext; @@ -13,7 +13,7 @@ mod protobuf; #[cfg(any(feature = "integration", test))] pub mod test_utils; -pub use channel_manager_ext::{BoxCloneSyncChannel, ChannelResolver}; +pub use channel_resolver_ext::{BoxCloneSyncChannel, ChannelResolver}; pub use distributed_ext::DistributedExt; pub use execution_plans::{ display_stage_graphviz, ArrowFlightReadExec, ExecutionTask, PartitionIsolatorExec, StageExec, From 355da3275360f51b6a709541ba8fae6b4d9005c1 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Thu, 4 Sep 2025 22:13:03 +0200 Subject: [PATCH 5/6] Rename physical_optimizer.rs to distributed_physical_optimizer_rule.rs --- ...al_optimizer.rs => distributed_physical_optimizer_rule.rs} | 2 +- src/lib.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) rename src/{physical_optimizer.rs => distributed_physical_optimizer_rule.rs} (99%) diff --git a/src/physical_optimizer.rs b/src/distributed_physical_optimizer_rule.rs similarity index 99% rename from src/physical_optimizer.rs rename to src/distributed_physical_optimizer_rule.rs index 70c0a0a..5f65b22 100644 --- a/src/physical_optimizer.rs +++ b/src/distributed_physical_optimizer_rule.rs @@ -187,7 +187,7 @@ pub fn can_be_divided(plan: &Arc) -> Result { #[cfg(test)] mod tests { use crate::assert_snapshot; - use crate::physical_optimizer::DistributedPhysicalOptimizerRule; + use crate::distributed_physical_optimizer_rule::DistributedPhysicalOptimizerRule; use crate::test_utils::parquet::register_parquet_tables; use datafusion::error::DataFusionError; use datafusion::execution::SessionStateBuilder; diff --git a/src/lib.rs b/src/lib.rs index 9f06ff3..dd6b8ca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,10 +4,10 @@ mod channel_resolver_ext; mod common; mod config_extension_ext; mod distributed_ext; +mod distributed_physical_optimizer_rule; mod errors; mod execution_plans; mod flight_service; -mod physical_optimizer; mod protobuf; #[cfg(any(feature = "integration", test))] @@ -15,6 +15,7 @@ pub mod test_utils; pub use channel_resolver_ext::{BoxCloneSyncChannel, ChannelResolver}; pub use distributed_ext::DistributedExt; +pub use distributed_physical_optimizer_rule::DistributedPhysicalOptimizerRule; pub use execution_plans::{ display_stage_graphviz, ArrowFlightReadExec, ExecutionTask, PartitionIsolatorExec, StageExec, }; @@ -23,4 +24,3 @@ pub use flight_service::{ DistributedSessionBuilderContext, MappedDistributedSessionBuilder, MappedDistributedSessionBuilderExt, }; -pub use physical_optimizer::DistributedPhysicalOptimizerRule; From 700b1de1fcc74bce009ede4a9fd8d368bae9b94f Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Fri, 5 Sep 2025 09:20:53 +0200 Subject: [PATCH 6/6] Refactor do_get.rs and adjacent files --- src/channel_resolver_ext.rs | 4 +- src/execution_plans/arrow_flight_read.rs | 14 +-- src/execution_plans/stage.rs | 18 +-- src/flight_service/do_get.rs | 151 ++++++++++++----------- src/flight_service/service.rs | 5 +- 5 files changed, 92 insertions(+), 100 deletions(-) diff --git a/src/channel_resolver_ext.rs b/src/channel_resolver_ext.rs index 1a6df43..33cf739 100644 --- a/src/channel_resolver_ext.rs +++ b/src/channel_resolver_ext.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use datafusion::common::exec_datafusion_err; use datafusion::error::DataFusionError; use datafusion::prelude::SessionConfig; use std::sync::Arc; @@ -16,9 +17,10 @@ pub(crate) fn set_distributed_channel_resolver( pub(crate) fn get_distributed_channel_resolver( cfg: &SessionConfig, -) -> Option> { +) -> Result, DataFusionError> { cfg.get_extension::() .map(|cm| cm.0.clone()) + .ok_or_else(|| exec_datafusion_err!("ChannelResolver not present in the session config")) } #[derive(Clone)] diff --git a/src/execution_plans/arrow_flight_read.rs b/src/execution_plans/arrow_flight_read.rs index 234cbc8..bee5819 100644 --- a/src/execution_plans/arrow_flight_read.rs +++ b/src/execution_plans/arrow_flight_read.rs @@ -147,18 +147,14 @@ impl ExecutionPlan for ArrowFlightReadExec { partition: usize, context: Arc, ) -> Result { - let ArrowFlightReadExec::Ready(this) = self else { + let ArrowFlightReadExec::Ready(self_ready) = self else { return exec_err!("ArrowFlightReadExec is not ready, was the distributed optimization step performed?"); }; // get the channel manager and current stage from our context - let Some(channel_resolver) = get_distributed_channel_resolver(context.session_config()) - else { - return exec_err!( - "ArrowFlightReadExec requires a ChannelResolver in the session config" - ); - }; + let channel_resolver = get_distributed_channel_resolver(context.session_config())?; + // the `ArrowFlightReadExec` node can only be executed in the context of a `StageExec` let stage = context .session_config() .get_extension::() @@ -170,10 +166,10 @@ impl ExecutionPlan for ArrowFlightReadExec { // reading from let child_stage = stage .child_stages_iter() - .find(|s| s.num == this.stage_num) + .find(|s| s.num == self_ready.stage_num) .ok_or(internal_datafusion_err!( "ArrowFlightReadExec: no child stage with num {}", - this.stage_num + self_ready.stage_num ))?; let flight_metadata = context diff --git a/src/execution_plans/stage.rs b/src/execution_plans/stage.rs index ff42349..fc1e45a 100644 --- a/src/execution_plans/stage.rs +++ b/src/execution_plans/stage.rs @@ -1,6 +1,6 @@ use crate::channel_resolver_ext::get_distributed_channel_resolver; use crate::{ArrowFlightReadExec, ChannelResolver, PartitionIsolatorExec}; -use datafusion::common::{exec_err, internal_err}; +use datafusion::common::internal_err; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::TaskContext; use datafusion::physical_plan::{ @@ -260,20 +260,10 @@ impl ExecutionPlan for StageExec { partition: usize, context: Arc, ) -> Result { - let stage = self - .as_any() - .downcast_ref::() - .expect("Unwrapping myself should always work"); - - let Some(channel_resolver) = get_distributed_channel_resolver(context.session_config()) - else { - return exec_err!("ChannelManager not found in session config"); - }; - - let urls = channel_resolver.get_urls()?; + let channel_resolver = get_distributed_channel_resolver(context.session_config())?; - let assigned_stage = stage - .try_assign_urls(&urls) + let assigned_stage = self + .try_assign_urls(&channel_resolver.get_urls()?) .map(Arc::new) .map_err(|e| DataFusionError::Execution(e.to_string()))?; diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index c316525..e979667 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -9,13 +9,14 @@ use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; use arrow_flight::flight_service_server::FlightService; use arrow_flight::Ticket; -use datafusion::execution::SessionState; +use datafusion::execution::{SendableRecordBatchStream, SessionState}; use futures::TryStreamExt; +use http::HeaderMap; use prost::Message; +use std::fmt::Display; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use tokio::sync::OnceCell; -use tonic::metadata::MetadataMap; use tonic::{Request, Response, Status}; #[derive(Clone, PartialEq, ::prost::Message)] @@ -41,9 +42,9 @@ pub struct DoGet { /// TaskData stores state for a single task being executed by this Endpoint. It may be shared /// by concurrent requests for the same task which execute separate partitions. pub struct TaskData { - pub(super) state: SessionState, + pub(super) session_state: SessionState, pub(super) stage: Arc, - ///num_partitions_remaining is initialized to the total number of partitions in the task (not + /// `num_partitions_remaining` is initialized to the total number of partitions in the task (not /// only tasks in the partition group). This is decremented for each request to the endpoint /// for this task. Once this count is zero, the task is likely complete. The task may not be /// complete because it's possible that the same partition was retried and this count was @@ -56,98 +57,78 @@ impl ArrowFlightEndpoint { &self, request: Request, ) -> Result::DoGetStream>, Status> { - let (metadata, _ext, ticket) = request.into_parts(); - let Ticket { ticket } = ticket; - let doget = DoGet::decode(ticket).map_err(|err| { + let (metadata, _ext, body) = request.into_parts(); + let doget = DoGet::decode(body.ticket).map_err(|err| { Status::invalid_argument(format!("Cannot decode DoGet message: {err}")) })?; + // There's only 1 `StageExec` responsible for all requests that share the same `stage_key`, + // so here we either retrieve the existing one or create a new one if it does not exist. + let (mut session_state, stage) = self + .get_state_and_stage( + doget.stage_key.ok_or_else(missing("stage_key"))?, + doget.stage_proto.ok_or_else(missing("stage_proto"))?, + metadata.clone().into_headers(), + ) + .await?; + + // Find out which partition group we are executing let partition = doget.partition as usize; let task_number = doget.task_number as usize; - let task_data = self.get_state_and_stage(doget, metadata).await?; - - let stage = task_data.stage; - let mut state = task_data.state; - - // find out which partition group we are executing - let task = stage - .tasks - .get(task_number) - .ok_or(Status::invalid_argument(format!( - "Task number {} not found in stage {}", - task_number, - stage.name() - )))?; - - let partition_group = PartitionGroup(task.partition_group.clone()); - state.config_mut().set_extension(Arc::new(partition_group)); - - let inner_plan = stage.plan.clone(); - - let stream = inner_plan - .execute(partition, state.task_ctx()) + let task = stage.tasks.get(task_number).ok_or_else(invalid(format!( + "Task number {task_number} not found in stage {}", + stage.num + )))?; + + let cfg = session_state.config_mut(); + cfg.set_extension(Arc::new(PartitionGroup(task.partition_group.clone()))); + cfg.set_extension(Arc::clone(&stage)); + cfg.set_extension(Arc::new(ContextGrpcMetadata(metadata.into_headers()))); + + // Rather than executing the `StageExec` itself, we want to execute the inner plan instead, + // as executing `StageExec` performs some worker assignation that should have already been + // done in the head stage. + let stream = stage + .plan + .execute(partition, session_state.task_ctx()) .map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?; - let flight_data_stream = FlightDataEncoderBuilder::new() - .with_schema(inner_plan.schema().clone()) - .build(stream.map_err(|err| { - FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err))) - })); - - Ok(Response::new(Box::pin(flight_data_stream.map_err( - |err| match err { - FlightError::Tonic(status) => *status, - _ => Status::internal(format!("Error during flight stream: {err}")), - }, - )))) + Ok(record_batch_stream_to_response(stream)) } async fn get_state_and_stage( &self, - doget: DoGet, - metadata_map: MetadataMap, - ) -> Result { - let key = doget - .stage_key - .ok_or(Status::invalid_argument("DoGet is missing the stage key"))?; - let once_stage = self - .stages + key: StageKey, + stage_proto: StageExecProto, + headers: HeaderMap, + ) -> Result<(SessionState, Arc), Status> { + let once = self + .task_data_entries .get_or_init(key.clone(), || Arc::new(OnceCell::::new())); - let stage_data = once_stage + let stage_data = once .get_or_try_init(|| async { - let stage_proto = doget - .stage_proto - .ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?; - - let headers = metadata_map.into_headers(); - let mut state = self + let session_state = self .session_builder .build_session_state(DistributedSessionBuilderContext { runtime_env: Arc::clone(&self.runtime), - headers: headers.clone(), + headers, }) .await .map_err(|err| datafusion_error_to_tonic_status(&err))?; - let codec = DistributedCodec::new_combined_with_user(state.config()); + let codec = DistributedCodec::new_combined_with_user(session_state.config()); - let stage = stage_from_proto(stage_proto, &state, self.runtime.as_ref(), &codec) - .map(Arc::new) + let stage = stage_from_proto(stage_proto, &session_state, &self.runtime, &codec) .map_err(|err| { Status::invalid_argument(format!("Cannot decode stage proto: {err}")) })?; - // Add the extensions that might be required for ExecutionPlan nodes in the plan - let config = state.config_mut(); - config.set_extension(stage.clone()); - config.set_extension(Arc::new(ContextGrpcMetadata(headers))); - // Initialize partition count to the number of partitions in the stage let total_partitions = stage.plan.properties().partitioning.partition_count(); Ok::<_, Status>(TaskData { - state, - stage, + session_state, + stage: Arc::new(stage), num_partitions_remaining: Arc::new(AtomicUsize::new(total_partitions)), }) }) @@ -158,13 +139,37 @@ impl ArrowFlightEndpoint { .num_partitions_remaining .fetch_sub(1, Ordering::SeqCst); if remaining_partitions <= 1 { - self.stages.remove(key.clone()); + self.task_data_entries.remove(key); } - Ok(stage_data.clone()) + Ok((stage_data.session_state.clone(), stage_data.stage.clone())) } } +fn missing(field: &'static str) -> impl FnOnce() -> Status { + move || Status::invalid_argument(format!("Missing field '{field}'")) +} + +fn invalid(msg: impl Display) -> impl FnOnce() -> Status { + move || Status::invalid_argument(msg.to_string()) +} + +fn record_batch_stream_to_response( + stream: SendableRecordBatchStream, +) -> Response<::DoGetStream> { + let flight_data_stream = + FlightDataEncoderBuilder::new() + .with_schema(stream.schema().clone()) + .build(stream.map_err(|err| { + FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err))) + })); + + Response::new(Box::pin(flight_data_stream.map_err(|err| match err { + FlightError::Tonic(status) => *status, + _ => Status::internal(format!("Error during flight stream: {err}")), + }))) +} + #[cfg(test)] mod tests { use super::*; @@ -262,13 +267,13 @@ mod tests { } // Check that the endpoint has not evicted any task states. - assert_eq!(endpoint.stages.len(), num_tasks); + assert_eq!(endpoint.task_data_entries.len(), num_tasks); // Run the last partition of task 0. Any partition number works. Verify that the task state // is evicted because all partitions have been processed. let result = do_get(1, 0, task_keys[0].clone()).await; assert!(result.is_ok()); - let stored_stage_keys = endpoint.stages.keys().collect::>(); + let stored_stage_keys = endpoint.task_data_entries.keys().collect::>(); assert_eq!(stored_stage_keys.len(), 2); assert!(stored_stage_keys.contains(&task_keys[1])); assert!(stored_stage_keys.contains(&task_keys[2])); @@ -276,14 +281,14 @@ mod tests { // Run the last partition of task 1. let result = do_get(1, 1, task_keys[1].clone()).await; assert!(result.is_ok()); - let stored_stage_keys = endpoint.stages.keys().collect::>(); + let stored_stage_keys = endpoint.task_data_entries.keys().collect::>(); assert_eq!(stored_stage_keys.len(), 1); assert!(stored_stage_keys.contains(&task_keys[2])); // Run the last partition of the last task. let result = do_get(1, 2, task_keys[2].clone()).await; assert!(result.is_ok()); - let stored_stage_keys = endpoint.stages.keys().collect::>(); + let stored_stage_keys = endpoint.task_data_entries.keys().collect::>(); assert_eq!(stored_stage_keys.len(), 0); } diff --git a/src/flight_service/service.rs b/src/flight_service/service.rs index 3b02cb3..c50ff1d 100644 --- a/src/flight_service/service.rs +++ b/src/flight_service/service.rs @@ -30,8 +30,7 @@ pub struct StageKey { pub struct ArrowFlightEndpoint { pub(super) runtime: Arc, - #[allow(clippy::type_complexity)] - pub(super) stages: TTLMap>>, + pub(super) task_data_entries: TTLMap>>, pub(super) session_builder: Arc, } @@ -42,7 +41,7 @@ impl ArrowFlightEndpoint { let ttl_map = TTLMap::try_new(TTLMapConfig::default())?; Ok(Self { runtime: Arc::new(RuntimeEnv::default()), - stages: ttl_map, + task_data_entries: ttl_map, session_builder: Arc::new(session_builder), }) }