From 7936ba91cbe0bfce314668e43d49c1ad63a07de6 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Mon, 4 Aug 2025 08:08:41 +0200 Subject: [PATCH] Remove stage delegation in favor of planning-time stage assignation --- context.rs | 20 ++ src/channel_manager.rs | 56 +++- src/context.rs | 20 ++ src/flight_service/do_get.rs | 90 +++--- src/flight_service/do_put.rs | 81 ----- src/flight_service/mod.rs | 2 - src/flight_service/service.rs | 7 +- .../stream_partitioner_registry.rs | 14 +- src/lib.rs | 8 +- src/plan/arrow_flight_read.rs | 265 +++++----------- src/plan/arrow_flight_read_proto.rs | 69 ++++- src/plan/assign_stages.rs | 73 +++++ src/plan/mod.rs | 2 + src/stage_delegation/context.rs | 25 -- src/stage_delegation/delegation.rs | 290 ------------------ src/stage_delegation/mod.rs | 5 - tests/common/insta.rs | 10 +- tests/common/localhost.rs | 31 +- tests/custom_extension_codec.rs | 13 +- tests/distributed_aggregation.rs | 8 +- tests/error_propagation.rs | 4 +- tests/highly_distributed_query.rs | 10 +- 22 files changed, 383 insertions(+), 720 deletions(-) create mode 100644 context.rs create mode 100644 src/context.rs delete mode 100644 src/flight_service/do_put.rs create mode 100644 src/plan/assign_stages.rs delete mode 100644 src/stage_delegation/context.rs delete mode 100644 src/stage_delegation/delegation.rs delete mode 100644 src/stage_delegation/mod.rs diff --git a/context.rs b/context.rs new file mode 100644 index 00000000..eae1e12d --- /dev/null +++ b/context.rs @@ -0,0 +1,20 @@ +use url::Url; +use uuid::Uuid; + +#[derive(Debug, Clone)] +pub struct StageContext { + /// Unique identifier of the Stage. + pub id: Uuid, + /// Number of tasks involved in the query. + pub n_tasks: usize, + /// Unique identifier of the input Stage. + pub input_id: Uuid, + /// Urls from which the current stage will need to read data. + pub input_urls: Vec, +} + +#[derive(Debug, Clone)] +pub struct StageTaskContext { + /// Index of the current task in a stage + pub task_idx: usize, +} diff --git a/src/channel_manager.rs b/src/channel_manager.rs index db88f639..08c60e62 100644 --- a/src/channel_manager.rs +++ b/src/channel_manager.rs @@ -1,12 +1,14 @@ use async_trait::async_trait; use datafusion::common::internal_datafusion_err; use datafusion::error::DataFusionError; -use datafusion::prelude::SessionConfig; +use datafusion::execution::TaskContext; +use datafusion::prelude::{SessionConfig, SessionContext}; use delegate::delegate; use std::sync::Arc; use tonic::body::BoxBody; use url::Url; +#[derive(Clone)] pub struct ChannelManager(Arc); impl ChannelManager { @@ -21,29 +23,49 @@ pub type BoxCloneSyncChannel = tower::util::BoxCloneSyncService< tonic::transport::Error, >; -#[derive(Clone, Debug)] -pub struct ArrowFlightChannel { - pub url: Url, - pub channel: BoxCloneSyncChannel, -} - +/// Abstracts networking details so that users can implement their own network resolution +/// mechanism. #[async_trait] pub trait ChannelResolver { - async fn get_n_channels(&self, n: usize) -> Result, DataFusionError>; - async fn get_channel_for_url(&self, url: &Url) -> Result; + /// Gets all available worker URLs. Used during stage assignment. + fn get_urls(&self) -> Result, DataFusionError>; + /// For a given URL, get a channel for communicating to it. + async fn get_channel_for_url(&self, url: &Url) -> Result; } impl ChannelManager { - pub fn try_from_session(session: &SessionConfig) -> Result, DataFusionError> { - session - .get_extension::() - .ok_or_else(|| internal_datafusion_err!("No extension ChannelManager")) - } - delegate! { to self.0 { - pub async fn get_n_channels(&self, n: usize) -> Result, DataFusionError>; - pub async fn get_channel_for_url(&self, url: &Url) -> Result; + pub fn get_urls(&self) -> Result, DataFusionError>; + pub async fn get_channel_for_url(&self, url: &Url) -> Result; } } } + +impl TryInto for &SessionConfig { + type Error = DataFusionError; + + fn try_into(self) -> Result { + Ok(self + .get_extension::() + .ok_or_else(|| internal_datafusion_err!("No extension ChannelManager"))? + .as_ref() + .clone()) + } +} + +impl TryInto for &TaskContext { + type Error = DataFusionError; + + fn try_into(self) -> Result { + self.session_config().try_into() + } +} + +impl TryInto for &SessionContext { + type Error = DataFusionError; + + fn try_into(self) -> Result { + self.task_ctx().as_ref().try_into() + } +} diff --git a/src/context.rs b/src/context.rs new file mode 100644 index 00000000..eae1e12d --- /dev/null +++ b/src/context.rs @@ -0,0 +1,20 @@ +use url::Url; +use uuid::Uuid; + +#[derive(Debug, Clone)] +pub struct StageContext { + /// Unique identifier of the Stage. + pub id: Uuid, + /// Number of tasks involved in the query. + pub n_tasks: usize, + /// Unique identifier of the input Stage. + pub input_id: Uuid, + /// Urls from which the current stage will need to read data. + pub input_urls: Vec, +} + +#[derive(Debug, Clone)] +pub struct StageTaskContext { + /// Index of the current task in a stage + pub task_idx: usize, +} diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index 47713737..347691da 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -1,8 +1,8 @@ use crate::composed_extension_codec::ComposedPhysicalExtensionCodec; +use crate::context::StageTaskContext; use crate::errors::datafusion_error_to_tonic_status; use crate::flight_service::service::ArrowFlightEndpoint; use crate::plan::ArrowFlightReadExecProtoCodec; -use crate::stage_delegation::{ActorContext, StageContext}; use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; use arrow_flight::flight_service_server::FlightService; @@ -10,15 +10,17 @@ use arrow_flight::Ticket; use datafusion::error::DataFusionError; use datafusion::execution::SessionStateBuilder; use datafusion::optimizer::OptimizerConfig; -use datafusion::physical_expr::Partitioning; +use datafusion::physical_expr::{Partitioning, PhysicalExpr}; use datafusion::physical_plan::ExecutionPlan; -use datafusion_proto::physical_plan::from_proto::parse_protobuf_partitioning; +use datafusion_proto::physical_plan::from_proto::parse_physical_exprs; +use datafusion_proto::physical_plan::to_proto::serialize_physical_exprs; use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec}; -use datafusion_proto::protobuf::PhysicalPlanNode; +use datafusion_proto::protobuf::{PhysicalExprNode, PhysicalPlanNode}; use futures::TryStreamExt; use prost::Message; use std::sync::Arc; use tonic::{Request, Response, Status}; +use uuid::Uuid; #[derive(Clone, PartialEq, ::prost::Message)] pub struct DoGet { @@ -35,26 +37,38 @@ pub enum DoGetInner { #[derive(Clone, PartialEq, ::prost::Message)] pub struct RemotePlanExec { #[prost(message, optional, boxed, tag = "1")] - plan: Option>, - #[prost(message, optional, tag = "2")] - stage_context: Option, - #[prost(message, optional, tag = "3")] - actor_context: Option, + pub plan: Option>, + #[prost(string, tag = "2")] + pub stage_id: String, + #[prost(uint64, tag = "3")] + pub task_idx: u64, + #[prost(uint64, tag = "4")] + pub output_task_idx: u64, + #[prost(uint64, tag = "5")] + pub output_tasks: u64, + #[prost(message, repeated, tag = "6")] + pub hash_expr: Vec, } impl DoGet { pub fn new_remote_plan_exec_ticket( plan: Arc, - stage_context: StageContext, - actor_context: ActorContext, + stage_id: Uuid, + task_idx: usize, + output_task_idx: usize, + output_tasks: usize, + hash_expr: &[Arc], extension_codec: &dyn PhysicalExtensionCodec, ) -> Result { let node = PhysicalPlanNode::try_from_physical_plan(plan, extension_codec)?; let do_get = Self { inner: Some(DoGetInner::RemotePlanExec(RemotePlanExec { plan: Some(Box::new(node)), - stage_context: Some(stage_context), - actor_context: Some(actor_context), + stage_id: stage_id.to_string(), + task_idx: task_idx as u64, + output_task_idx: output_task_idx as u64, + output_tasks: output_tasks as u64, + hash_expr: serialize_physical_exprs(hash_expr, extension_codec)?, })), }; Ok(Ticket::new(do_get.encode_to_vec())) @@ -91,14 +105,6 @@ impl ArrowFlightEndpoint { return invalid_argument("RemotePlanExec is missing the plan"); }; - let Some(stage_context) = action.stage_context else { - return invalid_argument("RemotePlanExec is missing the stage context"); - }; - - let Some(actor_context) = action.actor_context else { - return invalid_argument("RemotePlanExec is missing the actor context"); - }; - let mut codec = ComposedPhysicalExtensionCodec::default(); codec.push(ArrowFlightReadExecProtoCodec); codec.push_from_config(state.config()); @@ -107,40 +113,34 @@ impl ArrowFlightEndpoint { .try_into_physical_plan(function_registry, &self.runtime, &codec) .map_err(|err| Status::internal(format!("Cannot deserialize plan: {err}")))?; - let stage_id = stage_context.id.clone(); - let caller_actor_idx = actor_context.caller_actor_idx as usize; - let actor_idx = actor_context.actor_idx as usize; - let prev_n = stage_context.prev_actors as usize; - let partitioning = match parse_protobuf_partitioning( - stage_context.partitioning.as_ref(), + let stage_id = Uuid::parse_str(&action.stage_id).map_err(|err| { + Status::invalid_argument(format!( + "Cannot parse stage id '{}': {err}", + action.stage_id + )) + })?; + + let task_idx = action.task_idx as usize; + let caller_actor_idx = action.output_task_idx as usize; + let prev_n = action.output_tasks as usize; + let partitioning = match parse_physical_exprs( + &action.hash_expr, function_registry, &plan.schema(), &codec, ) { - // We need to replace the partition count in the provided Partitioning scheme with - // the number of actors in the previous stage. ArrowFlightReadExec might be declaring - // N partitions, but each ArrowFlightReadExec::execute(n) call will go to a different - // actor in the next stage. - // - // Each actor in that next stage (us here) needs to expose as many partitioned streams - // as actors exist on its previous stage. - Ok(Some(partitioning)) => match partitioning { - Partitioning::RoundRobinBatch(_) => Partitioning::RoundRobinBatch(prev_n), - Partitioning::Hash(expr, _) => Partitioning::Hash(expr, prev_n), - Partitioning::UnknownPartitioning(_) => Partitioning::UnknownPartitioning(prev_n), - }, - Ok(None) => return invalid_argument("Missing partitioning"), - Err(err) => return invalid_argument(format!("Cannot parse partitioning {err}")), + Ok(expr) if expr.is_empty() => Partitioning::Hash(expr, prev_n), + Ok(_) => Partitioning::RoundRobinBatch(prev_n), + Err(err) => return invalid_argument(format!("Cannot parse hash expressions {err}")), }; + let config = state.config_mut(); - config.set_extension(Arc::clone(&self.stage_delegation)); config.set_extension(Arc::clone(&self.channel_manager)); - config.set_extension(Arc::new(stage_context)); - config.set_extension(Arc::new(actor_context)); + config.set_extension(Arc::new(StageTaskContext { task_idx })); let stream_partitioner = self .partitioner_registry - .get_or_create_stream_partitioner(stage_id, actor_idx, plan, partitioning) + .get_or_create_stream_partitioner(stage_id, task_idx, plan, partitioning) .map_err(|err| datafusion_error_to_tonic_status(&err))?; let stream = stream_partitioner diff --git a/src/flight_service/do_put.rs b/src/flight_service/do_put.rs deleted file mode 100644 index c3b78457..00000000 --- a/src/flight_service/do_put.rs +++ /dev/null @@ -1,81 +0,0 @@ -use crate::flight_service::service::ArrowFlightEndpoint; -use crate::stage_delegation::StageContext; -use arrow_flight::flight_service_server::FlightService; -use arrow_flight::FlightData; -use futures::StreamExt; -use prost::Message; -use tonic::{Request, Response, Status, Streaming}; - -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct DoPut { - #[prost(oneof = "DoPutInner", tags = "1")] - pub inner: Option, -} - -#[derive(Clone, PartialEq, prost::Oneof)] -pub enum DoPutInner { - #[prost(message, tag = "1")] - StageContext(StageContextExt), -} - -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct StageContextExt { - #[prost(message, optional, tag = "1")] - pub stage_context: Option, - #[prost(string, tag = "2")] - pub stage_id: String, - #[prost(uint64, tag = "3")] - pub actor_idx: u64, -} - -impl DoPut { - pub fn new_stage_context_flight_data( - stage_id: String, - actor_idx: usize, - next_stage_context: StageContext, - ) -> FlightData { - let this = Self { - inner: Some(DoPutInner::StageContext(StageContextExt { - stage_id, - actor_idx: actor_idx as u64, - stage_context: Some(next_stage_context), - })), - }; - - FlightData::new().with_data_body(this.encode_to_vec()) - } -} - -impl ArrowFlightEndpoint { - pub(super) async fn put( - &self, - request: Request>, - ) -> Result::DoPutStream>, Status> { - let mut stream = request.into_inner(); - while let Some(msg) = stream.message().await? { - let action = DoPut::decode(msg.data_body).map_err(|err| { - Status::invalid_argument(format!("Cannot decode DoPut message: {err}")) - })?; - let Some(action) = action.inner else { - return Err(Status::invalid_argument("DoPut message is empty")); - }; - match action { - DoPutInner::StageContext(stage_context_ext) => { - let Some(stage_context) = stage_context_ext.stage_context else { - return Err(Status::invalid_argument("StageContext is empty")); - }; - let stage_id = stage_context_ext.stage_id; - let actor_idx = stage_context_ext.actor_idx as usize; - self.stage_delegation - .add_delegate_info(stage_id, actor_idx, stage_context) - .map_err(|err| { - Status::internal(format!( - "Cannot add delegate to stage_delegation: {err}" - )) - })?; - } - } - } - Ok(Response::new(futures::stream::empty().boxed())) - } -} diff --git a/src/flight_service/mod.rs b/src/flight_service/mod.rs index 5f825578..b773bad9 100644 --- a/src/flight_service/mod.rs +++ b/src/flight_service/mod.rs @@ -1,11 +1,9 @@ mod do_get; -mod do_put; mod service; mod session_builder; mod stream_partitioner_registry; pub(crate) use do_get::DoGet; -pub(crate) use do_put::DoPut; pub use service::ArrowFlightEndpoint; pub use session_builder::SessionBuilder; diff --git a/src/flight_service/service.rs b/src/flight_service/service.rs index 4e884c86..b761f41f 100644 --- a/src/flight_service/service.rs +++ b/src/flight_service/service.rs @@ -2,7 +2,6 @@ use crate::channel_manager::ChannelManager; use crate::flight_service::session_builder::NoopSessionBuilder; use crate::flight_service::stream_partitioner_registry::StreamPartitionerRegistry; use crate::flight_service::SessionBuilder; -use crate::stage_delegation::StageDelegation; use crate::ChannelResolver; use arrow_flight::flight_service_server::FlightService; use arrow_flight::{ @@ -17,7 +16,6 @@ use tonic::{Request, Response, Status, Streaming}; pub struct ArrowFlightEndpoint { pub(super) channel_manager: Arc, - pub(super) stage_delegation: Arc, pub(super) runtime: Arc, pub(super) partitioner_registry: Arc, pub(super) session_builder: Arc, @@ -27,7 +25,6 @@ impl ArrowFlightEndpoint { pub fn new(channel_resolver: impl ChannelResolver + Send + Sync + 'static) -> Self { Self { channel_manager: Arc::new(ChannelManager::new(channel_resolver)), - stage_delegation: Arc::new(StageDelegation::default()), runtime: Arc::new(RuntimeEnv::default()), partitioner_registry: Arc::new(StreamPartitionerRegistry::default()), session_builder: Arc::new(NoopSessionBuilder), @@ -96,9 +93,9 @@ impl FlightService for ArrowFlightEndpoint { async fn do_put( &self, - request: Request>, + _: Request>, ) -> Result, Status> { - self.put(request).await + Err(Status::unimplemented("Not yet implemented")) } type DoExchangeStream = BoxStream<'static, Result>; diff --git a/src/flight_service/stream_partitioner_registry.rs b/src/flight_service/stream_partitioner_registry.rs index 2db17e2c..005a7698 100644 --- a/src/flight_service/stream_partitioner_registry.rs +++ b/src/flight_service/stream_partitioner_registry.rs @@ -3,14 +3,14 @@ use datafusion::error::DataFusionError; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::{ExecutionPlan, Partitioning}; use std::sync::Arc; - +use uuid::Uuid; // TODO: find some way of cleaning up abandoned partitioners /// Keeps track of all the [StreamPartitioner] currently running in the program, identifying them /// by stage id. #[derive(Default)] pub struct StreamPartitionerRegistry { - map: DashMap<(String, usize), Arc>, + map: DashMap<(Uuid, usize), Arc>, } impl StreamPartitionerRegistry { @@ -18,7 +18,7 @@ impl StreamPartitionerRegistry { /// If there was already one, return a reference to it. pub fn get_or_create_stream_partitioner( &self, - id: String, + id: Uuid, actor_idx: usize, plan: Arc, partitioning: Partitioning, @@ -48,7 +48,7 @@ mod tests { let registry = StreamPartitionerRegistry::default(); let partitioner = registry.get_or_create_stream_partitioner( - "test".to_string(), + Uuid::new_v4(), 0, mock_exec(15, 10), Partitioning::RoundRobinBatch(PARTITIONS), @@ -69,7 +69,7 @@ mod tests { let registry = StreamPartitionerRegistry::default(); let partitioner = registry.get_or_create_stream_partitioner( - "test".to_string(), + Uuid::new_v4(), 0, mock_exec(5, 10), Partitioning::RoundRobinBatch(PARTITIONS), @@ -87,7 +87,7 @@ mod tests { let registry = StreamPartitionerRegistry::default(); let partitioner = registry.get_or_create_stream_partitioner( - "test".to_string(), + Uuid::new_v4(), 0, mock_exec(15, 10), Partitioning::Hash(vec![col("c0", &test_schema())?], PARTITIONS), @@ -108,7 +108,7 @@ mod tests { let registry = StreamPartitionerRegistry::default(); let partitioner = registry.get_or_create_stream_partitioner( - "test".to_string(), + Uuid::new_v4(), 0, mock_exec(5, 10), Partitioning::Hash(vec![col("c0", &test_schema())?], PARTITIONS), diff --git a/src/lib.rs b/src/lib.rs index d48ecb85..b18a7397 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,12 @@ mod channel_manager; mod composed_extension_codec; +pub(crate) mod context; mod errors; mod flight_service; mod plan; -mod stage_delegation; #[cfg(test)] pub mod test_utils; -pub use channel_manager::{ - ArrowFlightChannel, BoxCloneSyncChannel, ChannelManager, ChannelResolver, -}; +pub use channel_manager::{BoxCloneSyncChannel, ChannelManager, ChannelResolver}; pub use flight_service::{ArrowFlightEndpoint, SessionBuilder}; -pub use plan::ArrowFlightReadExec; +pub use plan::{assign_stages, ArrowFlightReadExec}; diff --git a/src/plan/arrow_flight_read.rs b/src/plan/arrow_flight_read.rs index dce2472d..ca9635f4 100644 --- a/src/plan/arrow_flight_read.rs +++ b/src/plan/arrow_flight_read.rs @@ -1,36 +1,30 @@ -use crate::channel_manager::{ArrowFlightChannel, ChannelManager}; +use crate::channel_manager::ChannelManager; use crate::composed_extension_codec::ComposedPhysicalExtensionCodec; +use crate::context::{StageContext, StageTaskContext}; use crate::errors::tonic_status_to_datafusion_error; -use crate::flight_service::{DoGet, DoPut}; +use crate::flight_service::DoGet; use crate::plan::arrow_flight_read_proto::ArrowFlightReadExecProtoCodec; -use crate::stage_delegation::{ActorContext, StageContext, StageDelegation}; use arrow_flight::decode::FlightRecordBatchStream; use arrow_flight::error::FlightError; use arrow_flight::flight_service_client::FlightServiceClient; -use datafusion::common::runtime::JoinSet; -use datafusion::common::{exec_datafusion_err, internal_err, plan_err}; +use datafusion::common::{internal_err, plan_err}; use datafusion::error::DataFusionError; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; -use datafusion_proto::physical_plan::to_proto::serialize_partitioning; -use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; use futures::{TryFutureExt, TryStreamExt}; use std::any::Any; use std::fmt::Formatter; use std::sync::Arc; -use tokio::sync::OnceCell; use tonic::IntoRequest; -use url::{ParseError, Url}; -use uuid::Uuid; #[derive(Debug, Clone)] pub struct ArrowFlightReadExec { properties: PlanProperties, child: Arc, - next_stage_context_cell: Arc)>>, + pub(crate) stage_context: Option, } impl ArrowFlightReadExec { @@ -43,32 +37,44 @@ impl ArrowFlightReadExec { Boundedness::Bounded, ), child, - next_stage_context_cell: Arc::new(OnceCell::new()), + stage_context: None, } } } impl DisplayAs for ArrowFlightReadExec { fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - match &self.properties.partitioning { - Partitioning::RoundRobinBatch(size) => { - write!(f, "ArrowFlightReadExec: input_actors={size}") - } - Partitioning::Hash(phy_exprs, size) => { - let phy_exprs_str = phy_exprs + let (hash_expr, size) = match &self.properties.partitioning { + Partitioning::RoundRobinBatch(size) => (vec![], size), + Partitioning::Hash(hash_expr, size) => (hash_expr.clone(), size), + Partitioning::UnknownPartitioning(size) => (vec![], size), + }; + + let hash_expr = hash_expr + .iter() + .map(|e| format!("{e}")) + .collect::>() + .join(", "); + + let stage_trail = match &self.stage_context { + None => " (Unassigned stage)".to_string(), + Some(stage) => format!( + " stage_id={} input_stage_id={} input_hosts=[{}]", + stage.id, + stage.input_id, + stage + .input_urls .iter() - .map(|e| format!("{e}")) - .collect::>() - .join(", "); - write!( - f, - "ArrowFlightReadExec: input_actors={size} hash=[{phy_exprs_str}]" - ) - } - Partitioning::UnknownPartitioning(size) => { - write!(f, "ArrowFlightReadExec: input_actors={size}") - } - } + .map(|url| url.to_string()) + .collect::>() + .join(", ") + ), + }; + + write!( + f, + "ArrowFlightReadExec: input_tasks={size} hash_expr=[{hash_expr}]{stage_trail}", + ) } } @@ -102,7 +108,7 @@ impl ExecutionPlan for ArrowFlightReadExec { Ok(Arc::new(Self { properties: self.properties.clone(), child: Arc::clone(&children[0]), - next_stage_context_cell: Arc::new(OnceCell::new()), + stage_context: self.stage_context.clone(), })) } @@ -111,66 +117,30 @@ impl ExecutionPlan for ArrowFlightReadExec { partition: usize, context: Arc, ) -> datafusion::common::Result { - let partitioning = self.properties.partitioning.clone(); - - let channel_manager = ChannelManager::try_from_session(context.session_config())?; + let plan = Arc::clone(&self.child); + let channel_manager: ChannelManager = context.as_ref().try_into()?; - let current_actor_opt = context.session_config().get_extension::(); - let current_stage_opt = context.session_config().get_extension::(); - let stage_delegation_opt = context.session_config().get_extension::(); - if current_stage_opt.is_some() && stage_delegation_opt.is_none() { - return internal_err!("No StageDelegation extension found in the SessionConfig even though a StageContext was present."); - } - if current_stage_opt.is_some() && current_actor_opt.is_none() { - return internal_err!("No ActorContext extension found in the SessionConfig even though a StageContext was present."); - } - let current_actor = current_actor_opt.unwrap_or_default(); + let Some(stage) = self.stage_context.clone() else { + return plan_err!("No stage assigned to this ArrowFlightReadExec"); + }; + let task_context = context.session_config().get_extension::(); - let plan = Arc::clone(&self.child); - let next_stage_context = Arc::clone(&self.next_stage_context_cell); + let hash_expr = match &self.properties.partitioning { + Partitioning::Hash(hash_expr, _) => hash_expr.clone(), + _ => vec![], + }; let stream = async move { - let (next_stage_context, channels) = next_stage_context.get_or_try_init(|| async { - if let Some(ref current_stage) = current_stage_opt { - if current_actor.actor_idx == current_stage.delegate { - // We are inside a stage, and we are the delegate, so need to - // build the channels and communicate them. - build_next_stage(&channel_manager, Some(current_stage), partitioning).await - } else { - // We are inside a stage, but we are not the delegate, so we need to - // wait for the delegate to tell us what the new channels are. - let Some(stage_delegation) = stage_delegation_opt else { - return internal_err!("No StageDelegation extension found in the SessionConfig even though a StageContext was present."); - }; - listen_to_next_stage( - &channel_manager, - &stage_delegation, - current_stage.id.clone(), - current_actor.actor_idx as usize - ).await - } - } else { - // We are not in a stage, the whole thing starts here. - build_next_stage(&channel_manager, None, partitioning).await - } - }).await?; - - if let Some(current_stage) = current_stage_opt { - if current_actor.actor_idx == current_stage.delegate { - // We are the delegate, and it's our duty to communicate the next stage context - // to the other actors that are not us. They will be waiting for us to send - // them this info. - communicate_next_stage( - Arc::clone(&channel_manager), - current_stage.as_ref().clone(), - next_stage_context.clone() - ).await?; - } + if partition >= stage.input_urls.len() { + return internal_err!( + "Invalid partition {partition} for a stage with only {} inputs", + stage.input_urls.len() + ); } - if partition >= channels.len() { - return internal_err!("Invalid channel index {partition} with a total number of {} channels", channels.len()); - } + let channel = channel_manager + .get_channel_for_url(&stage.input_urls[partition]) + .await?; let mut codec = ComposedPhysicalExtensionCodec::default(); codec.push(ArrowFlightReadExecProtoCodec); @@ -178,31 +148,34 @@ impl ExecutionPlan for ArrowFlightReadExec { let ticket = DoGet::new_remote_plan_exec_ticket( plan, - next_stage_context.clone(), - ActorContext { - caller_actor_idx: current_actor.actor_idx, - actor_idx: partition as u64, - }, - &codec + stage.input_id, + partition, + task_context.as_ref().map(|v| v.task_idx).unwrap_or(0), + stage.n_tasks, + &hash_expr, + &codec, )?; - let mut client = FlightServiceClient::new(channels[partition].channel.clone()); + let mut client = FlightServiceClient::new(channel); let stream = client .do_get(ticket.into_request()) .await - .map_err(|err| tonic_status_to_datafusion_error(&err).unwrap_or_else(|| { - DataFusionError::External(Box::new(err)) - }))? + .map_err(|err| { + tonic_status_to_datafusion_error(&err) + .unwrap_or_else(|| DataFusionError::External(Box::new(err))) + })? .into_inner() .map_err(|err| FlightError::Tonic(Box::new(err))); - Ok(FlightRecordBatchStream::new_from_flight_data(stream) - .map_err(|err| match err { + Ok( + FlightRecordBatchStream::new_from_flight_data(stream).map_err(|err| match err { FlightError::Tonic(status) => tonic_status_to_datafusion_error(&status) .unwrap_or_else(|| DataFusionError::External(Box::new(status))), - err => DataFusionError::External(Box::new(err)) - })) - }.try_flatten_stream(); + err => DataFusionError::External(Box::new(err)), + }), + ) + } + .try_flatten_stream(); Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), @@ -210,95 +183,3 @@ impl ExecutionPlan for ArrowFlightReadExec { ))) } } - -/// Builds the next stage context. This should be done by either the delegate in case we are already -/// inside a stage context, or unconditionally if we are not in a stage context. -async fn build_next_stage( - channel_manager: &ChannelManager, - current_stage: Option<&StageContext>, - partitioning: Partitioning, -) -> Result<(StageContext, Vec), DataFusionError> { - let output_partitions = partitioning.partition_count(); - let channels = channel_manager.get_n_channels(output_partitions).await?; - - let next_stage_context = StageContext { - id: Uuid::new_v4().to_string(), - partitioning: Some(serialize_partitioning( - &partitioning, - // TODO: this should be set by the user - &DefaultPhysicalExtensionCodec {}, - )?), - delegate: 0, - actors: channels.iter().map(|t| t.url.to_string()).collect(), - prev_actors: current_stage.map(|v| v.actors.len()).unwrap_or(1) as u64, - }; - Ok((next_stage_context, channels)) -} - -/// Communicates the next stage context to all the actors that are not us. This should be -/// done by the delegate in a stage, as it's the one responsible for ensuring every actor in -/// a stage knows how the next stage looks like. -async fn communicate_next_stage( - channel_manager: Arc, - current_stage: StageContext, - next_stage: StageContext, -) -> Result<(), DataFusionError> { - let actors = current_stage - .actors - .iter() - .enumerate() - // Do not communicate to self. - .filter(|(i, _)| *i != current_stage.delegate as usize) - .map(|(i, url)| Ok((i, Url::parse(url.as_str())?))) - .collect::, _>>() - .map_err(|err: ParseError| { - exec_datafusion_err!("Invalid actor Urls in next stage context: {err}") - })?; - - let mut join_set = JoinSet::new(); - for (actor_idx, url) in actors { - let stage_id = current_stage.id.clone(); - let next_stage = next_stage.clone(); - let channel_manager = Arc::clone(&channel_manager); - join_set.spawn(async move { - let flight_data = DoPut::new_stage_context_flight_data(stage_id, actor_idx, next_stage); - - let channel = channel_manager.get_channel_for_url(&url).await?; - let mut client = FlightServiceClient::new(channel.channel.clone()); - client - .do_put(futures::stream::once(async move { flight_data })) - .await - .map_err(|err| DataFusionError::External(Box::new(err))) - }); - } - for res in join_set.join_all().await { - res?; - } - Ok(()) -} - -/// Waits until the delegate in the current stage communicates us the next stage context. It's -/// the responsibility of the delegate to choose the next stage context, and other actors in the -/// stage must wait for that info to be communicated. This function does just that. -async fn listen_to_next_stage( - channel_manager: &ChannelManager, - stage_delegation: &StageDelegation, - stage_id: String, - actor_idx: usize, -) -> Result<(StageContext, Vec), DataFusionError> { - let next_stage_context = stage_delegation - .wait_for_delegate_info(stage_id, actor_idx) - .await?; - let urls = next_stage_context - .actors - .iter() - .map(|a| Url::parse(a.as_str())) - .collect::, _>>() - .map_err(|err| exec_datafusion_err!("Invalid actor Urls in next stage context: {err}"))?; - let channel_futures = urls - .iter() - .map(|url| channel_manager.get_channel_for_url(url)); - - let channels = futures::future::try_join_all(channel_futures).await?; - Ok((next_stage_context, channels)) -} diff --git a/src/plan/arrow_flight_read_proto.rs b/src/plan/arrow_flight_read_proto.rs index 689a25f2..b5e3f5d9 100644 --- a/src/plan/arrow_flight_read_proto.rs +++ b/src/plan/arrow_flight_read_proto.rs @@ -1,4 +1,6 @@ +use crate::context::StageContext; use crate::plan::arrow_flight_read::ArrowFlightReadExec; +use datafusion::error::DataFusionError; use datafusion::execution::FunctionRegistry; use datafusion::physical_plan::ExecutionPlan; use datafusion_proto::physical_plan::from_proto::parse_protobuf_partitioning; @@ -8,6 +10,8 @@ use datafusion_proto::protobuf; use datafusion_proto::protobuf::proto_error; use prost::Message; use std::sync::Arc; +use url::Url; +use uuid::Uuid; /// DataFusion [PhysicalExtensionCodec] implementation that allows sending and receiving /// [ArrowFlightReadExecProto] over the wire. @@ -21,8 +25,10 @@ impl PhysicalExtensionCodec for ArrowFlightReadExecProtoCodec { inputs: &[Arc], registry: &dyn FunctionRegistry, ) -> datafusion::common::Result> { - let ArrowFlightReadExecProto { partitioning } = - ArrowFlightReadExecProto::decode(buf).map_err(|err| proto_error(format!("{err}")))?; + let ArrowFlightReadExecProto { + partitioning, + stage_context, + } = ArrowFlightReadExecProto::decode(buf).map_err(|err| proto_error(format!("{err}")))?; if inputs.len() != 1 { return Err(proto_error(format!( @@ -31,6 +37,10 @@ impl PhysicalExtensionCodec for ArrowFlightReadExecProtoCodec { ))); } + let Some(stage_context) = stage_context else { + return Err(proto_error("Missing stage context")); + }; + let Some(partitioning) = parse_protobuf_partitioning( partitioning.as_ref(), registry, @@ -40,10 +50,26 @@ impl PhysicalExtensionCodec for ArrowFlightReadExecProtoCodec { else { return Err(proto_error("Partitioning not specified")); }; - Ok(Arc::new(ArrowFlightReadExec::new( - inputs[0].clone(), - partitioning, - ))) + let mut node = ArrowFlightReadExec::new(inputs[0].clone(), partitioning); + + fn parse_uuid(uuid: &str) -> Result { + uuid.parse::() + .map_err(|err| proto_error(format!("{err}"))) + } + + node.stage_context = Some(StageContext { + id: parse_uuid(&stage_context.id)?, + n_tasks: stage_context.n_tasks as usize, + input_id: parse_uuid(&stage_context.input_id)?, + input_urls: stage_context + .input_urls + .iter() + .map(|url| Url::parse(url)) + .collect::, _>>() + .map_err(|err| proto_error(format!("{err}")))?, + }); + + Ok(Arc::new(node)) } fn try_encode( @@ -57,12 +83,27 @@ impl PhysicalExtensionCodec for ArrowFlightReadExecProtoCodec { node.name() ))); }; + let Some(stage_context) = &node.stage_context else { + return Err(proto_error( + "Upon serializing the ArrowFlightReadExec, the stage context must be set.", + )); + }; ArrowFlightReadExecProto { partitioning: Some(serialize_partitioning( &node.properties().partitioning, &DefaultPhysicalExtensionCodec {}, )?), + stage_context: Some(StageContextProto { + id: stage_context.id.to_string(), + n_tasks: stage_context.n_tasks as u64, + input_id: stage_context.input_id.to_string(), + input_urls: stage_context + .input_urls + .iter() + .map(|url| url.to_string()) + .collect(), + }), } .encode(buf) .map_err(|err| proto_error(format!("{err}"))) @@ -74,6 +115,20 @@ impl PhysicalExtensionCodec for ArrowFlightReadExecProtoCodec { /// to send them over the wire. #[derive(Clone, PartialEq, ::prost::Message)] pub struct ArrowFlightReadExecProto { - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "1")] partitioning: Option, + #[prost(message, tag = "2")] + stage_context: Option, +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct StageContextProto { + #[prost(string, tag = "1")] + pub id: String, + #[prost(uint64, tag = "2")] + pub n_tasks: u64, + #[prost(string, tag = "3")] + pub input_id: String, + #[prost(string, repeated, tag = "4")] + pub input_urls: Vec, } diff --git a/src/plan/assign_stages.rs b/src/plan/assign_stages.rs new file mode 100644 index 00000000..48228879 --- /dev/null +++ b/src/plan/assign_stages.rs @@ -0,0 +1,73 @@ +use crate::context::StageContext; +use crate::{ArrowFlightReadExec, ChannelManager}; +use datafusion::common::tree_node::{Transformed, TreeNode}; +use datafusion::error::DataFusionError; +use datafusion::physical_plan::ExecutionPlan; +use std::cell::RefCell; +use std::sync::Arc; +use uuid::Uuid; + +pub fn assign_stages( + plan: Arc, + channel_manager: impl TryInto, +) -> Result, DataFusionError> { + let stack = RefCell::new(vec![]); + let mut i = 0; + + let urls = channel_manager.try_into()?.get_urls()?; + + Ok(plan + .transform_down_up( + |plan| { + let Some(node) = plan.as_any().downcast_ref::() else { + return Ok(Transformed::no(plan)); + }; + + // If the current ArrowFlightReadExec already has a task assigned, do nothing. + if let Some(ref stage_context) = node.stage_context { + stack.borrow_mut().push(stage_context.clone()); + return Ok(Transformed::no(plan.clone())); + } + + let mut input_urls = vec![]; + for _ in 0..node.properties().output_partitioning().partition_count() { + // Just round-robin the workers for assigning tasks. + input_urls.push(urls[i % urls.len()].clone()); + i += 1; + } + + let stage_context = if let Some(prev_stage) = stack.borrow().last() { + StageContext { + id: prev_stage.input_id, + n_tasks: prev_stage.input_urls.len(), + input_id: Uuid::new_v4(), + input_urls, + } + } else { + // This is the first ArrowFlightReadExec encountered in the plan, that's + // why there is no stage yet. + // + // As this task will not need to fan out data to upper stages, it does not + // care about output tasks. + StageContext { + id: Uuid::new_v4(), + n_tasks: 1, + input_id: Uuid::new_v4(), + input_urls, + } + }; + + stack.borrow_mut().push(stage_context.clone()); + let mut node = node.clone(); + node.stage_context = Some(stage_context); + Ok(Transformed::yes(Arc::new(node))) + }, + |plan| { + if plan.name() == "ArrowFlightReadExec" { + stack.borrow_mut().pop(); + } + Ok(Transformed::no(plan)) + }, + )? + .data) +} diff --git a/src/plan/mod.rs b/src/plan/mod.rs index 134672b2..25ba9d41 100644 --- a/src/plan/mod.rs +++ b/src/plan/mod.rs @@ -1,5 +1,7 @@ mod arrow_flight_read; mod arrow_flight_read_proto; +mod assign_stages; pub use arrow_flight_read::ArrowFlightReadExec; pub use arrow_flight_read_proto::ArrowFlightReadExecProtoCodec; +pub use assign_stages::assign_stages; diff --git a/src/stage_delegation/context.rs b/src/stage_delegation/context.rs deleted file mode 100644 index d3f69bbe..00000000 --- a/src/stage_delegation/context.rs +++ /dev/null @@ -1,25 +0,0 @@ -use datafusion_proto::protobuf; - -/// Contains the necessary context for actors in a stage to perform a distributed query. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct StageContext { - /// Unique identifier of the stage - #[prost(string, tag = "1")] - pub id: String, - #[prost(uint64, tag = "2")] - pub delegate: u64, - #[prost(string, repeated, tag = "3")] - pub actors: Vec, - #[prost(message, optional, tag = "4")] - pub partitioning: Option, - #[prost(uint64, tag = "5")] - pub prev_actors: u64, -} - -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ActorContext { - #[prost(uint64, tag = "1")] - pub caller_actor_idx: u64, - #[prost(uint64, tag = "2")] - pub actor_idx: u64, -} diff --git a/src/stage_delegation/delegation.rs b/src/stage_delegation/delegation.rs deleted file mode 100644 index 5d4f9e18..00000000 --- a/src/stage_delegation/delegation.rs +++ /dev/null @@ -1,290 +0,0 @@ -use super::StageContext; -use dashmap::{DashMap, Entry}; -use datafusion::common::{exec_datafusion_err, exec_err}; -use datafusion::error::DataFusionError; -use std::time::Duration; -use tokio::sync::oneshot; - -/// In each stage of the distributed plan, there will be N workers. All these workers -/// need to coordinate to pull data from the next stage, which will contain M workers. -/// -/// The way this is done is that for each stage, 1 worker is elected as "delegate", and -/// the rest of the workers are mere actors that wait for the delegate to tell them -/// where to go. -/// -/// Each actor in a stage knows the url of the rest of the actors, so the delegate actor can -/// go one by one telling them what does the next stage look like. That way, all the actors -/// will agree on where to go to pull data from even if they are hosted in different physical -/// machines. -/// -/// While starting a stage, several things can happen: -/// 1. The delegate can be very quick and choose the next stage context even before the other -/// actors have started waiting. -/// 2. The delegate can be very slow, and other actors might be waiting for the next context -/// info before the delegate even starting the choice of the next stage context. -/// -/// On 1, the `add_delegate_info` call will create an entry in the [DashMap] with a -/// [oneshot::Receiver] already populated with the [StageContext], that other actors -/// are free to pick up at their own pace. -/// -/// On 2, the `wait_for_delegate_info` call will create an entry in the [DashMap] with a -/// [oneshot::Sender], and listen on the other end of the channel [oneshot::Receiver] for -/// the delegate to put something there. -pub struct StageDelegation { - stage_targets: DashMap<(String, usize), Oneof>, - wait_timeout: Duration, -} - -impl Default for StageDelegation { - fn default() -> Self { - Self { - stage_targets: DashMap::default(), - wait_timeout: Duration::from_secs(5), - } - } -} - -impl StageDelegation { - /// Puts the [StageContext] info so that an actor can pick it up with `wait_for_delegate_info`. - /// - /// - If the actor was already waiting for this info, it just puts it on the - /// existing transmitter end. - /// - If no actor was waiting for this info, build a new channel and store the receiving end - /// so that actor can pick it up when it is ready. - pub fn add_delegate_info( - &self, - stage_id: String, - actor_idx: usize, - next_stage_context: StageContext, - ) -> Result<(), DataFusionError> { - let tx = match self.stage_targets.entry((stage_id, actor_idx)) { - Entry::Occupied(entry) => match entry.get() { - Oneof::Sender(_) => match entry.remove() { - Oneof::Sender(tx) => tx, - Oneof::Receiver(_) => unreachable!(), - }, - // This call is idempotent. If there's already a Receiver end here, it means that - // add_delegate_info() for the same stage_id was already called once. - Oneof::Receiver(_) => return Ok(()), - }, - Entry::Vacant(entry) => { - let (tx, rx) = oneshot::channel(); - entry.insert(Oneof::Receiver(rx)); - tx - } - }; - - // TODO: `send` does not wait for the other end of the channel to receive the message, - // so if nobody waits for it, we might leak an entry in `stage_targets` that will never - // be cleaned up. We can either: - // 1. schedule a cleanup task that iterates the entries cleaning up old ones - // 2. find some other API that allows us to .await until the other end receives the message, - // and on a timeout, cleanup the entry anyway. - tx.send(next_stage_context) - .map_err(|_| exec_datafusion_err!("Could not send stage context info")) - } - - /// Waits for the [StageContext] info to be provided by the delegate and returns it. - /// - /// - If the delegate already put this info, consume it immediately and return it. - /// - If the delegate did not put this info yet, create a new channel for the delegate to - /// store the info, and wait for that to happen, returning the info when it's ready. - pub async fn wait_for_delegate_info( - &self, - stage_id: String, - actor_idx: usize, - ) -> Result { - let rx = match self.stage_targets.entry((stage_id.clone(), actor_idx)) { - Entry::Occupied(entry) => match entry.get() { - Oneof::Sender(_) => return exec_err!("Programming error: while waiting for delegate info the entry in the StageDelegation target map cannot be a Sender"), - Oneof::Receiver(_) => match entry.remove() { - Oneof::Sender(_) => unreachable!(), - Oneof::Receiver(rx) => rx - }, - }, - Entry::Vacant(entry) => { - let (tx, rx) = oneshot::channel(); - entry.insert(Oneof::Sender(tx)); - rx - } - }; - - tokio::time::timeout(self.wait_timeout, rx) - .await - .map_err(|_| exec_datafusion_err!("Timeout waiting for delegate to post stage info for stage {stage_id} in actor {actor_idx}"))? - .map_err(|err| { - exec_datafusion_err!( - "Error waiting for delegate to tell us in which stage we are in: {err}" - ) - }) - } -} - -enum Oneof { - Sender(oneshot::Sender), - Receiver(oneshot::Receiver), -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::stage_delegation::StageContext; - use std::sync::Arc; - use uuid::Uuid; - - fn create_test_stage_context() -> StageContext { - StageContext { - id: Uuid::new_v4().to_string(), - delegate: 0, - prev_actors: 0, - actors: vec![ - "http://localhost:8080".to_string(), - "http://localhost:8081".to_string(), - ], - partitioning: Default::default(), - } - } - - #[tokio::test] - async fn test_delegate_first_then_actor_waits() { - let delegation = StageDelegation::default(); - let stage_id = Uuid::new_v4().to_string(); - let stage_context = create_test_stage_context(); - - // Delegate adds info first - delegation - .add_delegate_info(stage_id.clone(), 0, stage_context.clone()) - .unwrap(); - - // Actor waits for info (should get it immediately) - let received_context = delegation - .wait_for_delegate_info(stage_id, 0) - .await - .unwrap(); - assert_eq!(stage_context, received_context); - - // The stage target was cleaned up. - assert_eq!(delegation.stage_targets.len(), 0); - } - - #[tokio::test] - async fn test_actor_waits_first_then_delegate_adds() { - let delegation = Arc::new(StageDelegation::default()); - let stage_id = Uuid::new_v4().to_string(); - let stage_context = create_test_stage_context(); - - // Spawn a task that waits for delegate info - let delegation_clone = Arc::clone(&delegation); - let id = stage_id.clone(); - let wait_task = - tokio::spawn(async move { delegation_clone.wait_for_delegate_info(id, 0).await }); - - // Give the wait task a moment to start - tokio::time::sleep(Duration::from_millis(10)).await; - - // Delegate adds info - delegation - .add_delegate_info(stage_id, 0, stage_context.clone()) - .unwrap(); - - // Wait task should complete with the stage context - let received_context = wait_task.await.unwrap().unwrap(); - assert_eq!(stage_context, received_context); - - // The stage target was cleaned up. - assert_eq!(delegation.stage_targets.len(), 0); - } - - #[tokio::test] - async fn test_multiple_actors_waiting_for_same_stage() { - let delegation = Arc::new(StageDelegation::default()); - let stage_id = Uuid::new_v4().to_string(); - let stage_context = create_test_stage_context(); - - // First actor waits - let delegation_clone1 = Arc::clone(&delegation); - let id = stage_id.clone(); - let wait_task1 = - tokio::spawn(async move { delegation_clone1.wait_for_delegate_info(id, 0).await }); - - // Give the first wait task a moment to start - tokio::time::sleep(Duration::from_millis(10)).await; - - // Second actor tries to wait for the same stage - this should fail gracefully - // since there can only be one waiting receiver per stage - let result = delegation.wait_for_delegate_info(stage_id.clone(), 0).await; - assert!(result.is_err()); - - // Delegate adds info - the first actor should receive it - delegation - .add_delegate_info(stage_id, 0, stage_context.clone()) - .unwrap(); - - let received_context = wait_task1.await.unwrap().unwrap(); - assert_eq!(received_context.id, stage_context.id); - } - - #[tokio::test] - async fn test_different_stages_concurrent() { - let delegation = Arc::new(StageDelegation::default()); - let stage_id1 = Uuid::new_v4().to_string(); - let stage_id2 = Uuid::new_v4().to_string(); - let stage_context1 = create_test_stage_context(); - let stage_context2 = create_test_stage_context(); - - // Both actors wait for different stages - let delegation_clone1 = Arc::clone(&delegation); - let delegation_clone2 = Arc::clone(&delegation); - let id1 = stage_id1.clone(); - let id2 = stage_id2.clone(); - let wait_task1 = - tokio::spawn(async move { delegation_clone1.wait_for_delegate_info(id1, 0).await }); - let wait_task2 = - tokio::spawn(async move { delegation_clone2.wait_for_delegate_info(id2, 0).await }); - - // Give wait tasks a moment to start - tokio::time::sleep(Duration::from_millis(10)).await; - - // Delegates add info for both stages - delegation - .add_delegate_info(stage_id1, 0, stage_context1.clone()) - .unwrap(); - delegation - .add_delegate_info(stage_id2, 0, stage_context2.clone()) - .unwrap(); - - // Both should receive their respective contexts - let received_context1 = wait_task1.await.unwrap().unwrap(); - let received_context2 = wait_task2.await.unwrap().unwrap(); - - assert_eq!(received_context1.id, stage_context1.id.to_string()); - assert_eq!(received_context2.id, stage_context2.id.to_string()); - - // The stage target was cleaned up. - assert_eq!(delegation.stage_targets.len(), 0); - } - - #[tokio::test] - async fn test_add_delegate_info_twice_same_stage() { - let delegation = StageDelegation::default(); - let stage_id = Uuid::new_v4().to_string(); - let stage_context = create_test_stage_context(); - - // First add should succeed - delegation - .add_delegate_info(stage_id.clone(), 0, stage_context.clone()) - .unwrap(); - - // Second add for same stage should succeed (idempotent) - delegation - .add_delegate_info(stage_id.clone(), 0, stage_context.clone()) - .unwrap(); - - // Receiving should still work even if `add_delegate_info` was called two times - let received_context = delegation - .wait_for_delegate_info(stage_id, 0) - .await - .unwrap(); - assert_eq!(received_context, stage_context); - } -} diff --git a/src/stage_delegation/mod.rs b/src/stage_delegation/mod.rs deleted file mode 100644 index de08df60..00000000 --- a/src/stage_delegation/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod context; -mod delegation; - -pub use context::{ActorContext, StageContext}; -pub use delegation::StageDelegation; diff --git a/tests/common/insta.rs b/tests/common/insta.rs index 9df8aa35..61226841 100644 --- a/tests/common/insta.rs +++ b/tests/common/insta.rs @@ -1,5 +1,5 @@ +use datafusion::common::utils::get_available_parallelism; use std::env; -use std::thread::available_parallelism; #[macro_export] macro_rules! assert_snapshot { @@ -16,10 +16,14 @@ pub fn settings() -> insta::Settings { let cwd = env::current_dir().unwrap(); let cwd = cwd.to_str().unwrap(); settings.add_filter(cwd.trim_start_matches("/"), ""); - let cpus = available_parallelism().unwrap(); + let cpus = get_available_parallelism(); settings.add_filter(&format!(", {cpus}\\)"), ", CPUs)"); settings.add_filter(&format!("\\({cpus}\\)"), "(CPUs)"); - settings.add_filter(&format!("={cpus}"), "=CPUs"); + settings.add_filter(&format!("input_partitions={cpus}"), "input_partitions=CPUs"); + settings.add_filter( + r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}", + "UUID", + ); settings } diff --git a/tests/common/localhost.rs b/tests/common/localhost.rs index e0c16edf..5ccc9b4c 100644 --- a/tests/common/localhost.rs +++ b/tests/common/localhost.rs @@ -5,8 +5,7 @@ use datafusion::common::DataFusionError; use datafusion::execution::SessionStateBuilder; use datafusion::prelude::SessionContext; use datafusion_distributed::{ - ArrowFlightChannel, ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelManager, ChannelResolver, - SessionBuilder, + ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelManager, ChannelResolver, SessionBuilder, }; use std::error::Error; use std::sync::atomic::AtomicUsize; @@ -78,29 +77,17 @@ impl LocalHostChannelResolver { #[async_trait] impl ChannelResolver for LocalHostChannelResolver { - async fn get_n_channels(&self, n: usize) -> Result, DataFusionError> { - let mut result = vec![]; - for _ in 0..n { - let i = self.i.fetch_add(1, std::sync::atomic::Ordering::SeqCst) % self.ports.len(); - let port = self.ports[i]; - let url = format!("http://localhost:{port}"); - let endpoint = Channel::from_shared(url.clone()).map_err(external_err)?; - let channel = endpoint.connect().await.map_err(external_err)?; - result.push(ArrowFlightChannel { - url: Url::parse(&url).map_err(external_err)?, - channel: BoxCloneSyncChannel::new(channel), - }) - } - Ok(result) + fn get_urls(&self) -> Result, DataFusionError> { + self.ports + .iter() + .map(|port| format!("http://localhost:{port}")) + .map(|url| Url::parse(&url).map_err(external_err)) + .collect::, _>>() } - - async fn get_channel_for_url(&self, url: &Url) -> Result { + async fn get_channel_for_url(&self, url: &Url) -> Result { let endpoint = Channel::from_shared(url.to_string()).map_err(external_err)?; let channel = endpoint.connect().await.map_err(external_err)?; - Ok(ArrowFlightChannel { - url: url.clone(), - channel: BoxCloneSyncChannel::new(channel), - }) + Ok(BoxCloneSyncChannel::new(channel)) } } diff --git a/tests/custom_extension_codec.rs b/tests/custom_extension_codec.rs index 9016cfbe..2b1dbfa6 100644 --- a/tests/custom_extension_codec.rs +++ b/tests/custom_extension_codec.rs @@ -3,6 +3,7 @@ mod common; #[cfg(test)] mod tests { + use crate::assert_snapshot; use crate::common::localhost::start_localhost_context; use datafusion::arrow::array::Int64Array; use datafusion::arrow::compute::SortOptions; @@ -26,11 +27,10 @@ mod tests { use datafusion::physical_plan::{ displayable, execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; - use datafusion_distributed::{ArrowFlightReadExec, SessionBuilder}; + use datafusion_distributed::{assign_stages, ArrowFlightReadExec, SessionBuilder}; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_proto::protobuf::proto_error; use futures::{stream, TryStreamExt}; - use insta::assert_snapshot; use prost::Message; use std::any::Any; use std::fmt::Formatter; @@ -66,13 +66,14 @@ mod tests { "); let distributed_plan = build_plan(true)?; + let distributed_plan = assign_stages(distributed_plan, &ctx)?; assert_snapshot!(displayable(distributed_plan.as_ref()).indent(true).to_string(), @r" SortExec: expr=[numbers@0 DESC NULLS LAST], preserve_partitioning=[false] RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=10 - ArrowFlightReadExec: input_actors=10 + ArrowFlightReadExec: input_tasks=10 hash_expr=[] stage_id=UUID input_stage_id=UUID input_hosts=[http://localhost:50050/, http://localhost:50051/, http://localhost:50052/, http://localhost:50050/, http://localhost:50051/, http://localhost:50052/, http://localhost:50050/, http://localhost:50051/, http://localhost:50052/, http://localhost:50050/] SortExec: expr=[numbers@0 DESC NULLS LAST], preserve_partitioning=[false] - ArrowFlightReadExec: input_actors=1 hash=[numbers@0] + ArrowFlightReadExec: input_tasks=1 hash_expr=[numbers@0] stage_id=UUID input_stage_id=UUID input_hosts=[http://localhost:50051/] FilterExec: numbers@0 > 1 Int64ListExec: length=6 "); @@ -80,7 +81,7 @@ mod tests { let stream = execute_stream(single_node_plan, ctx.task_ctx())?; let batches_single_node = stream.try_collect::>().await?; - assert_snapshot!(pretty_format_batches(&batches_single_node)?, @r" + assert_snapshot!(pretty_format_batches(&batches_single_node).unwrap(), @r" +---------+ | numbers | +---------+ @@ -95,7 +96,7 @@ mod tests { let stream = execute_stream(distributed_plan, ctx.task_ctx())?; let batches_distributed = stream.try_collect::>().await?; - assert_snapshot!(pretty_format_batches(&batches_distributed)?, @r" + assert_snapshot!(pretty_format_batches(&batches_distributed).unwrap(), @r" +---------+ | numbers | +---------+ diff --git a/tests/distributed_aggregation.rs b/tests/distributed_aggregation.rs index c92db407..397157be 100644 --- a/tests/distributed_aggregation.rs +++ b/tests/distributed_aggregation.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)] mod common; #[cfg(test)] @@ -8,6 +9,7 @@ mod tests { use crate::common::plan::distribute_aggregate; use datafusion::arrow::util::pretty::pretty_format_batches; use datafusion::physical_plan::{displayable, execute_stream}; + use datafusion_distributed::assign_stages; use futures::TryStreamExt; use std::error::Error; @@ -24,6 +26,8 @@ mod tests { let physical_str = displayable(physical.as_ref()).indent(true).to_string(); let physical_distributed = distribute_aggregate(physical.clone())?; + let physical_distributed = assign_stages(physical_distributed, &ctx)?; + let physical_distributed_str = displayable(physical_distributed.as_ref()) .indent(true) .to_string(); @@ -50,12 +54,12 @@ mod tests { SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true] ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))] AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] - ArrowFlightReadExec: input_actors=8 + ArrowFlightReadExec: input_tasks=8 hash_expr=[] stage_id=UUID input_stage_id=UUID input_hosts=[http://localhost:50050/, http://localhost:50051/, http://localhost:50052/, http://localhost:50050/, http://localhost:50051/, http://localhost:50052/, http://localhost:50050/, http://localhost:50051/] CoalesceBatchesExec: target_batch_size=8192 RepartitionExec: partitioning=Hash([RainToday@0], CPUs), input_partitions=CPUs RepartitionExec: partitioning=RoundRobinBatch(CPUs), input_partitions=1 AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] - ArrowFlightReadExec: input_actors=1 hash=[RainToday@0] + ArrowFlightReadExec: input_tasks=1 hash_expr=[RainToday@0] stage_id=UUID input_stage_id=UUID input_hosts=[http://localhost:50052/] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet ", ); diff --git a/tests/error_propagation.rs b/tests/error_propagation.rs index f9ae75f8..81e41ee6 100644 --- a/tests/error_propagation.rs +++ b/tests/error_propagation.rs @@ -15,7 +15,7 @@ mod tests { use datafusion::physical_plan::{ execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; - use datafusion_distributed::{ArrowFlightReadExec, SessionBuilder}; + use datafusion_distributed::{assign_stages, ArrowFlightReadExec, SessionBuilder}; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_proto::protobuf::proto_error; use futures::{stream, TryStreamExt}; @@ -54,7 +54,7 @@ mod tests { Partitioning::RoundRobinBatch(size), )); } - + let plan = assign_stages(plan, &ctx)?; let stream = execute_stream(plan, ctx.task_ctx())?; let Err(err) = stream.try_collect::>().await else { diff --git a/tests/highly_distributed_query.rs b/tests/highly_distributed_query.rs index ce96a5f9..c04ee72d 100644 --- a/tests/highly_distributed_query.rs +++ b/tests/highly_distributed_query.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)] mod common; #[cfg(test)] @@ -7,7 +8,7 @@ mod tests { use crate::common::parquet::register_parquet_tables; use datafusion::physical_expr::Partitioning; use datafusion::physical_plan::{displayable, execute_stream}; - use datafusion_distributed::ArrowFlightReadExec; + use datafusion_distributed::{assign_stages, ArrowFlightReadExec}; use futures::TryStreamExt; use std::error::Error; use std::sync::Arc; @@ -34,6 +35,7 @@ mod tests { Partitioning::RoundRobinBatch(size), )); } + let physical_distributed = assign_stages(physical_distributed, &ctx)?; let physical_distributed_str = displayable(physical_distributed.as_ref()) .indent(true) .to_string(); @@ -44,9 +46,9 @@ mod tests { assert_snapshot!(physical_distributed_str, @r" - ArrowFlightReadExec: input_actors=5 - ArrowFlightReadExec: input_actors=10 - ArrowFlightReadExec: input_actors=1 + ArrowFlightReadExec: input_tasks=5 hash_expr=[] stage_id=UUID input_stage_id=UUID input_hosts=[http://localhost:50050/, http://localhost:50051/, http://localhost:50053/, http://localhost:50054/, http://localhost:50055/] + ArrowFlightReadExec: input_tasks=10 hash_expr=[] stage_id=UUID input_stage_id=UUID input_hosts=[http://localhost:50056/, http://localhost:50057/, http://localhost:50058/, http://localhost:50059/, http://localhost:50050/, http://localhost:50051/, http://localhost:50053/, http://localhost:50054/, http://localhost:50055/, http://localhost:50056/] + ArrowFlightReadExec: input_tasks=1 hash_expr=[] stage_id=UUID input_stage_id=UUID input_hosts=[http://localhost:50057/] DataSourceExec: file_groups={1 group: [[/testdata/flights-1m.parquet]]}, projection=[FL_DATE, DEP_DELAY, ARR_DELAY, AIR_TIME, DISTANCE, DEP_TIME, ARR_TIME], file_type=parquet ", );