From 3b45cc2c6c69c0b36beae2dbfe87a8e3be7db0ed Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Mon, 18 Aug 2025 16:03:54 +0200 Subject: [PATCH 01/15] Add `ConfigExtensionExt`, allowing the propagation of arbitrary [ConfigExtension]s across network boundaries --- src/config_extension_ext.rs | 456 +++++++++++++++++++++++++++++++ src/flight_service/do_get.rs | 19 +- src/lib.rs | 2 + src/plan/arrow_flight_read.rs | 26 +- tests/custom_config_extension.rs | 180 ++++++++++++ 5 files changed, 677 insertions(+), 6 deletions(-) create mode 100644 src/config_extension_ext.rs create mode 100644 tests/custom_config_extension.rs diff --git a/src/config_extension_ext.rs b/src/config_extension_ext.rs new file mode 100644 index 0000000..3daf7f0 --- /dev/null +++ b/src/config_extension_ext.rs @@ -0,0 +1,456 @@ +use datafusion::common::{internal_datafusion_err, DataFusionError}; +use datafusion::config::ConfigExtension; +use datafusion::execution::{SessionState, SessionStateBuilder}; +use datafusion::prelude::{SessionConfig, SessionContext}; +use delegate::delegate; +use http::{HeaderMap, HeaderName}; +use std::error::Error; +use std::str::FromStr; +use std::sync::Arc; + +const FLIGHT_METADATA_PREFIX: &str = "x-datafusion-distributed-"; + +/// Extension trait for `SessionConfig` to add support for propagating [ConfigExtension]s across +/// network calls. +pub trait ConfigExtensionExt { + /// Adds the provided [ConfigExtension] to the distributed context. The [ConfigExtension] will + /// be serialized using gRPC metadata and sent across tasks. Users are expected to call this + /// method with their own extensions to be able to access them in any place in the + /// plan. + /// + /// This method also adds the provided [ConfigExtension] to the current session option + /// extensions, the same as calling [SessionConfig::with_option_extension]. + /// + /// Example: + /// + /// ```rust + /// # use async_trait::async_trait; + /// # use datafusion::common::{extensions_options, DataFusionError}; + /// # use datafusion::config::ConfigExtension; + /// # use datafusion::execution::SessionState; + /// # use datafusion::prelude::SessionConfig; + /// # use datafusion_distributed::{ConfigExtensionExt, SessionBuilder}; + /// + /// extensions_options! { + /// pub struct CustomExtension { + /// pub foo: String, default = "".to_string() + /// pub bar: usize, default = 0 + /// pub baz: bool, default = false + /// } + /// } + /// + /// impl ConfigExtension for CustomExtension { + /// const PREFIX: &'static str = "custom"; + /// } + /// + /// let mut config = SessionConfig::new(); + /// let mut opt = CustomExtension::default(); + /// // Now, the CustomExtension will be able to cross network boundaries. Upon making an Arrow + /// // Flight request, it will be sent through gRPC metadata. + /// config.add_distributed_option_extension(opt).unwrap(); + /// + /// struct MyCustomSessionBuilder; + /// + /// #[async_trait] + /// impl SessionBuilder for MyCustomSessionBuilder { + /// async fn session_state(&self, mut state: SessionState) -> Result { + /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will + /// // know how to deserialize the CustomExtension from the gRPC metadata. + /// state.propagate_distributed_option_extension::()?; + /// Ok(state) + /// } + /// } + /// ``` + fn add_distributed_option_extension( + &mut self, + t: T, + ) -> Result<(), DataFusionError>; + + /// Gets the specified [ConfigExtension] from the distributed context and adds it to + /// the [SessionConfig::options] extensions. The function will build a new [ConfigExtension] + /// out of the Arrow Flight gRPC metadata present in the [SessionConfig] and will propagate it + /// to the extension options. + /// Example: + /// + /// ```rust + /// # use async_trait::async_trait; + /// # use datafusion::common::{extensions_options, DataFusionError}; + /// # use datafusion::config::ConfigExtension; + /// # use datafusion::execution::SessionState; + /// # use datafusion::prelude::SessionConfig; + /// # use datafusion_distributed::{ConfigExtensionExt, SessionBuilder}; + /// + /// extensions_options! { + /// pub struct CustomExtension { + /// pub foo: String, default = "".to_string() + /// pub bar: usize, default = 0 + /// pub baz: bool, default = false + /// } + /// } + /// + /// impl ConfigExtension for CustomExtension { + /// const PREFIX: &'static str = "custom"; + /// } + /// + /// let mut config = SessionConfig::new(); + /// let mut opt = CustomExtension::default(); + /// // Now, the CustomExtension will be able to cross network boundaries. Upon making an Arrow + /// // Flight request, it will be sent through gRPC metadata. + /// config.add_distributed_option_extension(opt).unwrap(); + /// + /// struct MyCustomSessionBuilder; + /// + /// #[async_trait] + /// impl SessionBuilder for MyCustomSessionBuilder { + /// async fn session_state(&self, mut state: SessionState) -> Result { + /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will + /// // know how to deserialize the CustomExtension from the gRPC metadata. + /// state.propagate_distributed_option_extension::()?; + /// Ok(state) + /// } + /// } + /// ``` + fn propagate_distributed_option_extension( + &mut self, + ) -> Result<(), DataFusionError>; +} + +impl ConfigExtensionExt for SessionConfig { + fn add_distributed_option_extension( + &mut self, + t: T, + ) -> Result<(), DataFusionError> { + fn parse_err(err: impl Error) -> DataFusionError { + DataFusionError::Internal(format!("Failed to add config extension: {err}")) + } + let mut meta = HeaderMap::new(); + + for entry in t.entries() { + if let Some(value) = entry.value { + meta.insert( + HeaderName::from_str(&format!( + "{}{}.{}", + FLIGHT_METADATA_PREFIX, + T::PREFIX, + entry.key + )) + .map_err(parse_err)?, + value.parse().map_err(parse_err)?, + ); + } + } + let flight_metadata = ContextGrpcMetadata(meta); + match self.get_extension::() { + None => self.set_extension(Arc::new(flight_metadata)), + Some(prev) => { + let prev = prev.as_ref().clone(); + self.set_extension(Arc::new(prev.merge(flight_metadata))) + } + } + self.options_mut().extensions.insert(t); + Ok(()) + } + + fn propagate_distributed_option_extension( + &mut self, + ) -> Result<(), DataFusionError> { + let Some(flight_metadata) = self.get_extension::() else { + return Ok(()); + }; + + let mut result = T::default(); + let mut found_some = false; + for (k, v) in flight_metadata.0.iter() { + let key = k.as_str().trim_start_matches(FLIGHT_METADATA_PREFIX); + if key.starts_with(T::PREFIX) { + found_some = true; + result.set( + &key.trim_start_matches(T::PREFIX).trim_start_matches("."), + v.to_str().map_err(|err| { + internal_datafusion_err!("Cannot parse header value: {err}") + })?, + )?; + } + } + if !found_some { + return Ok(()); + } + self.options_mut().extensions.insert(result); + Ok(()) + } +} + +impl ConfigExtensionExt for SessionStateBuilder { + delegate! { + to self.config().get_or_insert_default() { + fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; + fn propagate_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + } + } +} + +impl ConfigExtensionExt for SessionState { + delegate! { + to self.config_mut() { + fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; + fn propagate_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + } + } +} + +impl ConfigExtensionExt for SessionContext { + delegate! { + to self.state_ref().write().config_mut() { + fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; + fn propagate_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + } + } +} + +#[derive(Clone, Debug, Default)] +pub(crate) struct ContextGrpcMetadata(pub HeaderMap); + +impl ContextGrpcMetadata { + pub(crate) fn from_headers(metadata: HeaderMap) -> Self { + let mut new = HeaderMap::new(); + for (k, v) in metadata.into_iter() { + let Some(k) = k else { continue }; + if k.as_str().starts_with(FLIGHT_METADATA_PREFIX) { + new.insert(k, v); + } + } + Self(new) + } + + fn merge(mut self, other: Self) -> Self { + for (k, v) in other.0.into_iter() { + let Some(k) = k else { continue }; + self.0.insert(k, v); + } + self + } +} + +#[cfg(test)] +mod tests { + use crate::config_extension_ext::ContextGrpcMetadata; + use crate::ConfigExtensionExt; + use datafusion::common::extensions_options; + use datafusion::config::ConfigExtension; + use datafusion::prelude::SessionConfig; + use http::{HeaderMap, HeaderName, HeaderValue}; + use std::str::FromStr; + + #[test] + fn test_propagation() -> Result<(), Box> { + let mut config = SessionConfig::new(); + + let mut opt = CustomExtension::default(); + opt.foo = "foo".to_string(); + opt.bar = 1; + opt.baz = true; + + config.add_distributed_option_extension(opt)?; + + let mut new_config = SessionConfig::new(); + new_config.set_extension(config.get_extension::().unwrap()); + new_config.propagate_distributed_option_extension::()?; + + let opt = get_ext::(&config); + let new_opt = get_ext::(&new_config); + + assert_eq!(new_opt.foo, opt.foo); + assert_eq!(new_opt.bar, opt.bar); + assert_eq!(new_opt.baz, opt.baz); + + Ok(()) + } + + #[test] + fn test_add_extension_with_empty_values() -> Result<(), Box> { + let mut config = SessionConfig::new(); + let opt = CustomExtension::default(); + + config.add_distributed_option_extension(opt)?; + + let flight_metadata = config.get_extension::(); + assert!(flight_metadata.is_some()); + + let metadata = &flight_metadata.unwrap().0; + assert!(metadata.contains_key("x-datafusion-distributed-custom.foo")); + assert!(metadata.contains_key("x-datafusion-distributed-custom.bar")); + assert!(metadata.contains_key("x-datafusion-distributed-custom.baz")); + + let get = |key: &str| metadata.get(key).unwrap().to_str().unwrap(); + assert_eq!(get("x-datafusion-distributed-custom.foo"), ""); + assert_eq!(get("x-datafusion-distributed-custom.bar"), "0"); + assert_eq!(get("x-datafusion-distributed-custom.baz"), "false"); + + Ok(()) + } + + #[test] + fn test_new_extension_overwrites_previous() -> Result<(), Box> { + let mut config = SessionConfig::new(); + + let mut opt1 = CustomExtension::default(); + opt1.foo = "first".to_string(); + config.add_distributed_option_extension(opt1)?; + + let mut opt2 = CustomExtension::default(); + opt2.bar = 42; + config.add_distributed_option_extension(opt2)?; + + let flight_metadata = config.get_extension::().unwrap(); + let metadata = &flight_metadata.0; + + let get = |key: &str| metadata.get(key).unwrap().to_str().unwrap(); + assert_eq!(get("x-datafusion-distributed-custom.foo"), ""); + assert_eq!(get("x-datafusion-distributed-custom.bar"), "42"); + assert_eq!(get("x-datafusion-distributed-custom.baz"), "false"); + + Ok(()) + } + + #[test] + fn test_propagate_no_metadata() -> Result<(), Box> { + let mut config = SessionConfig::new(); + + config.propagate_distributed_option_extension::()?; + + let extension = config.options().extensions.get::(); + assert!(extension.is_none()); + + Ok(()) + } + + #[test] + fn test_propagate_no_matching_prefix() -> Result<(), Box> { + let mut config = SessionConfig::new(); + let mut header_map = HeaderMap::new(); + header_map.insert( + HeaderName::from_str("x-datafusion-distributed-other.setting").unwrap(), + HeaderValue::from_str("value").unwrap(), + ); + + let flight_metadata = ContextGrpcMetadata::from_headers(header_map); + config.set_extension(std::sync::Arc::new(flight_metadata)); + config.propagate_distributed_option_extension::()?; + + let extension = config.options().extensions.get::(); + assert!(extension.is_none()); + + Ok(()) + } + + #[test] + fn test_multiple_extensions_different_prefixes() -> Result<(), Box> { + let mut config = SessionConfig::new(); + + let mut custom_opt = CustomExtension::default(); + custom_opt.foo = "custom_value".to_string(); + custom_opt.bar = 123; + + let mut another_opt = AnotherExtension::default(); + another_opt.setting1 = "other".to_string(); + another_opt.setting2 = 456; + + config.add_distributed_option_extension(custom_opt)?; + config.add_distributed_option_extension(another_opt)?; + + let flight_metadata = config.get_extension::().unwrap(); + let metadata = &flight_metadata.0; + + assert!(metadata.contains_key("x-datafusion-distributed-custom.foo")); + assert!(metadata.contains_key("x-datafusion-distributed-custom.bar")); + assert!(metadata.contains_key("x-datafusion-distributed-another.setting1")); + assert!(metadata.contains_key("x-datafusion-distributed-another.setting2")); + + let get = |key: &str| metadata.get(key).unwrap().to_str().unwrap(); + + assert_eq!(get("x-datafusion-distributed-custom.foo"), "custom_value"); + assert_eq!(get("x-datafusion-distributed-custom.bar"), "123"); + assert_eq!(get("x-datafusion-distributed-another.setting1"), "other"); + assert_eq!(get("x-datafusion-distributed-another.setting2"), "456"); + + let mut new_config = SessionConfig::new(); + new_config.set_extension(flight_metadata); + new_config.propagate_distributed_option_extension::()?; + new_config.propagate_distributed_option_extension::()?; + + let propagated_custom = get_ext::(&new_config); + let propagated_another = get_ext::(&new_config); + + assert_eq!(propagated_custom.foo, "custom_value"); + assert_eq!(propagated_custom.bar, 123); + assert_eq!(propagated_another.setting1, "other"); + assert_eq!(propagated_another.setting2, 456); + + Ok(()) + } + + #[test] + fn test_invalid_header_name() { + let mut config = SessionConfig::new(); + let extension = InvalidExtension::default(); + + let result = config.add_distributed_option_extension(extension); + assert!(result.is_err()); + } + + #[test] + fn test_invalid_header_value() { + let mut config = SessionConfig::new(); + let extension = InvalidValueExtension::default(); + + let result = config.add_distributed_option_extension(extension); + assert!(result.is_err()); + } + + extensions_options! { + pub struct CustomExtension { + pub foo: String, default = "".to_string() + pub bar: usize, default = 0 + pub baz: bool, default = false + } + } + + impl ConfigExtension for CustomExtension { + const PREFIX: &'static str = "custom"; + } + + extensions_options! { + pub struct AnotherExtension { + pub setting1: String, default = "default1".to_string() + pub setting2: usize, default = 42 + } + } + + impl ConfigExtension for AnotherExtension { + const PREFIX: &'static str = "another"; + } + + extensions_options! { + pub struct InvalidExtension { + pub key_with_spaces: String, default = "value".to_string() + } + } + + impl ConfigExtension for InvalidExtension { + const PREFIX: &'static str = "invalid key with spaces"; + } + + extensions_options! { + pub struct InvalidValueExtension { + pub key: String, default = "\u{0000}invalid\u{0001}".to_string() + } + } + + impl ConfigExtension for InvalidValueExtension { + const PREFIX: &'static str = "invalid_value"; + } + + fn get_ext(cfg: &SessionConfig) -> &T { + cfg.options().extensions.get::().unwrap() + } +} diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index 38b7d25..2c14bba 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -1,4 +1,6 @@ +use super::service::StageKey; use crate::composed_extension_codec::ComposedPhysicalExtensionCodec; +use crate::config_extension_ext::ContextGrpcMetadata; use crate::errors::datafusion_error_to_tonic_status; use crate::flight_service::service::ArrowFlightEndpoint; use crate::plan::{DistributedCodec, PartitionGroup}; @@ -10,13 +12,13 @@ use arrow_flight::flight_service_server::FlightService; use arrow_flight::Ticket; use datafusion::execution::{SessionState, SessionStateBuilder}; use datafusion::optimizer::OptimizerConfig; +use datafusion::prelude::SessionConfig; use futures::TryStreamExt; use prost::Message; use std::sync::Arc; +use tonic::metadata::MetadataMap; use tonic::{Request, Response, Status}; -use super::service::StageKey; - #[derive(Clone, PartialEq, ::prost::Message)] pub struct DoGet { /// The ExecutionStage that we are going to execute @@ -41,14 +43,15 @@ impl ArrowFlightEndpoint { &self, request: Request, ) -> Result::DoGetStream>, Status> { - let Ticket { ticket } = request.into_inner(); + let (metadata, _ext, ticket) = request.into_parts(); + let Ticket { ticket } = ticket; let doget = DoGet::decode(ticket).map_err(|err| { Status::invalid_argument(format!("Cannot decode DoGet message: {err}")) })?; let partition = doget.partition as usize; let task_number = doget.task_number as usize; - let (mut state, stage) = self.get_state_and_stage(doget).await?; + let (mut state, stage) = self.get_state_and_stage(doget, metadata).await?; // find out which partition group we are executing let task = stage @@ -87,6 +90,7 @@ impl ArrowFlightEndpoint { async fn get_state_and_stage( &self, doget: DoGet, + metadata: MetadataMap, ) -> Result<(SessionState, Arc), Status> { let key = doget .stage_key @@ -102,9 +106,16 @@ impl ArrowFlightEndpoint { .stage_proto .ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?; + let mut config = SessionConfig::default(); + config.set_extension(Arc::new(ContextGrpcMetadata::from_headers( + metadata.into_headers(), + ))); + let state_builder = SessionStateBuilder::new() .with_runtime_env(Arc::clone(&self.runtime)) + .with_config(config) .with_default_features(); + let state_builder = self .session_builder .session_state_builder(state_builder) diff --git a/src/lib.rs b/src/lib.rs index 826bdf7..40515a9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ mod channel_manager; mod common; mod composed_extension_codec; +mod config_extension_ext; mod errors; mod flight_service; mod physical_optimizer; @@ -15,6 +16,7 @@ mod user_provided_codec; pub mod test_utils; pub use channel_manager::{BoxCloneSyncChannel, ChannelManager, ChannelResolver}; +pub use config_extension_ext::ConfigExtensionExt; pub use flight_service::{ArrowFlightEndpoint, NoopSessionBuilder, SessionBuilder}; pub use physical_optimizer::DistributedPhysicalOptimizerRule; pub use plan::ArrowFlightReadExec; diff --git a/src/plan/arrow_flight_read.rs b/src/plan/arrow_flight_read.rs index 62c5e76..bf3cb0b 100644 --- a/src/plan/arrow_flight_read.rs +++ b/src/plan/arrow_flight_read.rs @@ -1,6 +1,7 @@ use super::combined::CombinedRecordBatchStream; use crate::channel_manager::ChannelManager; use crate::composed_extension_codec::ComposedPhysicalExtensionCodec; +use crate::config_extension_ext::ContextGrpcMetadata; use crate::errors::tonic_status_to_datafusion_error; use crate::flight_service::{DoGet, StageKey}; use crate::plan::DistributedCodec; @@ -19,10 +20,13 @@ use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; use futures::{future, TryFutureExt, TryStreamExt}; +use http::Extensions; use prost::Message; use std::any::Any; use std::fmt::Formatter; use std::sync::Arc; +use tonic::metadata::MetadataMap; +use tonic::Request; use url::Url; /// This node has two variants. @@ -173,6 +177,10 @@ impl ExecutionPlan for ArrowFlightReadExec { this.stage_num ))?; + let flight_metadata = context + .session_config() + .get_extension::(); + let mut combined_codec = ComposedPhysicalExtensionCodec::default(); combined_codec.push(DistributedCodec {}); if let Some(ref user_codec) = get_user_codec(context.session_config()) { @@ -195,6 +203,7 @@ impl ExecutionPlan for ArrowFlightReadExec { let channel_manager_capture = channel_manager.clone(); let schema = schema.clone(); let query_id = query_id.clone(); + let flight_metadata = flight_metadata.clone().unwrap_or_default().as_ref().clone(); let key = StageKey { query_id, stage_id: child_stage_num, @@ -218,8 +227,14 @@ impl ExecutionPlan for ArrowFlightReadExec { ticket: ticket_bytes, }; - stream_from_stage_task(ticket, &url, schema.clone(), &channel_manager_capture) - .await + stream_from_stage_task( + ticket, + flight_metadata, + &url, + schema.clone(), + &channel_manager_capture, + ) + .await } }); @@ -240,12 +255,19 @@ impl ExecutionPlan for ArrowFlightReadExec { async fn stream_from_stage_task( ticket: Ticket, + metadata: ContextGrpcMetadata, url: &Url, schema: SchemaRef, channel_manager: &ChannelManager, ) -> Result { let channel = channel_manager.get_channel_for_url(url).await?; + let ticket = Request::from_parts( + MetadataMap::from_headers(metadata.0), + Extensions::default(), + ticket, + ); + let mut client = FlightServiceClient::new(channel); let stream = client .do_get(ticket) diff --git a/tests/custom_config_extension.rs b/tests/custom_config_extension.rs new file mode 100644 index 0000000..33f3fb9 --- /dev/null +++ b/tests/custom_config_extension.rs @@ -0,0 +1,180 @@ +#[cfg(all(feature = "integration", test))] +mod tests { + use async_trait::async_trait; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::common::{extensions_options, internal_err}; + use datafusion::config::ConfigExtension; + use datafusion::error::DataFusionError; + use datafusion::execution::{ + FunctionRegistry, SendableRecordBatchStream, SessionState, TaskContext, + }; + use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; + use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; + use datafusion::physical_plan::stream::RecordBatchStreamAdapter; + use datafusion::physical_plan::{ + execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, + }; + use datafusion_distributed::test_utils::localhost::start_localhost_context; + use datafusion_distributed::{add_user_codec, ConfigExtensionExt}; + use datafusion_distributed::{ + ArrowFlightReadExec, DistributedPhysicalOptimizerRule, SessionBuilder, + }; + use datafusion_proto::physical_plan::PhysicalExtensionCodec; + use futures::TryStreamExt; + use prost::Message; + use std::any::Any; + use std::fmt::Formatter; + use std::sync::Arc; + + #[tokio::test] + async fn custom_config_extension() -> Result<(), Box> { + #[derive(Clone)] + struct CustomSessionBuilder; + + #[async_trait] + impl SessionBuilder for CustomSessionBuilder { + async fn session_state( + &self, + mut state: SessionState, + ) -> Result { + state.propagate_distributed_option_extension::()?; + add_user_codec(state.config_mut(), CustomConfigExtensionRequiredExecCodec); + Ok(state) + } + } + + let (mut ctx, _guard) = start_localhost_context(3, CustomSessionBuilder).await; + add_user_codec(&mut ctx, CustomConfigExtensionRequiredExecCodec); + ctx.add_distributed_option_extension(CustomExtension { + foo: "foo".to_string(), + bar: 1, + baz: true, + })?; + + let mut plan: Arc = Arc::new(CustomConfigExtensionRequiredExec::new()); + + for size in [1, 2, 3] { + plan = Arc::new(ArrowFlightReadExec::new_pending( + plan, + Partitioning::RoundRobinBatch(size), + )); + } + + let plan = DistributedPhysicalOptimizerRule::default().distribute_plan(plan)?; + let stream = execute_stream(Arc::new(plan), ctx.task_ctx())?; + // It should not fail. + stream.try_collect::>().await?; + + Ok(()) + } + + extensions_options! { + pub struct CustomExtension { + pub foo: String, default = "".to_string() + pub bar: usize, default = 0 + pub baz: bool, default = false + } + } + + impl ConfigExtension for CustomExtension { + const PREFIX: &'static str = "custom"; + } + + #[derive(Debug)] + pub struct CustomConfigExtensionRequiredExec { + plan_properties: PlanProperties, + } + + impl CustomConfigExtensionRequiredExec { + fn new() -> Self { + let schema = Schema::new(vec![Field::new("numbers", DataType::Int64, false)]); + Self { + plan_properties: PlanProperties::new( + EquivalenceProperties::new(Arc::new(schema)), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ), + } + } + } + + impl DisplayAs for CustomConfigExtensionRequiredExec { + fn fmt_as(&self, _: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!(f, "CustomConfigExtensionRequiredExec") + } + } + + impl ExecutionPlan for CustomConfigExtensionRequiredExec { + fn name(&self) -> &str { + "CustomConfigExtensionRequiredExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.plan_properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> datafusion::common::Result> { + Ok(self) + } + + fn execute( + &self, + _: usize, + ctx: Arc, + ) -> datafusion::common::Result { + if ctx + .session_config() + .options() + .extensions + .get::() + .is_none() + { + return internal_err!("CustomExtension not found in context"); + } + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + futures::stream::empty(), + ))) + } + } + + #[derive(Debug)] + struct CustomConfigExtensionRequiredExecCodec; + + #[derive(Clone, PartialEq, ::prost::Message)] + struct CustomConfigExtensionRequiredExecProto {} + + impl PhysicalExtensionCodec for CustomConfigExtensionRequiredExecCodec { + fn try_decode( + &self, + _buf: &[u8], + _: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> datafusion::common::Result> { + Ok(Arc::new(CustomConfigExtensionRequiredExec::new())) + } + + fn try_encode( + &self, + _node: Arc, + buf: &mut Vec, + ) -> datafusion::common::Result<()> { + CustomConfigExtensionRequiredExecProto::default() + .encode(buf) + .unwrap(); + Ok(()) + } + } +} From ae8b6c97d9ef312804034545014356628b006cce Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Tue, 19 Aug 2025 16:53:05 +0200 Subject: [PATCH 02/15] Rename propagate_distributed_option_extension to retrieve_distributed_option_extension --- src/config_extension_ext.rs | 24 ++++++++++++------------ tests/custom_config_extension.rs | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/config_extension_ext.rs b/src/config_extension_ext.rs index 3daf7f0..449a28a 100644 --- a/src/config_extension_ext.rs +++ b/src/config_extension_ext.rs @@ -56,7 +56,7 @@ pub trait ConfigExtensionExt { /// async fn session_state(&self, mut state: SessionState) -> Result { /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will /// // know how to deserialize the CustomExtension from the gRPC metadata. - /// state.propagate_distributed_option_extension::()?; + /// state.retrieve_distributed_option_extension::()?; /// Ok(state) /// } /// } @@ -105,12 +105,12 @@ pub trait ConfigExtensionExt { /// async fn session_state(&self, mut state: SessionState) -> Result { /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will /// // know how to deserialize the CustomExtension from the gRPC metadata. - /// state.propagate_distributed_option_extension::()?; + /// state.retrieve_distributed_option_extension::()?; /// Ok(state) /// } /// } /// ``` - fn propagate_distributed_option_extension( + fn retrieve_distributed_option_extension( &mut self, ) -> Result<(), DataFusionError>; } @@ -151,7 +151,7 @@ impl ConfigExtensionExt for SessionConfig { Ok(()) } - fn propagate_distributed_option_extension( + fn retrieve_distributed_option_extension( &mut self, ) -> Result<(), DataFusionError> { let Some(flight_metadata) = self.get_extension::() else { @@ -184,7 +184,7 @@ impl ConfigExtensionExt for SessionStateBuilder { delegate! { to self.config().get_or_insert_default() { fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; - fn propagate_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + fn retrieve_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; } } } @@ -193,7 +193,7 @@ impl ConfigExtensionExt for SessionState { delegate! { to self.config_mut() { fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; - fn propagate_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + fn retrieve_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; } } } @@ -202,7 +202,7 @@ impl ConfigExtensionExt for SessionContext { delegate! { to self.state_ref().write().config_mut() { fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; - fn propagate_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + fn retrieve_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; } } } @@ -254,7 +254,7 @@ mod tests { let mut new_config = SessionConfig::new(); new_config.set_extension(config.get_extension::().unwrap()); - new_config.propagate_distributed_option_extension::()?; + new_config.retrieve_distributed_option_extension::()?; let opt = get_ext::(&config); let new_opt = get_ext::(&new_config); @@ -316,7 +316,7 @@ mod tests { fn test_propagate_no_metadata() -> Result<(), Box> { let mut config = SessionConfig::new(); - config.propagate_distributed_option_extension::()?; + config.retrieve_distributed_option_extension::()?; let extension = config.options().extensions.get::(); assert!(extension.is_none()); @@ -335,7 +335,7 @@ mod tests { let flight_metadata = ContextGrpcMetadata::from_headers(header_map); config.set_extension(std::sync::Arc::new(flight_metadata)); - config.propagate_distributed_option_extension::()?; + config.retrieve_distributed_option_extension::()?; let extension = config.options().extensions.get::(); assert!(extension.is_none()); @@ -375,8 +375,8 @@ mod tests { let mut new_config = SessionConfig::new(); new_config.set_extension(flight_metadata); - new_config.propagate_distributed_option_extension::()?; - new_config.propagate_distributed_option_extension::()?; + new_config.retrieve_distributed_option_extension::()?; + new_config.retrieve_distributed_option_extension::()?; let propagated_custom = get_ext::(&new_config); let propagated_another = get_ext::(&new_config); diff --git a/tests/custom_config_extension.rs b/tests/custom_config_extension.rs index 33f3fb9..d671d90 100644 --- a/tests/custom_config_extension.rs +++ b/tests/custom_config_extension.rs @@ -37,7 +37,7 @@ mod tests { &self, mut state: SessionState, ) -> Result { - state.propagate_distributed_option_extension::()?; + state.retrieve_distributed_option_extension::()?; add_user_codec(state.config_mut(), CustomConfigExtensionRequiredExecCodec); Ok(state) } From 9ac83b0b5b4d45464fcb6cfe8fe0e66650a5d6ef Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Tue, 19 Aug 2025 16:55:52 +0200 Subject: [PATCH 03/15] Change x-datafusion-distributed- to x-datafusion-distributed-config- --- src/config_extension_ext.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/config_extension_ext.rs b/src/config_extension_ext.rs index 449a28a..6476ce8 100644 --- a/src/config_extension_ext.rs +++ b/src/config_extension_ext.rs @@ -8,7 +8,7 @@ use std::error::Error; use std::str::FromStr; use std::sync::Arc; -const FLIGHT_METADATA_PREFIX: &str = "x-datafusion-distributed-"; +const FLIGHT_METADATA_CONFIG_PREFIX: &str = "x-datafusion-distributed-config-"; /// Extension trait for `SessionConfig` to add support for propagating [ConfigExtension]s across /// network calls. @@ -130,7 +130,7 @@ impl ConfigExtensionExt for SessionConfig { meta.insert( HeaderName::from_str(&format!( "{}{}.{}", - FLIGHT_METADATA_PREFIX, + FLIGHT_METADATA_CONFIG_PREFIX, T::PREFIX, entry.key )) @@ -161,7 +161,7 @@ impl ConfigExtensionExt for SessionConfig { let mut result = T::default(); let mut found_some = false; for (k, v) in flight_metadata.0.iter() { - let key = k.as_str().trim_start_matches(FLIGHT_METADATA_PREFIX); + let key = k.as_str().trim_start_matches(FLIGHT_METADATA_CONFIG_PREFIX); if key.starts_with(T::PREFIX) { found_some = true; result.set( @@ -215,7 +215,7 @@ impl ContextGrpcMetadata { let mut new = HeaderMap::new(); for (k, v) in metadata.into_iter() { let Some(k) = k else { continue }; - if k.as_str().starts_with(FLIGHT_METADATA_PREFIX) { + if k.as_str().starts_with(FLIGHT_METADATA_CONFIG_PREFIX) { new.insert(k, v); } } From e15796e3e97dadec31eaf148f07fd7f965ffb26b Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Tue, 19 Aug 2025 16:59:35 +0200 Subject: [PATCH 04/15] Check for full format!("{}.", T::PREFIX) in gRPC keys --- src/config_extension_ext.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/config_extension_ext.rs b/src/config_extension_ext.rs index 6476ce8..5a26e15 100644 --- a/src/config_extension_ext.rs +++ b/src/config_extension_ext.rs @@ -162,10 +162,11 @@ impl ConfigExtensionExt for SessionConfig { let mut found_some = false; for (k, v) in flight_metadata.0.iter() { let key = k.as_str().trim_start_matches(FLIGHT_METADATA_CONFIG_PREFIX); - if key.starts_with(T::PREFIX) { + let prefix = format!("{}.", T::PREFIX); + if key.starts_with(&prefix) { found_some = true; result.set( - &key.trim_start_matches(T::PREFIX).trim_start_matches("."), + &key.trim_start_matches(&prefix), v.to_str().map_err(|err| { internal_datafusion_err!("Cannot parse header value: {err}") })?, From 104e3fa46da5b1af78396d9fd14a7ef49fcf7377 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Tue, 19 Aug 2025 17:04:24 +0200 Subject: [PATCH 05/15] Remove double clone --- src/plan/arrow_flight_read.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/plan/arrow_flight_read.rs b/src/plan/arrow_flight_read.rs index bf3cb0b..36d7786 100644 --- a/src/plan/arrow_flight_read.rs +++ b/src/plan/arrow_flight_read.rs @@ -203,7 +203,10 @@ impl ExecutionPlan for ArrowFlightReadExec { let channel_manager_capture = channel_manager.clone(); let schema = schema.clone(); let query_id = query_id.clone(); - let flight_metadata = flight_metadata.clone().unwrap_or_default().as_ref().clone(); + let flight_metadata = flight_metadata + .as_ref() + .map(|v| v.as_ref().clone()) + .unwrap_or_default(); let key = StageKey { query_id, stage_id: child_stage_num, From 8576c2c64fc10f881d75a0c2e9d9972b1352aae8 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Tue, 19 Aug 2025 17:23:36 +0200 Subject: [PATCH 06/15] Fix tests --- src/config_extension_ext.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/config_extension_ext.rs b/src/config_extension_ext.rs index 5a26e15..22beb83 100644 --- a/src/config_extension_ext.rs +++ b/src/config_extension_ext.rs @@ -283,9 +283,9 @@ mod tests { assert!(metadata.contains_key("x-datafusion-distributed-custom.baz")); let get = |key: &str| metadata.get(key).unwrap().to_str().unwrap(); - assert_eq!(get("x-datafusion-distributed-custom.foo"), ""); - assert_eq!(get("x-datafusion-distributed-custom.bar"), "0"); - assert_eq!(get("x-datafusion-distributed-custom.baz"), "false"); + assert_eq!(get("x-datafusion-distributed-config-custom.foo"), ""); + assert_eq!(get("x-datafusion-distributed-config-custom.bar"), "0"); + assert_eq!(get("x-datafusion-distributed-config-custom.baz"), "false"); Ok(()) } @@ -306,9 +306,9 @@ mod tests { let metadata = &flight_metadata.0; let get = |key: &str| metadata.get(key).unwrap().to_str().unwrap(); - assert_eq!(get("x-datafusion-distributed-custom.foo"), ""); - assert_eq!(get("x-datafusion-distributed-custom.bar"), "42"); - assert_eq!(get("x-datafusion-distributed-custom.baz"), "false"); + assert_eq!(get("x-datafusion-distributed-config-custom.foo"), ""); + assert_eq!(get("x-datafusion-distributed-config-custom.bar"), "42"); + assert_eq!(get("x-datafusion-distributed-config-custom.baz"), "false"); Ok(()) } From d35ddff726f5f2f27cee451334ba93d9f1fbe761 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Tue, 19 Aug 2025 17:24:28 +0200 Subject: [PATCH 07/15] Fix tests --- src/config_extension_ext.rs | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/src/config_extension_ext.rs b/src/config_extension_ext.rs index 22beb83..41f5141 100644 --- a/src/config_extension_ext.rs +++ b/src/config_extension_ext.rs @@ -278,9 +278,9 @@ mod tests { assert!(flight_metadata.is_some()); let metadata = &flight_metadata.unwrap().0; - assert!(metadata.contains_key("x-datafusion-distributed-custom.foo")); - assert!(metadata.contains_key("x-datafusion-distributed-custom.bar")); - assert!(metadata.contains_key("x-datafusion-distributed-custom.baz")); + assert!(metadata.contains_key("x-datafusion-distributed-config-custom.foo")); + assert!(metadata.contains_key("x-datafusion-distributed-config-custom.bar")); + assert!(metadata.contains_key("x-datafusion-distributed-config-custom.baz")); let get = |key: &str| metadata.get(key).unwrap().to_str().unwrap(); assert_eq!(get("x-datafusion-distributed-config-custom.foo"), ""); @@ -362,17 +362,26 @@ mod tests { let flight_metadata = config.get_extension::().unwrap(); let metadata = &flight_metadata.0; - assert!(metadata.contains_key("x-datafusion-distributed-custom.foo")); - assert!(metadata.contains_key("x-datafusion-distributed-custom.bar")); - assert!(metadata.contains_key("x-datafusion-distributed-another.setting1")); - assert!(metadata.contains_key("x-datafusion-distributed-another.setting2")); + assert!(metadata.contains_key("x-datafusion-distributed-config-custom.foo")); + assert!(metadata.contains_key("x-datafusion-distributed-config-custom.bar")); + assert!(metadata.contains_key("x-datafusion-distributed-config-another.setting1")); + assert!(metadata.contains_key("x-datafusion-distributed-config-another.setting2")); let get = |key: &str| metadata.get(key).unwrap().to_str().unwrap(); - assert_eq!(get("x-datafusion-distributed-custom.foo"), "custom_value"); - assert_eq!(get("x-datafusion-distributed-custom.bar"), "123"); - assert_eq!(get("x-datafusion-distributed-another.setting1"), "other"); - assert_eq!(get("x-datafusion-distributed-another.setting2"), "456"); + assert_eq!( + get("x-datafusion-distributed-config-custom.foo"), + "custom_value" + ); + assert_eq!(get("x-datafusion-distributed-config-custom.bar"), "123"); + assert_eq!( + get("x-datafusion-distributed-config-another.setting1"), + "other" + ); + assert_eq!( + get("x-datafusion-distributed-config-another.setting2"), + "456" + ); let mut new_config = SessionConfig::new(); new_config.set_extension(flight_metadata); From 210ea86f5061683bcc31704bad8709bd37ec28ef Mon Sep 17 00:00:00 2001 From: Rob Tandy Date: Wed, 20 Aug 2025 13:10:49 -0400 Subject: [PATCH 08/15] allow collect left hashjoins --- benchmarks/gen-tpch.sh | 17 +- benchmarks/src/tpch/run.rs | 26 +-- src/common/ttl_map.rs | 2 +- src/common/util.rs | 38 ++++ src/physical_optimizer.rs | 27 ++- src/stage/execution_stage.rs | 1 + src/test_utils/tpch.rs | 19 +- tests/non_distributed_consistency_test.rs | 248 ---------------------- 8 files changed, 90 insertions(+), 288 deletions(-) delete mode 100644 tests/non_distributed_consistency_test.rs diff --git a/benchmarks/gen-tpch.sh b/benchmarks/gen-tpch.sh index 98ec9c8..0e9de99 100755 --- a/benchmarks/gen-tpch.sh +++ b/benchmarks/gen-tpch.sh @@ -2,14 +2,14 @@ set -e -SCALE_FACTOR=1 +SCALE_FACTOR=10 # https://stackoverflow.com/questions/59895/how-do-i-get-the-directory-where-a-bash-script-is-located-from-within-the-script -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd) DATA_DIR=${DATA_DIR:-$SCRIPT_DIR/data} CARGO_COMMAND=${CARGO_COMMAND:-"cargo run --release"} -if [ -z "$SCALE_FACTOR" ] ; then +if [ -z "$SCALE_FACTOR" ]; then echo "Internal error: Scale factor not specified" exit 1 fi @@ -36,7 +36,7 @@ if test -f "${FILE}"; then else echo " Copying answers to ${TPCH_DIR}/answers" mkdir -p "${TPCH_DIR}/answers" - docker run -v "${TPCH_DIR}":/data -it --entrypoint /bin/bash --rm ghcr.io/scalytics/tpch-docker:main -c "cp -f /opt/tpch/2.18.0_rc2/dbgen/answers/* /data/answers/" + docker run -v "${TPCH_DIR}":/data -it --entrypoint /bin/bash --rm ghcr.io/scalytics/tpch-docker:main -c "cp -f /opt/tpch/2.18.0_rc2/dbgen/answers/* /data/answers/" fi # Create 'parquet' files from tbl @@ -45,9 +45,9 @@ if test -d "${FILE}"; then echo " parquet files exist ($FILE exists)." else echo " creating parquet files using benchmark binary ..." - pushd "${SCRIPT_DIR}" > /dev/null + pushd "${SCRIPT_DIR}" >/dev/null $CARGO_COMMAND -- tpch-convert --input "${TPCH_DIR}" --output "${TPCH_DIR}" --format parquet - popd > /dev/null + popd >/dev/null fi # Create 'csv' files from tbl @@ -56,8 +56,7 @@ if test -d "${FILE}"; then echo " csv files exist ($FILE exists)." else echo " creating csv files using benchmark binary ..." - pushd "${SCRIPT_DIR}" > /dev/null + pushd "${SCRIPT_DIR}" >/dev/null $CARGO_COMMAND -- tpch-convert --input "${TPCH_DIR}" --output "${TPCH_DIR}/csv" --format csv - popd > /dev/null + popd >/dev/null fi - diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index 507e408..158fab2 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -121,19 +121,19 @@ impl SessionBuilder for RunOpt { .with_collect_statistics(!self.disable_statistics) .with_target_partitions(self.partitions()); - // FIXME: these three options are critical for the correct function of the library - // but we are not enforcing that the user sets them. They are here at the moment - // but we should figure out a way to do this better. - config - .options_mut() - .optimizer - .hash_join_single_partition_threshold = 0; - config - .options_mut() - .optimizer - .hash_join_single_partition_threshold_rows = 0; - - config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; + // // FIXME: these three options are critical for the correct function of the library + // // but we are not enforcing that the user sets them. They are here at the moment + // // but we should figure out a way to do this better. + // config + // .options_mut() + // .optimizer + // .hash_join_single_partition_threshold = 0; + // config + // .options_mut() + // .optimizer + // .hash_join_single_partition_threshold_rows = 0; + // + // config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; // end critical options section let rt_builder = self.common.runtime_env_builder()?; diff --git a/src/common/ttl_map.rs b/src/common/ttl_map.rs index abfd377..011d647 100644 --- a/src/common/ttl_map.rs +++ b/src/common/ttl_map.rs @@ -94,7 +94,7 @@ where shard.insert(key); } BucketOp::Clear => { - let keys_to_delete = mem::replace(&mut shard, HashSet::new()); + let keys_to_delete = mem::take(&mut shard); for key in keys_to_delete { data.remove(&key); } diff --git a/src/common/util.rs b/src/common/util.rs index 085c5c2..effe69f 100644 --- a/src/common/util.rs +++ b/src/common/util.rs @@ -1,7 +1,10 @@ +use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion::error::Result; +use datafusion::physical_plan::joins::PartitionMode; use datafusion::physical_plan::{displayable, ExecutionPlan, ExecutionPlanProperties}; use std::fmt::Write; +use std::sync::Arc; pub fn display_plan_with_partition_in_out(plan: &dyn ExecutionPlan) -> Result { let mut f = String::new(); @@ -34,3 +37,38 @@ pub fn display_plan_with_partition_in_out(plan: &dyn ExecutionPlan) -> Result) -> Result { + // recursively check to see if this stages plan contains a NestedLoopJoinExec + let mut has_unsplittable_plan = false; + let search = |f: &Arc| { + if f.as_any() + .downcast_ref::() + .is_some() + { + has_unsplittable_plan = true; + return Ok(TreeNodeRecursion::Stop); + } else if let Some(hash_join) = f + .as_any() + .downcast_ref::() + { + if hash_join.partition_mode() != &PartitionMode::Partitioned { + has_unsplittable_plan = true; + return Ok(TreeNodeRecursion::Stop); + } + } + + Ok(TreeNodeRecursion::Continue) + }; + plan.apply(search)?; + Ok(!has_unsplittable_plan) +} diff --git a/src/physical_optimizer.rs b/src/physical_optimizer.rs index 19a7296..dbc3061 100644 --- a/src/physical_optimizer.rs +++ b/src/physical_optimizer.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use super::stage::ExecutionStage; +use crate::common::util::can_be_divided; use crate::{plan::PartitionIsolatorExec, ArrowFlightReadExec}; use datafusion::common::tree_node::TreeNodeRecursion; use datafusion::error::DataFusionError; @@ -83,12 +84,14 @@ impl DistributedPhysicalOptimizerRule { internal_datafusion_err!("Expected RepartitionExec to have a child"), )?); - let maybe_isolated_plan = if let Some(ppt) = self.partitions_per_task { - let isolated = Arc::new(PartitionIsolatorExec::new(child, ppt)); - plan.with_new_children(vec![isolated])? - } else { - plan - }; + let maybe_isolated_plan = + if can_be_divided(&plan)? && self.partitions_per_task.is_some() { + let ppt = self.partitions_per_task.unwrap(); + let isolated = Arc::new(PartitionIsolatorExec::new(child, ppt)); + plan.with_new_children(vec![isolated])? + } else { + plan + }; return Ok(Transformed::yes(Arc::new( ArrowFlightReadExec::new_pending( @@ -120,7 +123,7 @@ impl DistributedPhysicalOptimizerRule { ) -> Result { let mut inputs = vec![]; - let distributed = plan.transform_down(|plan| { + let distributed = plan.clone().transform_down(|plan| { let Some(node) = plan.as_any().downcast_ref::() else { return Ok(Transformed::no(plan)); }; @@ -137,9 +140,13 @@ impl DistributedPhysicalOptimizerRule { let mut stage = ExecutionStage::new(query_id, *num, distributed.data, inputs); *num += 1; - if let Some(partitions_per_task) = self.partitions_per_task { - stage = stage.with_maximum_partitions_per_task(partitions_per_task); - } + stage = match (self.partitions_per_task, can_be_divided(&plan)?) { + (Some(partitions_per_task), true) => { + stage.with_maximum_partitions_per_task(partitions_per_task) + } + (_, _) => stage, + }; + stage.depth = depth; Ok(stage) diff --git a/src/stage/execution_stage.rs b/src/stage/execution_stage.rs index d66591c..619ed67 100644 --- a/src/stage/execution_stage.rs +++ b/src/stage/execution_stage.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use datafusion::common::internal_err; +use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::TaskContext; use datafusion::physical_plan::ExecutionPlan; diff --git a/src/test_utils/tpch.rs b/src/test_utils/tpch.rs index b10787c..e00551d 100644 --- a/src/test_utils/tpch.rs +++ b/src/test_utils/tpch.rs @@ -160,13 +160,18 @@ where macro_rules! must_generate_tpch_table { ($generator:ident, $arrow:ident, $name:literal, $data_dir:expr) => { - generate_table( - // TODO: Consider adjusting the partitions and batch sizes. - $arrow::new($generator::new(SCALE_FACTOR, 1, 1)).with_batch_size(1000), - $name, - $data_dir, - ) - .expect(concat!("Failed to generate ", $name, " table")); + let data_dir = $data_dir.join(format!("{}.parquet", $name)); + fs::create_dir_all(data_dir.clone()).expect("Failed to create data directory"); + // create three partitions for the table + (1..=3).for_each(|part| { + generate_table( + // TODO: Consider adjusting the partitions and batch sizes. + $arrow::new($generator::new(SCALE_FACTOR, part, 3)).with_batch_size(1000), + &format!("{}.parquet", part), + &data_dir.clone().into_boxed_path(), + ) + .expect(concat!("Failed to generate ", $name, " table")); + }); }; } diff --git a/tests/non_distributed_consistency_test.rs b/tests/non_distributed_consistency_test.rs deleted file mode 100644 index ea2f395..0000000 --- a/tests/non_distributed_consistency_test.rs +++ /dev/null @@ -1,248 +0,0 @@ -mod common; - -#[cfg(all(feature = "integration", test))] -mod tests { - use crate::common::{ensure_tpch_data, get_test_data_dir, get_test_tpch_query}; - use async_trait::async_trait; - use datafusion::error::DataFusionError; - use datafusion::execution::SessionStateBuilder; - use datafusion::prelude::{SessionConfig, SessionContext}; - use datafusion_distributed::test_utils::localhost::start_localhost_context; - use datafusion_distributed::{DistributedPhysicalOptimizerRule, SessionBuilder}; - use futures::TryStreamExt; - use std::error::Error; - use std::sync::Arc; - - #[tokio::test] - async fn test_tpch_1() -> Result<(), Box> { - test_tpch_query(1).await - } - - #[tokio::test] - async fn test_tpch_2() -> Result<(), Box> { - test_tpch_query(2).await - } - - #[tokio::test] - async fn test_tpch_3() -> Result<(), Box> { - test_tpch_query(3).await - } - - #[tokio::test] - async fn test_tpch_4() -> Result<(), Box> { - test_tpch_query(4).await - } - - #[tokio::test] - async fn test_tpch_5() -> Result<(), Box> { - test_tpch_query(5).await - } - - #[tokio::test] - async fn test_tpch_6() -> Result<(), Box> { - test_tpch_query(6).await - } - - #[tokio::test] - async fn test_tpch_7() -> Result<(), Box> { - test_tpch_query(7).await - } - - #[tokio::test] - async fn test_tpch_8() -> Result<(), Box> { - test_tpch_query(8).await - } - - #[tokio::test] - async fn test_tpch_9() -> Result<(), Box> { - test_tpch_query(9).await - } - - #[tokio::test] - async fn test_tpch_10() -> Result<(), Box> { - test_tpch_query(10).await - } - - #[tokio::test] - async fn test_tpch_11() -> Result<(), Box> { - test_tpch_query(11).await - } - - #[tokio::test] - async fn test_tpch_12() -> Result<(), Box> { - test_tpch_query(12).await - } - - #[tokio::test] - async fn test_tpch_13() -> Result<(), Box> { - test_tpch_query(13).await - } - - #[tokio::test] - async fn test_tpch_14() -> Result<(), Box> { - test_tpch_query(14).await - } - - #[tokio::test] - async fn test_tpch_15() -> Result<(), Box> { - test_tpch_query(15).await - } - - #[tokio::test] - async fn test_tpch_16() -> Result<(), Box> { - test_tpch_query(16).await - } - - #[tokio::test] - async fn test_tpch_17() -> Result<(), Box> { - test_tpch_query(17).await - } - - #[tokio::test] - async fn test_tpch_18() -> Result<(), Box> { - test_tpch_query(18).await - } - - #[tokio::test] - async fn test_tpch_19() -> Result<(), Box> { - test_tpch_query(19).await - } - - #[tokio::test] - async fn test_tpch_20() -> Result<(), Box> { - test_tpch_query(20).await - } - - #[tokio::test] - async fn test_tpch_21() -> Result<(), Box> { - test_tpch_query(21).await - } - - #[tokio::test] - // TODO: Add support for NestedLoopJoinExec to support query 22. - #[ignore] - async fn test_tpch_22() -> Result<(), Box> { - test_tpch_query(22).await - } - - async fn test_tpch_query(query_id: u8) -> Result<(), Box> { - let (ctx, _guard) = start_localhost_context(2, TestSessionBuilder).await; - run_tpch_query(ctx, query_id).await - } - - #[derive(Clone)] - struct TestSessionBuilder; - - #[async_trait] - impl SessionBuilder for TestSessionBuilder { - fn session_state_builder( - &self, - builder: SessionStateBuilder, - ) -> Result { - let mut config = SessionConfig::new().with_target_partitions(3); - - // FIXME: these three options are critical for the correct function of the library - // but we are not enforcing that the user sets them. They are here at the moment - // but we should figure out a way to do this better. - config - .options_mut() - .optimizer - .hash_join_single_partition_threshold = 0; - config - .options_mut() - .optimizer - .hash_join_single_partition_threshold_rows = 0; - - config.options_mut().optimizer.prefer_hash_join = true; - // end critical options section - - let rule = DistributedPhysicalOptimizerRule::new().with_maximum_partitions_per_task(2); - Ok(builder - .with_config(config) - .with_physical_optimizer_rule(Arc::new(rule))) - } - - async fn session_context( - &self, - ctx: SessionContext, - ) -> std::result::Result { - Ok(ctx) - } - } - - // test_non_distributed_consistency runs each TPC-H query twice - once in a distributed manner - // and once in a non-distributed manner. For each query, it asserts that the results are identical. - async fn run_tpch_query(ctx2: SessionContext, query_id: u8) -> Result<(), Box> { - ensure_tpch_data().await; - - let sql = get_test_tpch_query(query_id); - - // Context 1: Non-distributed execution. - let config1 = SessionConfig::new().with_target_partitions(3); - let state1 = SessionStateBuilder::new() - .with_default_features() - .with_config(config1) - .build(); - let ctx1 = SessionContext::new_with_state(state1); - - // Register tables for first context - for table_name in [ - "lineitem", "orders", "part", "partsupp", "customer", "nation", "region", "supplier", - ] { - let query_path = get_test_data_dir().join(format!("{}.parquet", table_name)); - ctx1.register_parquet( - table_name, - query_path.to_string_lossy().as_ref(), - datafusion::prelude::ParquetReadOptions::default(), - ) - .await?; - - ctx2.register_parquet( - table_name, - query_path.to_string_lossy().as_ref(), - datafusion::prelude::ParquetReadOptions::default(), - ) - .await?; - } - - let (stream1, stream2) = if query_id == 15 { - let queries: Vec<&str> = sql - .split(';') - .map(str::trim) - .filter(|s| !s.is_empty()) - .collect(); - - println!("queryies: {:?}", queries); - - ctx1.sql(queries[0]).await?.collect().await?; - ctx2.sql(queries[0]).await?.collect().await?; - let df1 = ctx1.sql(queries[1]).await?; - let df2 = ctx2.sql(queries[1]).await?; - let stream1 = df1.execute_stream().await?; - let stream2 = df2.execute_stream().await?; - - ctx1.sql(queries[2]).await?.collect().await?; - ctx2.sql(queries[2]).await?.collect().await?; - (stream1, stream2) - } else { - let stream1 = ctx1.sql(&sql).await?.execute_stream().await?; - let stream2 = ctx2.sql(&sql).await?.execute_stream().await?; - (stream1, stream2) - }; - - let batches1 = stream1.try_collect::>().await?; - let batches2 = stream2.try_collect::>().await?; - - let formatted1 = arrow::util::pretty::pretty_format_batches(&batches1)?; - let formatted2 = arrow::util::pretty::pretty_format_batches(&batches2)?; - - assert_eq!( - formatted1.to_string(), - formatted2.to_string(), - "Query {} results differ between executions", - query_id - ); - - Ok(()) - } -} From ff9167d6164804b914b23e5f3c4165a73367f58e Mon Sep 17 00:00:00 2001 From: Rob Tandy Date: Wed, 20 Aug 2025 13:22:51 -0400 Subject: [PATCH 09/15] allow nested_loop_joins --- src/common/util.rs | 29 +++++++++++++++++++ src/test_utils/tpch.rs | 19 +++++++----- ...stency_test.rs => tpch_validation_test.rs} | 16 +++++----- 3 files changed, 50 insertions(+), 14 deletions(-) rename tests/{non_distributed_consistency_test.rs => tpch_validation_test.rs} (95%) diff --git a/src/common/util.rs b/src/common/util.rs index 085c5c2..124af41 100644 --- a/src/common/util.rs +++ b/src/common/util.rs @@ -1,7 +1,10 @@ +use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion::error::Result; +use datafusion::physical_plan::joins::PartitionMode; use datafusion::physical_plan::{displayable, ExecutionPlan, ExecutionPlanProperties}; use std::fmt::Write; +use std::sync::Arc; pub fn display_plan_with_partition_in_out(plan: &dyn ExecutionPlan) -> Result { let mut f = String::new(); @@ -34,3 +37,29 @@ pub fn display_plan_with_partition_in_out(plan: &dyn ExecutionPlan) -> Result) -> Result { + // recursively check to see if this stages plan contains a NestedLoopJoinExec + let mut has_unsplittable_plan = false; + let search = |f: &Arc| { + if f.as_any() + .downcast_ref::() + .is_some() + { + has_unsplittable_plan = true; + return Ok(TreeNodeRecursion::Stop); + } + + Ok(TreeNodeRecursion::Continue) + }; + plan.apply(search)?; + Ok(!has_unsplittable_plan) +} diff --git a/src/test_utils/tpch.rs b/src/test_utils/tpch.rs index b10787c..e00551d 100644 --- a/src/test_utils/tpch.rs +++ b/src/test_utils/tpch.rs @@ -160,13 +160,18 @@ where macro_rules! must_generate_tpch_table { ($generator:ident, $arrow:ident, $name:literal, $data_dir:expr) => { - generate_table( - // TODO: Consider adjusting the partitions and batch sizes. - $arrow::new($generator::new(SCALE_FACTOR, 1, 1)).with_batch_size(1000), - $name, - $data_dir, - ) - .expect(concat!("Failed to generate ", $name, " table")); + let data_dir = $data_dir.join(format!("{}.parquet", $name)); + fs::create_dir_all(data_dir.clone()).expect("Failed to create data directory"); + // create three partitions for the table + (1..=3).for_each(|part| { + generate_table( + // TODO: Consider adjusting the partitions and batch sizes. + $arrow::new($generator::new(SCALE_FACTOR, part, 3)).with_batch_size(1000), + &format!("{}.parquet", part), + &data_dir.clone().into_boxed_path(), + ) + .expect(concat!("Failed to generate ", $name, " table")); + }); }; } diff --git a/tests/non_distributed_consistency_test.rs b/tests/tpch_validation_test.rs similarity index 95% rename from tests/non_distributed_consistency_test.rs rename to tests/tpch_validation_test.rs index ea2f395..67419c9 100644 --- a/tests/non_distributed_consistency_test.rs +++ b/tests/tpch_validation_test.rs @@ -6,9 +6,12 @@ mod tests { use async_trait::async_trait; use datafusion::error::DataFusionError; use datafusion::execution::SessionStateBuilder; + use datafusion::physical_plan::displayable; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_distributed::test_utils::localhost::start_localhost_context; - use datafusion_distributed::{DistributedPhysicalOptimizerRule, SessionBuilder}; + use datafusion_distributed::{ + display_stage_graphviz, DistributedPhysicalOptimizerRule, ExecutionStage, SessionBuilder, + }; use futures::TryStreamExt; use std::error::Error; use std::sync::Arc; @@ -119,8 +122,6 @@ mod tests { } #[tokio::test] - // TODO: Add support for NestedLoopJoinExec to support query 22. - #[ignore] async fn test_tpch_22() -> Result<(), Box> { test_tpch_query(22).await } @@ -155,7 +156,6 @@ mod tests { config.options_mut().optimizer.prefer_hash_join = true; // end critical options section - let rule = DistributedPhysicalOptimizerRule::new().with_maximum_partitions_per_task(2); Ok(builder .with_config(config) @@ -174,7 +174,6 @@ mod tests { // and once in a non-distributed manner. For each query, it asserts that the results are identical. async fn run_tpch_query(ctx2: SessionContext, query_id: u8) -> Result<(), Box> { ensure_tpch_data().await; - let sql = get_test_tpch_query(query_id); // Context 1: Non-distributed execution. @@ -205,6 +204,9 @@ mod tests { .await?; } + // Query 15 has three queries in it, one creating the view, the second + // executing, which we want to capture the output of, and the third + // tearing down the view let (stream1, stream2) = if query_id == 15 { let queries: Vec<&str> = sql .split(';') @@ -212,12 +214,11 @@ mod tests { .filter(|s| !s.is_empty()) .collect(); - println!("queryies: {:?}", queries); - ctx1.sql(queries[0]).await?.collect().await?; ctx2.sql(queries[0]).await?.collect().await?; let df1 = ctx1.sql(queries[1]).await?; let df2 = ctx2.sql(queries[1]).await?; + let stream1 = df1.execute_stream().await?; let stream2 = df2.execute_stream().await?; @@ -227,6 +228,7 @@ mod tests { } else { let stream1 = ctx1.sql(&sql).await?.execute_stream().await?; let stream2 = ctx2.sql(&sql).await?.execute_stream().await?; + (stream1, stream2) }; From 25d24d2a01e7cb17ef3202d957590f34b0df30d0 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Thu, 21 Aug 2025 14:28:18 +0200 Subject: [PATCH 10/15] Rework SessionBuilder --- benchmarks/src/tpch/run.rs | 29 ++--- src/common/mod.rs | 1 + src/config_extension_ext.rs | 66 +++++------ src/flight_service/do_get.rs | 51 +++------ src/flight_service/mod.rs | 4 +- src/flight_service/service.rs | 10 +- src/flight_service/session_builder.rs | 130 ++++++++++------------ src/lib.rs | 5 +- src/test_utils/localhost.rs | 37 +++--- tests/custom_config_extension.rs | 33 +++--- tests/custom_extension_codec.rs | 26 ++--- tests/distributed_aggregation.rs | 4 +- tests/error_propagation.rs | 26 +++-- tests/highly_distributed_query.rs | 4 +- tests/non_distributed_consistency_test.rs | 74 ++++++------ 15 files changed, 218 insertions(+), 282 deletions(-) diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index 507e408..f3cd2bb 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -36,14 +36,16 @@ use datafusion::datasource::listing::{ }; use datafusion::datasource::{MemTable, TableProvider}; use datafusion::error::{DataFusionError, Result}; -use datafusion::execution::SessionStateBuilder; +use datafusion::execution::{SessionState, SessionStateBuilder}; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::{collect, displayable}; use datafusion::prelude::*; use crate::util::{print_memory_stats, BenchmarkRun, CommonOpt, QueryResult}; use datafusion_distributed::test_utils::localhost::start_localhost_context; -use datafusion_distributed::{DistributedPhysicalOptimizerRule, SessionBuilder}; +use datafusion_distributed::{ + DistributedPhysicalOptimizerRule, DistributedSessionBuilder, DistributedSessionBuilderContext, +}; use log::info; use structopt::StructOpt; @@ -110,11 +112,13 @@ pub struct RunOpt { } #[async_trait] -impl SessionBuilder for RunOpt { - fn session_state_builder( +impl DistributedSessionBuilder for RunOpt { + async fn build_session_state( &self, - mut builder: SessionStateBuilder, - ) -> Result { + _ctx: DistributedSessionBuilderContext, + ) -> Result { + let mut builder = SessionStateBuilder::new().with_default_features(); + let mut config = self .common .config()? @@ -145,17 +149,14 @@ impl SessionBuilder for RunOpt { builder = builder.with_physical_optimizer_rule(Arc::new(rule)); } - Ok(builder + let state = builder .with_config(config) - .with_runtime_env(rt_builder.build_arc()?)) - } + .with_runtime_env(rt_builder.build_arc()?) + .build(); - async fn session_context( - &self, - ctx: SessionContext, - ) -> std::result::Result { + let ctx = SessionContext::from(state); self.register_tables(&ctx).await?; - Ok(ctx) + Ok(ctx.state()) } } diff --git a/src/common/mod.rs b/src/common/mod.rs index 572b996..6e77e1a 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,2 +1,3 @@ +#[allow(unused)] pub mod ttl_map; pub mod util; diff --git a/src/config_extension_ext.rs b/src/config_extension_ext.rs index 41f5141..eb927e0 100644 --- a/src/config_extension_ext.rs +++ b/src/config_extension_ext.rs @@ -27,9 +27,9 @@ pub trait ConfigExtensionExt { /// # use async_trait::async_trait; /// # use datafusion::common::{extensions_options, DataFusionError}; /// # use datafusion::config::ConfigExtension; - /// # use datafusion::execution::SessionState; + /// # use datafusion::execution::{SessionState, SessionStateBuilder}; /// # use datafusion::prelude::SessionConfig; - /// # use datafusion_distributed::{ConfigExtensionExt, SessionBuilder}; + /// # use datafusion_distributed::{ConfigExtensionExt, DistributedSessionBuilder, DistributedSessionBuilderContext}; /// /// extensions_options! { /// pub struct CustomExtension { @@ -52,11 +52,13 @@ pub trait ConfigExtensionExt { /// struct MyCustomSessionBuilder; /// /// #[async_trait] - /// impl SessionBuilder for MyCustomSessionBuilder { - /// async fn session_state(&self, mut state: SessionState) -> Result { + /// impl DistributedSessionBuilder for MyCustomSessionBuilder { + /// async fn build_session_state(&self, ctx: DistributedSessionBuilderContext) -> Result { + /// let mut state = SessionStateBuilder::new().build(); + /// /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will /// // know how to deserialize the CustomExtension from the gRPC metadata. - /// state.retrieve_distributed_option_extension::()?; + /// state.retrieve_distributed_option_extension::(&ctx.headers)?; /// Ok(state) /// } /// } @@ -76,9 +78,9 @@ pub trait ConfigExtensionExt { /// # use async_trait::async_trait; /// # use datafusion::common::{extensions_options, DataFusionError}; /// # use datafusion::config::ConfigExtension; - /// # use datafusion::execution::SessionState; + /// # use datafusion::execution::{SessionState, SessionStateBuilder}; /// # use datafusion::prelude::SessionConfig; - /// # use datafusion_distributed::{ConfigExtensionExt, SessionBuilder}; + /// # use datafusion_distributed::{ConfigExtensionExt, DistributedSessionBuilder, DistributedSessionBuilderContext}; /// /// extensions_options! { /// pub struct CustomExtension { @@ -101,17 +103,19 @@ pub trait ConfigExtensionExt { /// struct MyCustomSessionBuilder; /// /// #[async_trait] - /// impl SessionBuilder for MyCustomSessionBuilder { - /// async fn session_state(&self, mut state: SessionState) -> Result { + /// impl DistributedSessionBuilder for MyCustomSessionBuilder { + /// async fn build_session_state(&self, ctx: DistributedSessionBuilderContext) -> Result { + /// let mut state = SessionStateBuilder::new().build(); /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will /// // know how to deserialize the CustomExtension from the gRPC metadata. - /// state.retrieve_distributed_option_extension::()?; + /// state.retrieve_distributed_option_extension::(&ctx.headers)?; /// Ok(state) /// } /// } /// ``` fn retrieve_distributed_option_extension( &mut self, + headers: &HeaderMap, ) -> Result<(), DataFusionError>; } @@ -153,14 +157,11 @@ impl ConfigExtensionExt for SessionConfig { fn retrieve_distributed_option_extension( &mut self, + headers: &HeaderMap, ) -> Result<(), DataFusionError> { - let Some(flight_metadata) = self.get_extension::() else { - return Ok(()); - }; - let mut result = T::default(); let mut found_some = false; - for (k, v) in flight_metadata.0.iter() { + for (k, v) in headers.iter() { let key = k.as_str().trim_start_matches(FLIGHT_METADATA_CONFIG_PREFIX); let prefix = format!("{}.", T::PREFIX); if key.starts_with(&prefix) { @@ -185,7 +186,7 @@ impl ConfigExtensionExt for SessionStateBuilder { delegate! { to self.config().get_or_insert_default() { fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; - fn retrieve_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + fn retrieve_distributed_option_extension(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>; } } } @@ -194,7 +195,7 @@ impl ConfigExtensionExt for SessionState { delegate! { to self.config_mut() { fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; - fn retrieve_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + fn retrieve_distributed_option_extension(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>; } } } @@ -203,7 +204,7 @@ impl ConfigExtensionExt for SessionContext { delegate! { to self.state_ref().write().config_mut() { fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; - fn retrieve_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + fn retrieve_distributed_option_extension(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>; } } } @@ -212,17 +213,6 @@ impl ConfigExtensionExt for SessionContext { pub(crate) struct ContextGrpcMetadata(pub HeaderMap); impl ContextGrpcMetadata { - pub(crate) fn from_headers(metadata: HeaderMap) -> Self { - let mut new = HeaderMap::new(); - for (k, v) in metadata.into_iter() { - let Some(k) = k else { continue }; - if k.as_str().starts_with(FLIGHT_METADATA_CONFIG_PREFIX) { - new.insert(k, v); - } - } - Self(new) - } - fn merge(mut self, other: Self) -> Self { for (k, v) in other.0.into_iter() { let Some(k) = k else { continue }; @@ -252,10 +242,9 @@ mod tests { opt.baz = true; config.add_distributed_option_extension(opt)?; - + let metadata = config.get_extension::().unwrap(); let mut new_config = SessionConfig::new(); - new_config.set_extension(config.get_extension::().unwrap()); - new_config.retrieve_distributed_option_extension::()?; + new_config.retrieve_distributed_option_extension::(&metadata.0)?; let opt = get_ext::(&config); let new_opt = get_ext::(&new_config); @@ -317,7 +306,7 @@ mod tests { fn test_propagate_no_metadata() -> Result<(), Box> { let mut config = SessionConfig::new(); - config.retrieve_distributed_option_extension::()?; + config.retrieve_distributed_option_extension::(&Default::default())?; let extension = config.options().extensions.get::(); assert!(extension.is_none()); @@ -330,13 +319,11 @@ mod tests { let mut config = SessionConfig::new(); let mut header_map = HeaderMap::new(); header_map.insert( - HeaderName::from_str("x-datafusion-distributed-other.setting").unwrap(), + HeaderName::from_str("x-datafusion-distributed-config-other.setting").unwrap(), HeaderValue::from_str("value").unwrap(), ); - let flight_metadata = ContextGrpcMetadata::from_headers(header_map); - config.set_extension(std::sync::Arc::new(flight_metadata)); - config.retrieve_distributed_option_extension::()?; + config.retrieve_distributed_option_extension::(&header_map)?; let extension = config.options().extensions.get::(); assert!(extension.is_none()); @@ -384,9 +371,8 @@ mod tests { ); let mut new_config = SessionConfig::new(); - new_config.set_extension(flight_metadata); - new_config.retrieve_distributed_option_extension::()?; - new_config.retrieve_distributed_option_extension::()?; + new_config.retrieve_distributed_option_extension::(&metadata)?; + new_config.retrieve_distributed_option_extension::(&metadata)?; let propagated_custom = get_ext::(&new_config); let propagated_another = get_ext::(&new_config); diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index 2c14bba..0be1c5b 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -3,6 +3,7 @@ use crate::composed_extension_codec::ComposedPhysicalExtensionCodec; use crate::config_extension_ext::ContextGrpcMetadata; use crate::errors::datafusion_error_to_tonic_status; use crate::flight_service::service::ArrowFlightEndpoint; +use crate::flight_service::session_builder::DistributedSessionBuilderContext; use crate::plan::{DistributedCodec, PartitionGroup}; use crate::stage::{stage_from_proto, ExecutionStage, ExecutionStageProto}; use crate::user_provided_codec::get_user_codec; @@ -10,9 +11,7 @@ use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; use arrow_flight::flight_service_server::FlightService; use arrow_flight::Ticket; -use datafusion::execution::{SessionState, SessionStateBuilder}; -use datafusion::optimizer::OptimizerConfig; -use datafusion::prelude::SessionConfig; +use datafusion::execution::SessionState; use futures::TryStreamExt; use prost::Message; use std::sync::Arc; @@ -90,7 +89,7 @@ impl ArrowFlightEndpoint { async fn get_state_and_stage( &self, doget: DoGet, - metadata: MetadataMap, + metadata_map: MetadataMap, ) -> Result<(SessionState, Arc), Status> { let key = doget .stage_key @@ -106,54 +105,34 @@ impl ArrowFlightEndpoint { .stage_proto .ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?; - let mut config = SessionConfig::default(); - config.set_extension(Arc::new(ContextGrpcMetadata::from_headers( - metadata.into_headers(), - ))); - - let state_builder = SessionStateBuilder::new() - .with_runtime_env(Arc::clone(&self.runtime)) - .with_config(config) - .with_default_features(); - - let state_builder = self - .session_builder - .session_state_builder(state_builder) - .map_err(|err| datafusion_error_to_tonic_status(&err))?; - - let state = state_builder.build(); + let headers = metadata_map.into_headers(); let mut state = self .session_builder - .session_state(state) + .build_session_state(DistributedSessionBuilderContext { + runtime_env: Arc::clone(&self.runtime), + headers: headers.clone(), + }) .await .map_err(|err| datafusion_error_to_tonic_status(&err))?; - let function_registry = - state.function_registry().ok_or(Status::invalid_argument( - "FunctionRegistry not present in newly built SessionState", - ))?; - let mut combined_codec = ComposedPhysicalExtensionCodec::default(); combined_codec.push(DistributedCodec); if let Some(ref user_codec) = get_user_codec(state.config()) { combined_codec.push_arc(Arc::clone(user_codec)); } - let stage = stage_from_proto( - stage_proto, - function_registry, - self.runtime.as_ref(), - &combined_codec, - ) - .map(Arc::new) - .map_err(|err| { - Status::invalid_argument(format!("Cannot decode stage proto: {err}")) - })?; + let stage = + stage_from_proto(stage_proto, &state, self.runtime.as_ref(), &combined_codec) + .map(Arc::new) + .map_err(|err| { + Status::invalid_argument(format!("Cannot decode stage proto: {err}")) + })?; // Add the extensions that might be required for ExecutionPlan nodes in the plan let config = state.config_mut(); config.set_extension(Arc::clone(&self.channel_manager)); config.set_extension(stage.clone()); + config.set_extension(Arc::new(ContextGrpcMetadata(headers))); Ok::<_, Status>((state, stage)) }) diff --git a/src/flight_service/mod.rs b/src/flight_service/mod.rs index e3b53c8..b389074 100644 --- a/src/flight_service/mod.rs +++ b/src/flight_service/mod.rs @@ -5,4 +5,6 @@ mod session_builder; pub(crate) use do_get::DoGet; pub use service::{ArrowFlightEndpoint, StageKey}; -pub use session_builder::{NoopSessionBuilder, SessionBuilder}; +pub use session_builder::{ + DefaultSessionBuilder, DistributedSessionBuilder, DistributedSessionBuilderContext, +}; diff --git a/src/flight_service/service.rs b/src/flight_service/service.rs index fc19269..78e6bdb 100644 --- a/src/flight_service/service.rs +++ b/src/flight_service/service.rs @@ -1,6 +1,6 @@ use crate::channel_manager::ChannelManager; -use crate::flight_service::session_builder::NoopSessionBuilder; -use crate::flight_service::SessionBuilder; +use crate::flight_service::session_builder::DefaultSessionBuilder; +use crate::flight_service::DistributedSessionBuilder; use crate::stage::ExecutionStage; use crate::ChannelResolver; use arrow_flight::flight_service_server::FlightService; @@ -36,7 +36,7 @@ pub struct ArrowFlightEndpoint { pub(super) runtime: Arc, #[allow(clippy::type_complexity)] pub(super) stages: DashMap)>>>, - pub(super) session_builder: Arc, + pub(super) session_builder: Arc, } impl ArrowFlightEndpoint { @@ -45,13 +45,13 @@ impl ArrowFlightEndpoint { channel_manager: Arc::new(ChannelManager::new(channel_resolver)), runtime: Arc::new(RuntimeEnv::default()), stages: DashMap::new(), - session_builder: Arc::new(NoopSessionBuilder), + session_builder: Arc::new(DefaultSessionBuilder), } } pub fn with_session_builder( &mut self, - session_builder: impl SessionBuilder + Send + Sync + 'static, + session_builder: impl DistributedSessionBuilder + Send + Sync + 'static, ) { self.session_builder = Arc::new(session_builder); } diff --git a/src/flight_service/session_builder.rs b/src/flight_service/session_builder.rs index 64be9dd..3decda0 100644 --- a/src/flight_service/session_builder.rs +++ b/src/flight_service/session_builder.rs @@ -1,27 +1,34 @@ use async_trait::async_trait; use datafusion::error::DataFusionError; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::{SessionState, SessionStateBuilder}; -use datafusion::prelude::SessionContext; +use http::HeaderMap; +use std::sync::Arc; + +#[derive(Debug, Clone, Default)] +pub struct DistributedSessionBuilderContext { + pub runtime_env: Arc, + pub headers: HeaderMap, +} /// Trait called by the Arrow Flight endpoint that handles distributed parts of a DataFusion -/// plan for building a DataFusion's [datafusion::prelude::SessionContext]. +/// plan for building a DataFusion's [SessionState]. #[async_trait] -pub trait SessionBuilder { - /// Takes a [SessionStateBuilder] and adds whatever is necessary for it to work, like - /// custom extension codecs, custom physical optimization rules, UDFs, UDAFs, config - /// extensions, etc... +pub trait DistributedSessionBuilder { + /// Builds a custom [SessionState] scoped to a single ArrowFlight gRPC call, allowing the + /// users to provide a customized DataFusion session with things like custom extension codecs, + /// custom physical optimization rules, UDFs, UDAFs, config extensions, etc... /// - /// Example: adding some custom extension plan codecs + /// Example: /// /// ```rust /// # use std::sync::Arc; /// # use async_trait::async_trait; /// # use datafusion::error::DataFusionError; - /// # use datafusion::execution::runtime_env::RuntimeEnv; - /// # use datafusion::execution::{FunctionRegistry, SessionStateBuilder}; + /// # use datafusion::execution::{FunctionRegistry, SessionState, SessionStateBuilder}; /// # use datafusion::physical_plan::ExecutionPlan; /// # use datafusion_proto::physical_plan::PhysicalExtensionCodec; - /// # use datafusion_distributed::{with_user_codec, SessionBuilder}; + /// # use datafusion_distributed::{with_user_codec, DistributedSessionBuilder, DistributedSessionBuilderContext}; /// /// #[derive(Debug)] /// struct CustomExecCodec; @@ -40,79 +47,54 @@ pub trait SessionBuilder { /// struct CustomSessionBuilder; /// /// #[async_trait] - /// impl SessionBuilder for CustomSessionBuilder { - /// fn session_state_builder(&self, mut builder: SessionStateBuilder) -> Result { - /// // Add your UDFs, optimization rules, etc... - /// Ok(with_user_codec(builder, CustomExecCodec)) - /// } - /// } - /// ``` - fn session_state_builder( - &self, - builder: SessionStateBuilder, - ) -> Result { - Ok(builder) - } - - /// Modifies the [SessionState] and returns it. Same as [SessionBuilder::session_state_builder] - /// but operating on an already built [SessionState]. - /// - /// Example: - /// - /// ```rust - /// # use async_trait::async_trait; - /// # use datafusion::common::DataFusionError; - /// # use datafusion::execution::SessionState; - /// # use datafusion_distributed::SessionBuilder; + /// impl DistributedSessionBuilder for CustomSessionBuilder { + /// async fn build_session_state(&self, ctx: DistributedSessionBuilderContext) -> Result { + /// let builder = SessionStateBuilder::new() + /// .with_runtime_env(ctx.runtime_env.clone()) + /// .with_default_features(); /// - /// #[derive(Clone)] - /// struct CustomSessionBuilder; + /// let builder = with_user_codec(builder, CustomExecCodec); /// - /// #[async_trait] - /// impl SessionBuilder for CustomSessionBuilder { - /// async fn session_state(&self, state: SessionState) -> Result { - /// // mutate the state adding any custom logic - /// Ok(state) - /// } - /// } - /// ``` - async fn session_state(&self, state: SessionState) -> Result { - Ok(state) - } - - /// Modifies the [SessionContext] and returns it. Same as [SessionBuilder::session_state_builder] - /// or [SessionBuilder::session_state] but operation on an already built [SessionContext]. - /// - /// Example: - /// - /// ```rust - /// # use async_trait::async_trait; - /// # use datafusion::common::DataFusionError; - /// # use datafusion::prelude::SessionContext; - /// # use datafusion_distributed::SessionBuilder; - /// - /// #[derive(Clone)] - /// struct CustomSessionBuilder; + /// // Add your UDFs, optimization rules, etc... /// - /// #[async_trait] - /// impl SessionBuilder for CustomSessionBuilder { - /// async fn session_context(&self, ctx: SessionContext) -> Result { - /// // mutate the context adding any custom logic - /// Ok(ctx) + /// Ok(builder.build()) /// } /// } /// ``` - async fn session_context( + async fn build_session_state( &self, - ctx: SessionContext, - ) -> Result { - Ok(ctx) - } + ctx: DistributedSessionBuilderContext, + ) -> Result; } -/// Noop implementation of the [SessionBuilder]. Used by default if no [SessionBuilder] is provided +/// Noop implementation of the [DistributedSessionBuilder]. Used by default if no [DistributedSessionBuilder] is provided /// while building the Arrow Flight endpoint. #[derive(Debug, Clone)] -pub struct NoopSessionBuilder; +pub struct DefaultSessionBuilder; + +#[async_trait] +impl DistributedSessionBuilder for DefaultSessionBuilder { + async fn build_session_state( + &self, + ctx: DistributedSessionBuilderContext, + ) -> Result { + Ok(SessionStateBuilder::new() + .with_runtime_env(ctx.runtime_env.clone()) + .with_default_features() + .build()) + } +} -impl SessionBuilder for NoopSessionBuilder {} +#[async_trait] +impl DistributedSessionBuilder for F +where + F: Fn(DistributedSessionBuilderContext) -> Fut + Send + Sync + 'static, + Fut: std::future::Future> + Send + 'static, +{ + async fn build_session_state( + &self, + ctx: DistributedSessionBuilderContext, + ) -> Result { + self(ctx).await + } +} diff --git a/src/lib.rs b/src/lib.rs index 40515a9..dbf8a5a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,7 +17,10 @@ pub mod test_utils; pub use channel_manager::{BoxCloneSyncChannel, ChannelManager, ChannelResolver}; pub use config_extension_ext::ConfigExtensionExt; -pub use flight_service::{ArrowFlightEndpoint, NoopSessionBuilder, SessionBuilder}; +pub use flight_service::{ + ArrowFlightEndpoint, DefaultSessionBuilder, DistributedSessionBuilder, + DistributedSessionBuilderContext, +}; pub use physical_optimizer::DistributedPhysicalOptimizerRule; pub use plan::ArrowFlightReadExec; pub use stage::{display_stage_graphviz, ExecutionStage}; diff --git a/src/test_utils/localhost.rs b/src/test_utils/localhost.rs index 83a328e..a782807 100644 --- a/src/test_utils/localhost.rs +++ b/src/test_utils/localhost.rs @@ -1,12 +1,13 @@ use crate::{ - ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelManager, ChannelResolver, SessionBuilder, + ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelManager, ChannelResolver, + DistributedSessionBuilder, DistributedSessionBuilderContext, }; use arrow_flight::flight_service_server::FlightServiceServer; use async_trait::async_trait; +use datafusion::common::runtime::JoinSet; use datafusion::common::DataFusionError; -use datafusion::execution::SessionStateBuilder; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::prelude::SessionContext; -use datafusion::{common::runtime::JoinSet, prelude::SessionConfig}; use std::error::Error; use std::sync::Arc; use std::time::Duration; @@ -19,7 +20,7 @@ pub async fn start_localhost_context( session_builder: B, ) -> (SessionContext, JoinSet<()>) where - B: SessionBuilder + Send + Sync + 'static, + B: DistributedSessionBuilder + Send + Sync + 'static, B: Clone, { let listeners = futures::future::try_join_all( @@ -53,25 +54,19 @@ where } tokio::time::sleep(Duration::from_millis(100)).await; - let config = SessionConfig::new().with_target_partitions(3); - - let builder = SessionStateBuilder::new() - .with_default_features() - .with_config(config); - let builder = session_builder.session_state_builder(builder).unwrap(); - - let state = builder.build(); - let state = session_builder.session_state(state).await.unwrap(); - - let ctx = SessionContext::new_with_state(state); - let ctx = session_builder.session_context(ctx).await.unwrap(); - - ctx.state_ref() - .write() + let mut state = session_builder + .build_session_state(DistributedSessionBuilderContext { + runtime_env: Arc::new(RuntimeEnv::default()), + headers: Default::default(), + }) + .await + .unwrap(); + state .config_mut() .set_extension(Arc::new(ChannelManager::new(channel_resolver))); + state.config_mut().options_mut().execution.target_partitions = 3; - (ctx, join_set) + (SessionContext::from(state), join_set) } #[derive(Clone)] @@ -108,7 +103,7 @@ impl ChannelResolver for LocalHostChannelResolver { pub async fn spawn_flight_service( channel_resolver: impl ChannelResolver + Send + Sync + 'static, - session_builder: impl SessionBuilder + Send + Sync + 'static, + session_builder: impl DistributedSessionBuilder + Send + Sync + 'static, incoming: TcpListener, ) -> Result<(), Box> { let mut endpoint = ArrowFlightEndpoint::new(channel_resolver); diff --git a/tests/custom_config_extension.rs b/tests/custom_config_extension.rs index d671d90..af84a92 100644 --- a/tests/custom_config_extension.rs +++ b/tests/custom_config_extension.rs @@ -1,12 +1,11 @@ #[cfg(all(feature = "integration", test))] mod tests { - use async_trait::async_trait; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::common::{extensions_options, internal_err}; use datafusion::config::ConfigExtension; use datafusion::error::DataFusionError; use datafusion::execution::{ - FunctionRegistry, SendableRecordBatchStream, SessionState, TaskContext, + FunctionRegistry, SendableRecordBatchStream, SessionState, SessionStateBuilder, TaskContext, }; use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; @@ -15,10 +14,10 @@ mod tests { execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; use datafusion_distributed::test_utils::localhost::start_localhost_context; - use datafusion_distributed::{add_user_codec, ConfigExtensionExt}; use datafusion_distributed::{ - ArrowFlightReadExec, DistributedPhysicalOptimizerRule, SessionBuilder, + add_user_codec, ConfigExtensionExt, DistributedSessionBuilderContext, }; + use datafusion_distributed::{ArrowFlightReadExec, DistributedPhysicalOptimizerRule}; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use futures::TryStreamExt; use prost::Message; @@ -28,23 +27,19 @@ mod tests { #[tokio::test] async fn custom_config_extension() -> Result<(), Box> { - #[derive(Clone)] - struct CustomSessionBuilder; - - #[async_trait] - impl SessionBuilder for CustomSessionBuilder { - async fn session_state( - &self, - mut state: SessionState, - ) -> Result { - state.retrieve_distributed_option_extension::()?; - add_user_codec(state.config_mut(), CustomConfigExtensionRequiredExecCodec); - Ok(state) - } + async fn build_state( + ctx: DistributedSessionBuilderContext, + ) -> Result { + let mut state = SessionStateBuilder::new() + .with_runtime_env(ctx.runtime_env) + .with_default_features() + .build(); + state.retrieve_distributed_option_extension::(&ctx.headers)?; + add_user_codec(state.config_mut(), CustomConfigExtensionRequiredExecCodec); + Ok(state) } - let (mut ctx, _guard) = start_localhost_context(3, CustomSessionBuilder).await; - add_user_codec(&mut ctx, CustomConfigExtensionRequiredExecCodec); + let (mut ctx, _guard) = start_localhost_context(3, build_state).await; ctx.add_distributed_option_extension(CustomExtension { foo: "foo".to_string(), bar: 1, diff --git a/tests/custom_extension_codec.rs b/tests/custom_extension_codec.rs index 3d2870d..fa86082 100644 --- a/tests/custom_extension_codec.rs +++ b/tests/custom_extension_codec.rs @@ -7,7 +7,7 @@ mod tests { use datafusion::arrow::util::pretty::pretty_format_batches; use datafusion::error::DataFusionError; use datafusion::execution::{ - FunctionRegistry, SendableRecordBatchStream, SessionStateBuilder, TaskContext, + FunctionRegistry, SendableRecordBatchStream, SessionState, SessionStateBuilder, TaskContext, }; use datafusion::logical_expr::Operator; use datafusion::physical_expr::expressions::{col, lit, BinaryExpr}; @@ -22,11 +22,11 @@ mod tests { use datafusion::physical_plan::{ displayable, execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; - use datafusion_distributed::assert_snapshot; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::{ - with_user_codec, ArrowFlightReadExec, DistributedPhysicalOptimizerRule, SessionBuilder, + add_user_codec, assert_snapshot, DistributedSessionBuilderContext, }; + use datafusion_distributed::{ArrowFlightReadExec, DistributedPhysicalOptimizerRule}; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_proto::protobuf::proto_error; use futures::{stream, TryStreamExt}; @@ -38,18 +38,18 @@ mod tests { #[tokio::test] #[ignore] async fn custom_extension_codec() -> Result<(), Box> { - #[derive(Clone)] - struct CustomSessionBuilder; - impl SessionBuilder for CustomSessionBuilder { - fn session_state_builder( - &self, - builder: SessionStateBuilder, - ) -> Result { - Ok(with_user_codec(builder, Int64ListExecCodec)) - } + async fn build_state( + ctx: DistributedSessionBuilderContext, + ) -> Result { + let mut state = SessionStateBuilder::new() + .with_runtime_env(ctx.runtime_env) + .with_default_features() + .build(); + add_user_codec(state.config_mut(), Int64ListExecCodec); + Ok(state) } - let (ctx, _guard) = start_localhost_context(3, CustomSessionBuilder).await; + let (ctx, _guard) = start_localhost_context(3, build_state).await; let single_node_plan = build_plan(false)?; assert_snapshot!(displayable(single_node_plan.as_ref()).indent(true).to_string(), @r" diff --git a/tests/distributed_aggregation.rs b/tests/distributed_aggregation.rs index 1a572f2..273b332 100644 --- a/tests/distributed_aggregation.rs +++ b/tests/distributed_aggregation.rs @@ -5,13 +5,13 @@ mod tests { use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::test_utils::parquet::register_parquet_tables; use datafusion_distributed::test_utils::plan::distribute_aggregate; - use datafusion_distributed::{assert_snapshot, NoopSessionBuilder}; + use datafusion_distributed::{assert_snapshot, DefaultSessionBuilder}; use futures::TryStreamExt; use std::error::Error; #[tokio::test] async fn distributed_aggregation() -> Result<(), Box> { - let (ctx, _guard) = start_localhost_context(3, NoopSessionBuilder).await; + let (ctx, _guard) = start_localhost_context(3, DefaultSessionBuilder).await; register_parquet_tables(&ctx).await?; let df = ctx diff --git a/tests/error_propagation.rs b/tests/error_propagation.rs index e8d2ca5..b030d09 100644 --- a/tests/error_propagation.rs +++ b/tests/error_propagation.rs @@ -3,7 +3,7 @@ mod tests { use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::error::DataFusionError; use datafusion::execution::{ - FunctionRegistry, SendableRecordBatchStream, SessionStateBuilder, TaskContext, + FunctionRegistry, SendableRecordBatchStream, SessionState, SessionStateBuilder, TaskContext, }; use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; @@ -13,7 +13,8 @@ mod tests { }; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::{ - with_user_codec, ArrowFlightReadExec, DistributedPhysicalOptimizerRule, SessionBuilder, + add_user_codec, ArrowFlightReadExec, DistributedPhysicalOptimizerRule, + DistributedSessionBuilderContext, }; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_proto::protobuf::proto_error; @@ -26,17 +27,18 @@ mod tests { #[tokio::test] async fn test_error_propagation() -> Result<(), Box> { - #[derive(Clone)] - struct CustomSessionBuilder; - impl SessionBuilder for CustomSessionBuilder { - fn session_state_builder( - &self, - builder: SessionStateBuilder, - ) -> Result { - Ok(with_user_codec(builder, ErrorExecCodec)) - } + async fn build_state( + ctx: DistributedSessionBuilderContext, + ) -> Result { + let mut state = SessionStateBuilder::new() + .with_runtime_env(ctx.runtime_env) + .with_default_features() + .build(); + add_user_codec(state.config_mut(), ErrorExecCodec); + Ok(state) } - let (ctx, _guard) = start_localhost_context(3, CustomSessionBuilder).await; + + let (ctx, _guard) = start_localhost_context(3, build_state).await; let mut plan: Arc = Arc::new(ErrorExec::new("something failed")); diff --git a/tests/highly_distributed_query.rs b/tests/highly_distributed_query.rs index d6c2cfa..c2b8160 100644 --- a/tests/highly_distributed_query.rs +++ b/tests/highly_distributed_query.rs @@ -4,7 +4,7 @@ mod tests { use datafusion::physical_plan::{displayable, execute_stream}; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::test_utils::parquet::register_parquet_tables; - use datafusion_distributed::{assert_snapshot, ArrowFlightReadExec, NoopSessionBuilder}; + use datafusion_distributed::{assert_snapshot, ArrowFlightReadExec, DefaultSessionBuilder}; use futures::TryStreamExt; use std::error::Error; use std::sync::Arc; @@ -12,7 +12,7 @@ mod tests { #[tokio::test] #[ignore] async fn highly_distributed_query() -> Result<(), Box> { - let (ctx, _guard) = start_localhost_context(9, NoopSessionBuilder).await; + let (ctx, _guard) = start_localhost_context(9, DefaultSessionBuilder).await; register_parquet_tables(&ctx).await?; let df = ctx.sql(r#"SELECT * FROM flights_1m"#).await?; diff --git a/tests/non_distributed_consistency_test.rs b/tests/non_distributed_consistency_test.rs index ea2f395..819289b 100644 --- a/tests/non_distributed_consistency_test.rs +++ b/tests/non_distributed_consistency_test.rs @@ -3,12 +3,13 @@ mod common; #[cfg(all(feature = "integration", test))] mod tests { use crate::common::{ensure_tpch_data, get_test_data_dir, get_test_tpch_query}; - use async_trait::async_trait; use datafusion::error::DataFusionError; - use datafusion::execution::SessionStateBuilder; + use datafusion::execution::{SessionState, SessionStateBuilder}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_distributed::test_utils::localhost::start_localhost_context; - use datafusion_distributed::{DistributedPhysicalOptimizerRule, SessionBuilder}; + use datafusion_distributed::{ + DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext, + }; use futures::TryStreamExt; use std::error::Error; use std::sync::Arc; @@ -126,48 +127,37 @@ mod tests { } async fn test_tpch_query(query_id: u8) -> Result<(), Box> { - let (ctx, _guard) = start_localhost_context(2, TestSessionBuilder).await; + let (ctx, _guard) = start_localhost_context(2, build_state).await; run_tpch_query(ctx, query_id).await } - #[derive(Clone)] - struct TestSessionBuilder; - - #[async_trait] - impl SessionBuilder for TestSessionBuilder { - fn session_state_builder( - &self, - builder: SessionStateBuilder, - ) -> Result { - let mut config = SessionConfig::new().with_target_partitions(3); - - // FIXME: these three options are critical for the correct function of the library - // but we are not enforcing that the user sets them. They are here at the moment - // but we should figure out a way to do this better. - config - .options_mut() - .optimizer - .hash_join_single_partition_threshold = 0; - config - .options_mut() - .optimizer - .hash_join_single_partition_threshold_rows = 0; - - config.options_mut().optimizer.prefer_hash_join = true; - // end critical options section - - let rule = DistributedPhysicalOptimizerRule::new().with_maximum_partitions_per_task(2); - Ok(builder - .with_config(config) - .with_physical_optimizer_rule(Arc::new(rule))) - } - - async fn session_context( - &self, - ctx: SessionContext, - ) -> std::result::Result { - Ok(ctx) - } + async fn build_state( + ctx: DistributedSessionBuilderContext, + ) -> Result { + let mut config = SessionConfig::new().with_target_partitions(3); + + // FIXME: these three options are critical for the correct function of the library + // but we are not enforcing that the user sets them. They are here at the moment + // but we should figure out a way to do this better. + config + .options_mut() + .optimizer + .hash_join_single_partition_threshold = 0; + config + .options_mut() + .optimizer + .hash_join_single_partition_threshold_rows = 0; + + config.options_mut().optimizer.prefer_hash_join = true; + // end critical options section + + let rule = DistributedPhysicalOptimizerRule::new().with_maximum_partitions_per_task(2); + Ok(SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(ctx.runtime_env) + .with_default_features() + .with_physical_optimizer_rule(Arc::new(rule)) + .build()) } // test_non_distributed_consistency runs each TPC-H query twice - once in a distributed manner From c692a6f0fe18d1b1bc2afc825a93f0a15c51a8ce Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Thu, 21 Aug 2025 14:36:21 +0200 Subject: [PATCH 11/15] Fix clippy errors --- src/common/ttl_map.rs | 6 +++--- src/config_extension_ext.rs | 43 ++++++++++++++++++++++--------------- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/src/common/ttl_map.rs b/src/common/ttl_map.rs index abfd377..064ea98 100644 --- a/src/common/ttl_map.rs +++ b/src/common/ttl_map.rs @@ -94,7 +94,7 @@ where shard.insert(key); } BucketOp::Clear => { - let keys_to_delete = mem::replace(&mut shard, HashSet::new()); + let keys_to_delete = mem::take(&mut shard); for key in keys_to_delete { data.remove(&key); } @@ -253,14 +253,14 @@ where /// run_gc_loop will continuously clear expired entries from the map, checking every `period`. The /// function terminates if `shutdown` is signalled. - async fn run_gc_loop(time: Arc, period: Duration, buckets: &Vec>) { + async fn run_gc_loop(time: Arc, period: Duration, buckets: &[Bucket]) { loop { tokio::time::sleep(period).await; Self::gc(time.clone(), buckets); } } - fn gc(time: Arc, buckets: &Vec>) { + fn gc(time: Arc, buckets: &[Bucket]) { let index = time.load(std::sync::atomic::Ordering::SeqCst) % buckets.len() as u64; buckets[index as usize].clear(); time.fetch_add(1, std::sync::atomic::Ordering::SeqCst); diff --git a/src/config_extension_ext.rs b/src/config_extension_ext.rs index eb927e0..bab103a 100644 --- a/src/config_extension_ext.rs +++ b/src/config_extension_ext.rs @@ -167,7 +167,7 @@ impl ConfigExtensionExt for SessionConfig { if key.starts_with(&prefix) { found_some = true; result.set( - &key.trim_start_matches(&prefix), + key.trim_start_matches(&prefix), v.to_str().map_err(|err| { internal_datafusion_err!("Cannot parse header value: {err}") })?, @@ -236,10 +236,11 @@ mod tests { fn test_propagation() -> Result<(), Box> { let mut config = SessionConfig::new(); - let mut opt = CustomExtension::default(); - opt.foo = "foo".to_string(); - opt.bar = 1; - opt.baz = true; + let opt = CustomExtension { + foo: "".to_string(), + bar: 0, + baz: false, + }; config.add_distributed_option_extension(opt)?; let metadata = config.get_extension::().unwrap(); @@ -283,12 +284,16 @@ mod tests { fn test_new_extension_overwrites_previous() -> Result<(), Box> { let mut config = SessionConfig::new(); - let mut opt1 = CustomExtension::default(); - opt1.foo = "first".to_string(); + let opt1 = CustomExtension { + foo: "first".to_string(), + ..Default::default() + }; config.add_distributed_option_extension(opt1)?; - let mut opt2 = CustomExtension::default(); - opt2.bar = 42; + let opt2 = CustomExtension { + bar: 42, + ..Default::default() + }; config.add_distributed_option_extension(opt2)?; let flight_metadata = config.get_extension::().unwrap(); @@ -335,13 +340,17 @@ mod tests { fn test_multiple_extensions_different_prefixes() -> Result<(), Box> { let mut config = SessionConfig::new(); - let mut custom_opt = CustomExtension::default(); - custom_opt.foo = "custom_value".to_string(); - custom_opt.bar = 123; + let custom_opt = CustomExtension { + foo: "custom_value".to_string(), + bar: 123, + ..Default::default() + }; - let mut another_opt = AnotherExtension::default(); - another_opt.setting1 = "other".to_string(); - another_opt.setting2 = 456; + let another_opt = AnotherExtension { + setting1: "other".to_string(), + setting2: 456, + ..Default::default() + }; config.add_distributed_option_extension(custom_opt)?; config.add_distributed_option_extension(another_opt)?; @@ -371,8 +380,8 @@ mod tests { ); let mut new_config = SessionConfig::new(); - new_config.retrieve_distributed_option_extension::(&metadata)?; - new_config.retrieve_distributed_option_extension::(&metadata)?; + new_config.retrieve_distributed_option_extension::(metadata)?; + new_config.retrieve_distributed_option_extension::(metadata)?; let propagated_custom = get_ext::(&new_config); let propagated_another = get_ext::(&new_config); From 10dc81a2dc05b2b18563fdd7cd618f969536af50 Mon Sep 17 00:00:00 2001 From: Rob Tandy Date: Thu, 21 Aug 2025 08:51:28 -0400 Subject: [PATCH 12/15] check if plan can be divided into tasks --- src/common/ttl_map.rs | 3 +-- src/common/util.rs | 1 - src/physical_optimizer.rs | 27 +++++++++++++++++---------- tests/tpch_validation_test.rs | 6 ++---- 4 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/common/ttl_map.rs b/src/common/ttl_map.rs index abfd377..ff66c73 100644 --- a/src/common/ttl_map.rs +++ b/src/common/ttl_map.rs @@ -27,7 +27,6 @@ use dashmap::{DashMap, Entry}; use datafusion::error::DataFusionError; use std::collections::HashSet; use std::hash::Hash; -use std::mem; use std::sync::atomic::AtomicU64; #[cfg(test)] use std::sync::atomic::{AtomicUsize, Ordering::Relaxed}; @@ -94,7 +93,7 @@ where shard.insert(key); } BucketOp::Clear => { - let keys_to_delete = mem::replace(&mut shard, HashSet::new()); + let keys_to_delete = std::mem::replace(&mut shard, HashSet::new()); for key in keys_to_delete { data.remove(&key); } diff --git a/src/common/util.rs b/src/common/util.rs index 124af41..52b0652 100644 --- a/src/common/util.rs +++ b/src/common/util.rs @@ -1,6 +1,5 @@ use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion::error::Result; -use datafusion::physical_plan::joins::PartitionMode; use datafusion::physical_plan::{displayable, ExecutionPlan, ExecutionPlanProperties}; use std::fmt::Write; diff --git a/src/physical_optimizer.rs b/src/physical_optimizer.rs index 19a7296..dbc3061 100644 --- a/src/physical_optimizer.rs +++ b/src/physical_optimizer.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use super::stage::ExecutionStage; +use crate::common::util::can_be_divided; use crate::{plan::PartitionIsolatorExec, ArrowFlightReadExec}; use datafusion::common::tree_node::TreeNodeRecursion; use datafusion::error::DataFusionError; @@ -83,12 +84,14 @@ impl DistributedPhysicalOptimizerRule { internal_datafusion_err!("Expected RepartitionExec to have a child"), )?); - let maybe_isolated_plan = if let Some(ppt) = self.partitions_per_task { - let isolated = Arc::new(PartitionIsolatorExec::new(child, ppt)); - plan.with_new_children(vec![isolated])? - } else { - plan - }; + let maybe_isolated_plan = + if can_be_divided(&plan)? && self.partitions_per_task.is_some() { + let ppt = self.partitions_per_task.unwrap(); + let isolated = Arc::new(PartitionIsolatorExec::new(child, ppt)); + plan.with_new_children(vec![isolated])? + } else { + plan + }; return Ok(Transformed::yes(Arc::new( ArrowFlightReadExec::new_pending( @@ -120,7 +123,7 @@ impl DistributedPhysicalOptimizerRule { ) -> Result { let mut inputs = vec![]; - let distributed = plan.transform_down(|plan| { + let distributed = plan.clone().transform_down(|plan| { let Some(node) = plan.as_any().downcast_ref::() else { return Ok(Transformed::no(plan)); }; @@ -137,9 +140,13 @@ impl DistributedPhysicalOptimizerRule { let mut stage = ExecutionStage::new(query_id, *num, distributed.data, inputs); *num += 1; - if let Some(partitions_per_task) = self.partitions_per_task { - stage = stage.with_maximum_partitions_per_task(partitions_per_task); - } + stage = match (self.partitions_per_task, can_be_divided(&plan)?) { + (Some(partitions_per_task), true) => { + stage.with_maximum_partitions_per_task(partitions_per_task) + } + (_, _) => stage, + }; + stage.depth = depth; Ok(stage) diff --git a/tests/tpch_validation_test.rs b/tests/tpch_validation_test.rs index 67419c9..c54dc41 100644 --- a/tests/tpch_validation_test.rs +++ b/tests/tpch_validation_test.rs @@ -6,12 +6,10 @@ mod tests { use async_trait::async_trait; use datafusion::error::DataFusionError; use datafusion::execution::SessionStateBuilder; - use datafusion::physical_plan::displayable; + use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_distributed::test_utils::localhost::start_localhost_context; - use datafusion_distributed::{ - display_stage_graphviz, DistributedPhysicalOptimizerRule, ExecutionStage, SessionBuilder, - }; + use datafusion_distributed::{DistributedPhysicalOptimizerRule, SessionBuilder}; use futures::TryStreamExt; use std::error::Error; use std::sync::Arc; From e3bbcc452e4da58573d91a2bcfd844d1d6ebed70 Mon Sep 17 00:00:00 2001 From: Rob Tandy Date: Thu, 21 Aug 2025 09:22:03 -0400 Subject: [PATCH 13/15] revert change --- src/common/ttl_map.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/ttl_map.rs b/src/common/ttl_map.rs index 011d647..abfd377 100644 --- a/src/common/ttl_map.rs +++ b/src/common/ttl_map.rs @@ -94,7 +94,7 @@ where shard.insert(key); } BucketOp::Clear => { - let keys_to_delete = mem::take(&mut shard); + let keys_to_delete = mem::replace(&mut shard, HashSet::new()); for key in keys_to_delete { data.remove(&key); } From 9aa6e165f107703e9db5e626defbefa0be70370c Mon Sep 17 00:00:00 2001 From: Rob Tandy Date: Thu, 21 Aug 2025 10:33:57 -0400 Subject: [PATCH 14/15] fmt and clippy --- src/common/util.rs | 3 --- src/stage/execution_stage.rs | 1 - 2 files changed, 4 deletions(-) diff --git a/src/common/util.rs b/src/common/util.rs index 35682b7..085c5c2 100644 --- a/src/common/util.rs +++ b/src/common/util.rs @@ -1,10 +1,7 @@ -use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion::error::Result; -use datafusion::physical_plan::joins::PartitionMode; use datafusion::physical_plan::{displayable, ExecutionPlan, ExecutionPlanProperties}; use std::fmt::Write; -use std::sync::Arc; pub fn display_plan_with_partition_in_out(plan: &dyn ExecutionPlan) -> Result { let mut f = String::new(); diff --git a/src/stage/execution_stage.rs b/src/stage/execution_stage.rs index 619ed67..d66591c 100644 --- a/src/stage/execution_stage.rs +++ b/src/stage/execution_stage.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use datafusion::common::internal_err; -use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::TaskContext; use datafusion::physical_plan::ExecutionPlan; From a726bff68162a3f451c47d3376b1444bff43dde7 Mon Sep 17 00:00:00 2001 From: Rob Tandy Date: Thu, 21 Aug 2025 10:39:03 -0400 Subject: [PATCH 15/15] revert change in gen-tpch.sh --- benchmarks/gen-tpch.sh | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/benchmarks/gen-tpch.sh b/benchmarks/gen-tpch.sh index 0e9de99..98ec9c8 100755 --- a/benchmarks/gen-tpch.sh +++ b/benchmarks/gen-tpch.sh @@ -2,14 +2,14 @@ set -e -SCALE_FACTOR=10 +SCALE_FACTOR=1 # https://stackoverflow.com/questions/59895/how-do-i-get-the-directory-where-a-bash-script-is-located-from-within-the-script -SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd) +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) DATA_DIR=${DATA_DIR:-$SCRIPT_DIR/data} CARGO_COMMAND=${CARGO_COMMAND:-"cargo run --release"} -if [ -z "$SCALE_FACTOR" ]; then +if [ -z "$SCALE_FACTOR" ] ; then echo "Internal error: Scale factor not specified" exit 1 fi @@ -36,7 +36,7 @@ if test -f "${FILE}"; then else echo " Copying answers to ${TPCH_DIR}/answers" mkdir -p "${TPCH_DIR}/answers" - docker run -v "${TPCH_DIR}":/data -it --entrypoint /bin/bash --rm ghcr.io/scalytics/tpch-docker:main -c "cp -f /opt/tpch/2.18.0_rc2/dbgen/answers/* /data/answers/" + docker run -v "${TPCH_DIR}":/data -it --entrypoint /bin/bash --rm ghcr.io/scalytics/tpch-docker:main -c "cp -f /opt/tpch/2.18.0_rc2/dbgen/answers/* /data/answers/" fi # Create 'parquet' files from tbl @@ -45,9 +45,9 @@ if test -d "${FILE}"; then echo " parquet files exist ($FILE exists)." else echo " creating parquet files using benchmark binary ..." - pushd "${SCRIPT_DIR}" >/dev/null + pushd "${SCRIPT_DIR}" > /dev/null $CARGO_COMMAND -- tpch-convert --input "${TPCH_DIR}" --output "${TPCH_DIR}" --format parquet - popd >/dev/null + popd > /dev/null fi # Create 'csv' files from tbl @@ -56,7 +56,8 @@ if test -d "${FILE}"; then echo " csv files exist ($FILE exists)." else echo " creating csv files using benchmark binary ..." - pushd "${SCRIPT_DIR}" >/dev/null + pushd "${SCRIPT_DIR}" > /dev/null $CARGO_COMMAND -- tpch-convert --input "${TPCH_DIR}" --output "${TPCH_DIR}/csv" --format csv - popd >/dev/null + popd > /dev/null fi +