From 3b45cc2c6c69c0b36beae2dbfe87a8e3be7db0ed Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Mon, 18 Aug 2025 16:03:54 +0200 Subject: [PATCH 1/9] Add `ConfigExtensionExt`, allowing the propagation of arbitrary [ConfigExtension]s across network boundaries --- src/config_extension_ext.rs | 456 +++++++++++++++++++++++++++++++ src/flight_service/do_get.rs | 19 +- src/lib.rs | 2 + src/plan/arrow_flight_read.rs | 26 +- tests/custom_config_extension.rs | 180 ++++++++++++ 5 files changed, 677 insertions(+), 6 deletions(-) create mode 100644 src/config_extension_ext.rs create mode 100644 tests/custom_config_extension.rs diff --git a/src/config_extension_ext.rs b/src/config_extension_ext.rs new file mode 100644 index 0000000..3daf7f0 --- /dev/null +++ b/src/config_extension_ext.rs @@ -0,0 +1,456 @@ +use datafusion::common::{internal_datafusion_err, DataFusionError}; +use datafusion::config::ConfigExtension; +use datafusion::execution::{SessionState, SessionStateBuilder}; +use datafusion::prelude::{SessionConfig, SessionContext}; +use delegate::delegate; +use http::{HeaderMap, HeaderName}; +use std::error::Error; +use std::str::FromStr; +use std::sync::Arc; + +const FLIGHT_METADATA_PREFIX: &str = "x-datafusion-distributed-"; + +/// Extension trait for `SessionConfig` to add support for propagating [ConfigExtension]s across +/// network calls. +pub trait ConfigExtensionExt { + /// Adds the provided [ConfigExtension] to the distributed context. The [ConfigExtension] will + /// be serialized using gRPC metadata and sent across tasks. Users are expected to call this + /// method with their own extensions to be able to access them in any place in the + /// plan. + /// + /// This method also adds the provided [ConfigExtension] to the current session option + /// extensions, the same as calling [SessionConfig::with_option_extension]. + /// + /// Example: + /// + /// ```rust + /// # use async_trait::async_trait; + /// # use datafusion::common::{extensions_options, DataFusionError}; + /// # use datafusion::config::ConfigExtension; + /// # use datafusion::execution::SessionState; + /// # use datafusion::prelude::SessionConfig; + /// # use datafusion_distributed::{ConfigExtensionExt, SessionBuilder}; + /// + /// extensions_options! { + /// pub struct CustomExtension { + /// pub foo: String, default = "".to_string() + /// pub bar: usize, default = 0 + /// pub baz: bool, default = false + /// } + /// } + /// + /// impl ConfigExtension for CustomExtension { + /// const PREFIX: &'static str = "custom"; + /// } + /// + /// let mut config = SessionConfig::new(); + /// let mut opt = CustomExtension::default(); + /// // Now, the CustomExtension will be able to cross network boundaries. Upon making an Arrow + /// // Flight request, it will be sent through gRPC metadata. + /// config.add_distributed_option_extension(opt).unwrap(); + /// + /// struct MyCustomSessionBuilder; + /// + /// #[async_trait] + /// impl SessionBuilder for MyCustomSessionBuilder { + /// async fn session_state(&self, mut state: SessionState) -> Result { + /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will + /// // know how to deserialize the CustomExtension from the gRPC metadata. + /// state.propagate_distributed_option_extension::()?; + /// Ok(state) + /// } + /// } + /// ``` + fn add_distributed_option_extension( + &mut self, + t: T, + ) -> Result<(), DataFusionError>; + + /// Gets the specified [ConfigExtension] from the distributed context and adds it to + /// the [SessionConfig::options] extensions. The function will build a new [ConfigExtension] + /// out of the Arrow Flight gRPC metadata present in the [SessionConfig] and will propagate it + /// to the extension options. + /// Example: + /// + /// ```rust + /// # use async_trait::async_trait; + /// # use datafusion::common::{extensions_options, DataFusionError}; + /// # use datafusion::config::ConfigExtension; + /// # use datafusion::execution::SessionState; + /// # use datafusion::prelude::SessionConfig; + /// # use datafusion_distributed::{ConfigExtensionExt, SessionBuilder}; + /// + /// extensions_options! { + /// pub struct CustomExtension { + /// pub foo: String, default = "".to_string() + /// pub bar: usize, default = 0 + /// pub baz: bool, default = false + /// } + /// } + /// + /// impl ConfigExtension for CustomExtension { + /// const PREFIX: &'static str = "custom"; + /// } + /// + /// let mut config = SessionConfig::new(); + /// let mut opt = CustomExtension::default(); + /// // Now, the CustomExtension will be able to cross network boundaries. Upon making an Arrow + /// // Flight request, it will be sent through gRPC metadata. + /// config.add_distributed_option_extension(opt).unwrap(); + /// + /// struct MyCustomSessionBuilder; + /// + /// #[async_trait] + /// impl SessionBuilder for MyCustomSessionBuilder { + /// async fn session_state(&self, mut state: SessionState) -> Result { + /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will + /// // know how to deserialize the CustomExtension from the gRPC metadata. + /// state.propagate_distributed_option_extension::()?; + /// Ok(state) + /// } + /// } + /// ``` + fn propagate_distributed_option_extension( + &mut self, + ) -> Result<(), DataFusionError>; +} + +impl ConfigExtensionExt for SessionConfig { + fn add_distributed_option_extension( + &mut self, + t: T, + ) -> Result<(), DataFusionError> { + fn parse_err(err: impl Error) -> DataFusionError { + DataFusionError::Internal(format!("Failed to add config extension: {err}")) + } + let mut meta = HeaderMap::new(); + + for entry in t.entries() { + if let Some(value) = entry.value { + meta.insert( + HeaderName::from_str(&format!( + "{}{}.{}", + FLIGHT_METADATA_PREFIX, + T::PREFIX, + entry.key + )) + .map_err(parse_err)?, + value.parse().map_err(parse_err)?, + ); + } + } + let flight_metadata = ContextGrpcMetadata(meta); + match self.get_extension::() { + None => self.set_extension(Arc::new(flight_metadata)), + Some(prev) => { + let prev = prev.as_ref().clone(); + self.set_extension(Arc::new(prev.merge(flight_metadata))) + } + } + self.options_mut().extensions.insert(t); + Ok(()) + } + + fn propagate_distributed_option_extension( + &mut self, + ) -> Result<(), DataFusionError> { + let Some(flight_metadata) = self.get_extension::() else { + return Ok(()); + }; + + let mut result = T::default(); + let mut found_some = false; + for (k, v) in flight_metadata.0.iter() { + let key = k.as_str().trim_start_matches(FLIGHT_METADATA_PREFIX); + if key.starts_with(T::PREFIX) { + found_some = true; + result.set( + &key.trim_start_matches(T::PREFIX).trim_start_matches("."), + v.to_str().map_err(|err| { + internal_datafusion_err!("Cannot parse header value: {err}") + })?, + )?; + } + } + if !found_some { + return Ok(()); + } + self.options_mut().extensions.insert(result); + Ok(()) + } +} + +impl ConfigExtensionExt for SessionStateBuilder { + delegate! { + to self.config().get_or_insert_default() { + fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; + fn propagate_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + } + } +} + +impl ConfigExtensionExt for SessionState { + delegate! { + to self.config_mut() { + fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; + fn propagate_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + } + } +} + +impl ConfigExtensionExt for SessionContext { + delegate! { + to self.state_ref().write().config_mut() { + fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; + fn propagate_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + } + } +} + +#[derive(Clone, Debug, Default)] +pub(crate) struct ContextGrpcMetadata(pub HeaderMap); + +impl ContextGrpcMetadata { + pub(crate) fn from_headers(metadata: HeaderMap) -> Self { + let mut new = HeaderMap::new(); + for (k, v) in metadata.into_iter() { + let Some(k) = k else { continue }; + if k.as_str().starts_with(FLIGHT_METADATA_PREFIX) { + new.insert(k, v); + } + } + Self(new) + } + + fn merge(mut self, other: Self) -> Self { + for (k, v) in other.0.into_iter() { + let Some(k) = k else { continue }; + self.0.insert(k, v); + } + self + } +} + +#[cfg(test)] +mod tests { + use crate::config_extension_ext::ContextGrpcMetadata; + use crate::ConfigExtensionExt; + use datafusion::common::extensions_options; + use datafusion::config::ConfigExtension; + use datafusion::prelude::SessionConfig; + use http::{HeaderMap, HeaderName, HeaderValue}; + use std::str::FromStr; + + #[test] + fn test_propagation() -> Result<(), Box> { + let mut config = SessionConfig::new(); + + let mut opt = CustomExtension::default(); + opt.foo = "foo".to_string(); + opt.bar = 1; + opt.baz = true; + + config.add_distributed_option_extension(opt)?; + + let mut new_config = SessionConfig::new(); + new_config.set_extension(config.get_extension::().unwrap()); + new_config.propagate_distributed_option_extension::()?; + + let opt = get_ext::(&config); + let new_opt = get_ext::(&new_config); + + assert_eq!(new_opt.foo, opt.foo); + assert_eq!(new_opt.bar, opt.bar); + assert_eq!(new_opt.baz, opt.baz); + + Ok(()) + } + + #[test] + fn test_add_extension_with_empty_values() -> Result<(), Box> { + let mut config = SessionConfig::new(); + let opt = CustomExtension::default(); + + config.add_distributed_option_extension(opt)?; + + let flight_metadata = config.get_extension::(); + assert!(flight_metadata.is_some()); + + let metadata = &flight_metadata.unwrap().0; + assert!(metadata.contains_key("x-datafusion-distributed-custom.foo")); + assert!(metadata.contains_key("x-datafusion-distributed-custom.bar")); + assert!(metadata.contains_key("x-datafusion-distributed-custom.baz")); + + let get = |key: &str| metadata.get(key).unwrap().to_str().unwrap(); + assert_eq!(get("x-datafusion-distributed-custom.foo"), ""); + assert_eq!(get("x-datafusion-distributed-custom.bar"), "0"); + assert_eq!(get("x-datafusion-distributed-custom.baz"), "false"); + + Ok(()) + } + + #[test] + fn test_new_extension_overwrites_previous() -> Result<(), Box> { + let mut config = SessionConfig::new(); + + let mut opt1 = CustomExtension::default(); + opt1.foo = "first".to_string(); + config.add_distributed_option_extension(opt1)?; + + let mut opt2 = CustomExtension::default(); + opt2.bar = 42; + config.add_distributed_option_extension(opt2)?; + + let flight_metadata = config.get_extension::().unwrap(); + let metadata = &flight_metadata.0; + + let get = |key: &str| metadata.get(key).unwrap().to_str().unwrap(); + assert_eq!(get("x-datafusion-distributed-custom.foo"), ""); + assert_eq!(get("x-datafusion-distributed-custom.bar"), "42"); + assert_eq!(get("x-datafusion-distributed-custom.baz"), "false"); + + Ok(()) + } + + #[test] + fn test_propagate_no_metadata() -> Result<(), Box> { + let mut config = SessionConfig::new(); + + config.propagate_distributed_option_extension::()?; + + let extension = config.options().extensions.get::(); + assert!(extension.is_none()); + + Ok(()) + } + + #[test] + fn test_propagate_no_matching_prefix() -> Result<(), Box> { + let mut config = SessionConfig::new(); + let mut header_map = HeaderMap::new(); + header_map.insert( + HeaderName::from_str("x-datafusion-distributed-other.setting").unwrap(), + HeaderValue::from_str("value").unwrap(), + ); + + let flight_metadata = ContextGrpcMetadata::from_headers(header_map); + config.set_extension(std::sync::Arc::new(flight_metadata)); + config.propagate_distributed_option_extension::()?; + + let extension = config.options().extensions.get::(); + assert!(extension.is_none()); + + Ok(()) + } + + #[test] + fn test_multiple_extensions_different_prefixes() -> Result<(), Box> { + let mut config = SessionConfig::new(); + + let mut custom_opt = CustomExtension::default(); + custom_opt.foo = "custom_value".to_string(); + custom_opt.bar = 123; + + let mut another_opt = AnotherExtension::default(); + another_opt.setting1 = "other".to_string(); + another_opt.setting2 = 456; + + config.add_distributed_option_extension(custom_opt)?; + config.add_distributed_option_extension(another_opt)?; + + let flight_metadata = config.get_extension::().unwrap(); + let metadata = &flight_metadata.0; + + assert!(metadata.contains_key("x-datafusion-distributed-custom.foo")); + assert!(metadata.contains_key("x-datafusion-distributed-custom.bar")); + assert!(metadata.contains_key("x-datafusion-distributed-another.setting1")); + assert!(metadata.contains_key("x-datafusion-distributed-another.setting2")); + + let get = |key: &str| metadata.get(key).unwrap().to_str().unwrap(); + + assert_eq!(get("x-datafusion-distributed-custom.foo"), "custom_value"); + assert_eq!(get("x-datafusion-distributed-custom.bar"), "123"); + assert_eq!(get("x-datafusion-distributed-another.setting1"), "other"); + assert_eq!(get("x-datafusion-distributed-another.setting2"), "456"); + + let mut new_config = SessionConfig::new(); + new_config.set_extension(flight_metadata); + new_config.propagate_distributed_option_extension::()?; + new_config.propagate_distributed_option_extension::()?; + + let propagated_custom = get_ext::(&new_config); + let propagated_another = get_ext::(&new_config); + + assert_eq!(propagated_custom.foo, "custom_value"); + assert_eq!(propagated_custom.bar, 123); + assert_eq!(propagated_another.setting1, "other"); + assert_eq!(propagated_another.setting2, 456); + + Ok(()) + } + + #[test] + fn test_invalid_header_name() { + let mut config = SessionConfig::new(); + let extension = InvalidExtension::default(); + + let result = config.add_distributed_option_extension(extension); + assert!(result.is_err()); + } + + #[test] + fn test_invalid_header_value() { + let mut config = SessionConfig::new(); + let extension = InvalidValueExtension::default(); + + let result = config.add_distributed_option_extension(extension); + assert!(result.is_err()); + } + + extensions_options! { + pub struct CustomExtension { + pub foo: String, default = "".to_string() + pub bar: usize, default = 0 + pub baz: bool, default = false + } + } + + impl ConfigExtension for CustomExtension { + const PREFIX: &'static str = "custom"; + } + + extensions_options! { + pub struct AnotherExtension { + pub setting1: String, default = "default1".to_string() + pub setting2: usize, default = 42 + } + } + + impl ConfigExtension for AnotherExtension { + const PREFIX: &'static str = "another"; + } + + extensions_options! { + pub struct InvalidExtension { + pub key_with_spaces: String, default = "value".to_string() + } + } + + impl ConfigExtension for InvalidExtension { + const PREFIX: &'static str = "invalid key with spaces"; + } + + extensions_options! { + pub struct InvalidValueExtension { + pub key: String, default = "\u{0000}invalid\u{0001}".to_string() + } + } + + impl ConfigExtension for InvalidValueExtension { + const PREFIX: &'static str = "invalid_value"; + } + + fn get_ext(cfg: &SessionConfig) -> &T { + cfg.options().extensions.get::().unwrap() + } +} diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index 38b7d25..2c14bba 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -1,4 +1,6 @@ +use super::service::StageKey; use crate::composed_extension_codec::ComposedPhysicalExtensionCodec; +use crate::config_extension_ext::ContextGrpcMetadata; use crate::errors::datafusion_error_to_tonic_status; use crate::flight_service::service::ArrowFlightEndpoint; use crate::plan::{DistributedCodec, PartitionGroup}; @@ -10,13 +12,13 @@ use arrow_flight::flight_service_server::FlightService; use arrow_flight::Ticket; use datafusion::execution::{SessionState, SessionStateBuilder}; use datafusion::optimizer::OptimizerConfig; +use datafusion::prelude::SessionConfig; use futures::TryStreamExt; use prost::Message; use std::sync::Arc; +use tonic::metadata::MetadataMap; use tonic::{Request, Response, Status}; -use super::service::StageKey; - #[derive(Clone, PartialEq, ::prost::Message)] pub struct DoGet { /// The ExecutionStage that we are going to execute @@ -41,14 +43,15 @@ impl ArrowFlightEndpoint { &self, request: Request, ) -> Result::DoGetStream>, Status> { - let Ticket { ticket } = request.into_inner(); + let (metadata, _ext, ticket) = request.into_parts(); + let Ticket { ticket } = ticket; let doget = DoGet::decode(ticket).map_err(|err| { Status::invalid_argument(format!("Cannot decode DoGet message: {err}")) })?; let partition = doget.partition as usize; let task_number = doget.task_number as usize; - let (mut state, stage) = self.get_state_and_stage(doget).await?; + let (mut state, stage) = self.get_state_and_stage(doget, metadata).await?; // find out which partition group we are executing let task = stage @@ -87,6 +90,7 @@ impl ArrowFlightEndpoint { async fn get_state_and_stage( &self, doget: DoGet, + metadata: MetadataMap, ) -> Result<(SessionState, Arc), Status> { let key = doget .stage_key @@ -102,9 +106,16 @@ impl ArrowFlightEndpoint { .stage_proto .ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?; + let mut config = SessionConfig::default(); + config.set_extension(Arc::new(ContextGrpcMetadata::from_headers( + metadata.into_headers(), + ))); + let state_builder = SessionStateBuilder::new() .with_runtime_env(Arc::clone(&self.runtime)) + .with_config(config) .with_default_features(); + let state_builder = self .session_builder .session_state_builder(state_builder) diff --git a/src/lib.rs b/src/lib.rs index 826bdf7..40515a9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ mod channel_manager; mod common; mod composed_extension_codec; +mod config_extension_ext; mod errors; mod flight_service; mod physical_optimizer; @@ -15,6 +16,7 @@ mod user_provided_codec; pub mod test_utils; pub use channel_manager::{BoxCloneSyncChannel, ChannelManager, ChannelResolver}; +pub use config_extension_ext::ConfigExtensionExt; pub use flight_service::{ArrowFlightEndpoint, NoopSessionBuilder, SessionBuilder}; pub use physical_optimizer::DistributedPhysicalOptimizerRule; pub use plan::ArrowFlightReadExec; diff --git a/src/plan/arrow_flight_read.rs b/src/plan/arrow_flight_read.rs index 62c5e76..bf3cb0b 100644 --- a/src/plan/arrow_flight_read.rs +++ b/src/plan/arrow_flight_read.rs @@ -1,6 +1,7 @@ use super::combined::CombinedRecordBatchStream; use crate::channel_manager::ChannelManager; use crate::composed_extension_codec::ComposedPhysicalExtensionCodec; +use crate::config_extension_ext::ContextGrpcMetadata; use crate::errors::tonic_status_to_datafusion_error; use crate::flight_service::{DoGet, StageKey}; use crate::plan::DistributedCodec; @@ -19,10 +20,13 @@ use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; use futures::{future, TryFutureExt, TryStreamExt}; +use http::Extensions; use prost::Message; use std::any::Any; use std::fmt::Formatter; use std::sync::Arc; +use tonic::metadata::MetadataMap; +use tonic::Request; use url::Url; /// This node has two variants. @@ -173,6 +177,10 @@ impl ExecutionPlan for ArrowFlightReadExec { this.stage_num ))?; + let flight_metadata = context + .session_config() + .get_extension::(); + let mut combined_codec = ComposedPhysicalExtensionCodec::default(); combined_codec.push(DistributedCodec {}); if let Some(ref user_codec) = get_user_codec(context.session_config()) { @@ -195,6 +203,7 @@ impl ExecutionPlan for ArrowFlightReadExec { let channel_manager_capture = channel_manager.clone(); let schema = schema.clone(); let query_id = query_id.clone(); + let flight_metadata = flight_metadata.clone().unwrap_or_default().as_ref().clone(); let key = StageKey { query_id, stage_id: child_stage_num, @@ -218,8 +227,14 @@ impl ExecutionPlan for ArrowFlightReadExec { ticket: ticket_bytes, }; - stream_from_stage_task(ticket, &url, schema.clone(), &channel_manager_capture) - .await + stream_from_stage_task( + ticket, + flight_metadata, + &url, + schema.clone(), + &channel_manager_capture, + ) + .await } }); @@ -240,12 +255,19 @@ impl ExecutionPlan for ArrowFlightReadExec { async fn stream_from_stage_task( ticket: Ticket, + metadata: ContextGrpcMetadata, url: &Url, schema: SchemaRef, channel_manager: &ChannelManager, ) -> Result { let channel = channel_manager.get_channel_for_url(url).await?; + let ticket = Request::from_parts( + MetadataMap::from_headers(metadata.0), + Extensions::default(), + ticket, + ); + let mut client = FlightServiceClient::new(channel); let stream = client .do_get(ticket) diff --git a/tests/custom_config_extension.rs b/tests/custom_config_extension.rs new file mode 100644 index 0000000..33f3fb9 --- /dev/null +++ b/tests/custom_config_extension.rs @@ -0,0 +1,180 @@ +#[cfg(all(feature = "integration", test))] +mod tests { + use async_trait::async_trait; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::common::{extensions_options, internal_err}; + use datafusion::config::ConfigExtension; + use datafusion::error::DataFusionError; + use datafusion::execution::{ + FunctionRegistry, SendableRecordBatchStream, SessionState, TaskContext, + }; + use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; + use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; + use datafusion::physical_plan::stream::RecordBatchStreamAdapter; + use datafusion::physical_plan::{ + execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, + }; + use datafusion_distributed::test_utils::localhost::start_localhost_context; + use datafusion_distributed::{add_user_codec, ConfigExtensionExt}; + use datafusion_distributed::{ + ArrowFlightReadExec, DistributedPhysicalOptimizerRule, SessionBuilder, + }; + use datafusion_proto::physical_plan::PhysicalExtensionCodec; + use futures::TryStreamExt; + use prost::Message; + use std::any::Any; + use std::fmt::Formatter; + use std::sync::Arc; + + #[tokio::test] + async fn custom_config_extension() -> Result<(), Box> { + #[derive(Clone)] + struct CustomSessionBuilder; + + #[async_trait] + impl SessionBuilder for CustomSessionBuilder { + async fn session_state( + &self, + mut state: SessionState, + ) -> Result { + state.propagate_distributed_option_extension::()?; + add_user_codec(state.config_mut(), CustomConfigExtensionRequiredExecCodec); + Ok(state) + } + } + + let (mut ctx, _guard) = start_localhost_context(3, CustomSessionBuilder).await; + add_user_codec(&mut ctx, CustomConfigExtensionRequiredExecCodec); + ctx.add_distributed_option_extension(CustomExtension { + foo: "foo".to_string(), + bar: 1, + baz: true, + })?; + + let mut plan: Arc = Arc::new(CustomConfigExtensionRequiredExec::new()); + + for size in [1, 2, 3] { + plan = Arc::new(ArrowFlightReadExec::new_pending( + plan, + Partitioning::RoundRobinBatch(size), + )); + } + + let plan = DistributedPhysicalOptimizerRule::default().distribute_plan(plan)?; + let stream = execute_stream(Arc::new(plan), ctx.task_ctx())?; + // It should not fail. + stream.try_collect::>().await?; + + Ok(()) + } + + extensions_options! { + pub struct CustomExtension { + pub foo: String, default = "".to_string() + pub bar: usize, default = 0 + pub baz: bool, default = false + } + } + + impl ConfigExtension for CustomExtension { + const PREFIX: &'static str = "custom"; + } + + #[derive(Debug)] + pub struct CustomConfigExtensionRequiredExec { + plan_properties: PlanProperties, + } + + impl CustomConfigExtensionRequiredExec { + fn new() -> Self { + let schema = Schema::new(vec![Field::new("numbers", DataType::Int64, false)]); + Self { + plan_properties: PlanProperties::new( + EquivalenceProperties::new(Arc::new(schema)), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ), + } + } + } + + impl DisplayAs for CustomConfigExtensionRequiredExec { + fn fmt_as(&self, _: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!(f, "CustomConfigExtensionRequiredExec") + } + } + + impl ExecutionPlan for CustomConfigExtensionRequiredExec { + fn name(&self) -> &str { + "CustomConfigExtensionRequiredExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.plan_properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> datafusion::common::Result> { + Ok(self) + } + + fn execute( + &self, + _: usize, + ctx: Arc, + ) -> datafusion::common::Result { + if ctx + .session_config() + .options() + .extensions + .get::() + .is_none() + { + return internal_err!("CustomExtension not found in context"); + } + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + futures::stream::empty(), + ))) + } + } + + #[derive(Debug)] + struct CustomConfigExtensionRequiredExecCodec; + + #[derive(Clone, PartialEq, ::prost::Message)] + struct CustomConfigExtensionRequiredExecProto {} + + impl PhysicalExtensionCodec for CustomConfigExtensionRequiredExecCodec { + fn try_decode( + &self, + _buf: &[u8], + _: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> datafusion::common::Result> { + Ok(Arc::new(CustomConfigExtensionRequiredExec::new())) + } + + fn try_encode( + &self, + _node: Arc, + buf: &mut Vec, + ) -> datafusion::common::Result<()> { + CustomConfigExtensionRequiredExecProto::default() + .encode(buf) + .unwrap(); + Ok(()) + } + } +} From ae8b6c97d9ef312804034545014356628b006cce Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Tue, 19 Aug 2025 16:53:05 +0200 Subject: [PATCH 2/9] Rename propagate_distributed_option_extension to retrieve_distributed_option_extension --- src/config_extension_ext.rs | 24 ++++++++++++------------ tests/custom_config_extension.rs | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/config_extension_ext.rs b/src/config_extension_ext.rs index 3daf7f0..449a28a 100644 --- a/src/config_extension_ext.rs +++ b/src/config_extension_ext.rs @@ -56,7 +56,7 @@ pub trait ConfigExtensionExt { /// async fn session_state(&self, mut state: SessionState) -> Result { /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will /// // know how to deserialize the CustomExtension from the gRPC metadata. - /// state.propagate_distributed_option_extension::()?; + /// state.retrieve_distributed_option_extension::()?; /// Ok(state) /// } /// } @@ -105,12 +105,12 @@ pub trait ConfigExtensionExt { /// async fn session_state(&self, mut state: SessionState) -> Result { /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will /// // know how to deserialize the CustomExtension from the gRPC metadata. - /// state.propagate_distributed_option_extension::()?; + /// state.retrieve_distributed_option_extension::()?; /// Ok(state) /// } /// } /// ``` - fn propagate_distributed_option_extension( + fn retrieve_distributed_option_extension( &mut self, ) -> Result<(), DataFusionError>; } @@ -151,7 +151,7 @@ impl ConfigExtensionExt for SessionConfig { Ok(()) } - fn propagate_distributed_option_extension( + fn retrieve_distributed_option_extension( &mut self, ) -> Result<(), DataFusionError> { let Some(flight_metadata) = self.get_extension::() else { @@ -184,7 +184,7 @@ impl ConfigExtensionExt for SessionStateBuilder { delegate! { to self.config().get_or_insert_default() { fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; - fn propagate_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + fn retrieve_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; } } } @@ -193,7 +193,7 @@ impl ConfigExtensionExt for SessionState { delegate! { to self.config_mut() { fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; - fn propagate_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + fn retrieve_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; } } } @@ -202,7 +202,7 @@ impl ConfigExtensionExt for SessionContext { delegate! { to self.state_ref().write().config_mut() { fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; - fn propagate_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + fn retrieve_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; } } } @@ -254,7 +254,7 @@ mod tests { let mut new_config = SessionConfig::new(); new_config.set_extension(config.get_extension::().unwrap()); - new_config.propagate_distributed_option_extension::()?; + new_config.retrieve_distributed_option_extension::()?; let opt = get_ext::(&config); let new_opt = get_ext::(&new_config); @@ -316,7 +316,7 @@ mod tests { fn test_propagate_no_metadata() -> Result<(), Box> { let mut config = SessionConfig::new(); - config.propagate_distributed_option_extension::()?; + config.retrieve_distributed_option_extension::()?; let extension = config.options().extensions.get::(); assert!(extension.is_none()); @@ -335,7 +335,7 @@ mod tests { let flight_metadata = ContextGrpcMetadata::from_headers(header_map); config.set_extension(std::sync::Arc::new(flight_metadata)); - config.propagate_distributed_option_extension::()?; + config.retrieve_distributed_option_extension::()?; let extension = config.options().extensions.get::(); assert!(extension.is_none()); @@ -375,8 +375,8 @@ mod tests { let mut new_config = SessionConfig::new(); new_config.set_extension(flight_metadata); - new_config.propagate_distributed_option_extension::()?; - new_config.propagate_distributed_option_extension::()?; + new_config.retrieve_distributed_option_extension::()?; + new_config.retrieve_distributed_option_extension::()?; let propagated_custom = get_ext::(&new_config); let propagated_another = get_ext::(&new_config); diff --git a/tests/custom_config_extension.rs b/tests/custom_config_extension.rs index 33f3fb9..d671d90 100644 --- a/tests/custom_config_extension.rs +++ b/tests/custom_config_extension.rs @@ -37,7 +37,7 @@ mod tests { &self, mut state: SessionState, ) -> Result { - state.propagate_distributed_option_extension::()?; + state.retrieve_distributed_option_extension::()?; add_user_codec(state.config_mut(), CustomConfigExtensionRequiredExecCodec); Ok(state) } From 9ac83b0b5b4d45464fcb6cfe8fe0e66650a5d6ef Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Tue, 19 Aug 2025 16:55:52 +0200 Subject: [PATCH 3/9] Change x-datafusion-distributed- to x-datafusion-distributed-config- --- src/config_extension_ext.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/config_extension_ext.rs b/src/config_extension_ext.rs index 449a28a..6476ce8 100644 --- a/src/config_extension_ext.rs +++ b/src/config_extension_ext.rs @@ -8,7 +8,7 @@ use std::error::Error; use std::str::FromStr; use std::sync::Arc; -const FLIGHT_METADATA_PREFIX: &str = "x-datafusion-distributed-"; +const FLIGHT_METADATA_CONFIG_PREFIX: &str = "x-datafusion-distributed-config-"; /// Extension trait for `SessionConfig` to add support for propagating [ConfigExtension]s across /// network calls. @@ -130,7 +130,7 @@ impl ConfigExtensionExt for SessionConfig { meta.insert( HeaderName::from_str(&format!( "{}{}.{}", - FLIGHT_METADATA_PREFIX, + FLIGHT_METADATA_CONFIG_PREFIX, T::PREFIX, entry.key )) @@ -161,7 +161,7 @@ impl ConfigExtensionExt for SessionConfig { let mut result = T::default(); let mut found_some = false; for (k, v) in flight_metadata.0.iter() { - let key = k.as_str().trim_start_matches(FLIGHT_METADATA_PREFIX); + let key = k.as_str().trim_start_matches(FLIGHT_METADATA_CONFIG_PREFIX); if key.starts_with(T::PREFIX) { found_some = true; result.set( @@ -215,7 +215,7 @@ impl ContextGrpcMetadata { let mut new = HeaderMap::new(); for (k, v) in metadata.into_iter() { let Some(k) = k else { continue }; - if k.as_str().starts_with(FLIGHT_METADATA_PREFIX) { + if k.as_str().starts_with(FLIGHT_METADATA_CONFIG_PREFIX) { new.insert(k, v); } } From e15796e3e97dadec31eaf148f07fd7f965ffb26b Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Tue, 19 Aug 2025 16:59:35 +0200 Subject: [PATCH 4/9] Check for full format!("{}.", T::PREFIX) in gRPC keys --- src/config_extension_ext.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/config_extension_ext.rs b/src/config_extension_ext.rs index 6476ce8..5a26e15 100644 --- a/src/config_extension_ext.rs +++ b/src/config_extension_ext.rs @@ -162,10 +162,11 @@ impl ConfigExtensionExt for SessionConfig { let mut found_some = false; for (k, v) in flight_metadata.0.iter() { let key = k.as_str().trim_start_matches(FLIGHT_METADATA_CONFIG_PREFIX); - if key.starts_with(T::PREFIX) { + let prefix = format!("{}.", T::PREFIX); + if key.starts_with(&prefix) { found_some = true; result.set( - &key.trim_start_matches(T::PREFIX).trim_start_matches("."), + &key.trim_start_matches(&prefix), v.to_str().map_err(|err| { internal_datafusion_err!("Cannot parse header value: {err}") })?, From 104e3fa46da5b1af78396d9fd14a7ef49fcf7377 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Tue, 19 Aug 2025 17:04:24 +0200 Subject: [PATCH 5/9] Remove double clone --- src/plan/arrow_flight_read.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/plan/arrow_flight_read.rs b/src/plan/arrow_flight_read.rs index bf3cb0b..36d7786 100644 --- a/src/plan/arrow_flight_read.rs +++ b/src/plan/arrow_flight_read.rs @@ -203,7 +203,10 @@ impl ExecutionPlan for ArrowFlightReadExec { let channel_manager_capture = channel_manager.clone(); let schema = schema.clone(); let query_id = query_id.clone(); - let flight_metadata = flight_metadata.clone().unwrap_or_default().as_ref().clone(); + let flight_metadata = flight_metadata + .as_ref() + .map(|v| v.as_ref().clone()) + .unwrap_or_default(); let key = StageKey { query_id, stage_id: child_stage_num, From 8576c2c64fc10f881d75a0c2e9d9972b1352aae8 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Tue, 19 Aug 2025 17:23:36 +0200 Subject: [PATCH 6/9] Fix tests --- src/config_extension_ext.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/config_extension_ext.rs b/src/config_extension_ext.rs index 5a26e15..22beb83 100644 --- a/src/config_extension_ext.rs +++ b/src/config_extension_ext.rs @@ -283,9 +283,9 @@ mod tests { assert!(metadata.contains_key("x-datafusion-distributed-custom.baz")); let get = |key: &str| metadata.get(key).unwrap().to_str().unwrap(); - assert_eq!(get("x-datafusion-distributed-custom.foo"), ""); - assert_eq!(get("x-datafusion-distributed-custom.bar"), "0"); - assert_eq!(get("x-datafusion-distributed-custom.baz"), "false"); + assert_eq!(get("x-datafusion-distributed-config-custom.foo"), ""); + assert_eq!(get("x-datafusion-distributed-config-custom.bar"), "0"); + assert_eq!(get("x-datafusion-distributed-config-custom.baz"), "false"); Ok(()) } @@ -306,9 +306,9 @@ mod tests { let metadata = &flight_metadata.0; let get = |key: &str| metadata.get(key).unwrap().to_str().unwrap(); - assert_eq!(get("x-datafusion-distributed-custom.foo"), ""); - assert_eq!(get("x-datafusion-distributed-custom.bar"), "42"); - assert_eq!(get("x-datafusion-distributed-custom.baz"), "false"); + assert_eq!(get("x-datafusion-distributed-config-custom.foo"), ""); + assert_eq!(get("x-datafusion-distributed-config-custom.bar"), "42"); + assert_eq!(get("x-datafusion-distributed-config-custom.baz"), "false"); Ok(()) } From d35ddff726f5f2f27cee451334ba93d9f1fbe761 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Tue, 19 Aug 2025 17:24:28 +0200 Subject: [PATCH 7/9] Fix tests --- src/config_extension_ext.rs | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/src/config_extension_ext.rs b/src/config_extension_ext.rs index 22beb83..41f5141 100644 --- a/src/config_extension_ext.rs +++ b/src/config_extension_ext.rs @@ -278,9 +278,9 @@ mod tests { assert!(flight_metadata.is_some()); let metadata = &flight_metadata.unwrap().0; - assert!(metadata.contains_key("x-datafusion-distributed-custom.foo")); - assert!(metadata.contains_key("x-datafusion-distributed-custom.bar")); - assert!(metadata.contains_key("x-datafusion-distributed-custom.baz")); + assert!(metadata.contains_key("x-datafusion-distributed-config-custom.foo")); + assert!(metadata.contains_key("x-datafusion-distributed-config-custom.bar")); + assert!(metadata.contains_key("x-datafusion-distributed-config-custom.baz")); let get = |key: &str| metadata.get(key).unwrap().to_str().unwrap(); assert_eq!(get("x-datafusion-distributed-config-custom.foo"), ""); @@ -362,17 +362,26 @@ mod tests { let flight_metadata = config.get_extension::().unwrap(); let metadata = &flight_metadata.0; - assert!(metadata.contains_key("x-datafusion-distributed-custom.foo")); - assert!(metadata.contains_key("x-datafusion-distributed-custom.bar")); - assert!(metadata.contains_key("x-datafusion-distributed-another.setting1")); - assert!(metadata.contains_key("x-datafusion-distributed-another.setting2")); + assert!(metadata.contains_key("x-datafusion-distributed-config-custom.foo")); + assert!(metadata.contains_key("x-datafusion-distributed-config-custom.bar")); + assert!(metadata.contains_key("x-datafusion-distributed-config-another.setting1")); + assert!(metadata.contains_key("x-datafusion-distributed-config-another.setting2")); let get = |key: &str| metadata.get(key).unwrap().to_str().unwrap(); - assert_eq!(get("x-datafusion-distributed-custom.foo"), "custom_value"); - assert_eq!(get("x-datafusion-distributed-custom.bar"), "123"); - assert_eq!(get("x-datafusion-distributed-another.setting1"), "other"); - assert_eq!(get("x-datafusion-distributed-another.setting2"), "456"); + assert_eq!( + get("x-datafusion-distributed-config-custom.foo"), + "custom_value" + ); + assert_eq!(get("x-datafusion-distributed-config-custom.bar"), "123"); + assert_eq!( + get("x-datafusion-distributed-config-another.setting1"), + "other" + ); + assert_eq!( + get("x-datafusion-distributed-config-another.setting2"), + "456" + ); let mut new_config = SessionConfig::new(); new_config.set_extension(flight_metadata); From 25d24d2a01e7cb17ef3202d957590f34b0df30d0 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Thu, 21 Aug 2025 14:28:18 +0200 Subject: [PATCH 8/9] Rework SessionBuilder --- benchmarks/src/tpch/run.rs | 29 ++--- src/common/mod.rs | 1 + src/config_extension_ext.rs | 66 +++++------ src/flight_service/do_get.rs | 51 +++------ src/flight_service/mod.rs | 4 +- src/flight_service/service.rs | 10 +- src/flight_service/session_builder.rs | 130 ++++++++++------------ src/lib.rs | 5 +- src/test_utils/localhost.rs | 37 +++--- tests/custom_config_extension.rs | 33 +++--- tests/custom_extension_codec.rs | 26 ++--- tests/distributed_aggregation.rs | 4 +- tests/error_propagation.rs | 26 +++-- tests/highly_distributed_query.rs | 4 +- tests/non_distributed_consistency_test.rs | 74 ++++++------ 15 files changed, 218 insertions(+), 282 deletions(-) diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index 507e408..f3cd2bb 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -36,14 +36,16 @@ use datafusion::datasource::listing::{ }; use datafusion::datasource::{MemTable, TableProvider}; use datafusion::error::{DataFusionError, Result}; -use datafusion::execution::SessionStateBuilder; +use datafusion::execution::{SessionState, SessionStateBuilder}; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::{collect, displayable}; use datafusion::prelude::*; use crate::util::{print_memory_stats, BenchmarkRun, CommonOpt, QueryResult}; use datafusion_distributed::test_utils::localhost::start_localhost_context; -use datafusion_distributed::{DistributedPhysicalOptimizerRule, SessionBuilder}; +use datafusion_distributed::{ + DistributedPhysicalOptimizerRule, DistributedSessionBuilder, DistributedSessionBuilderContext, +}; use log::info; use structopt::StructOpt; @@ -110,11 +112,13 @@ pub struct RunOpt { } #[async_trait] -impl SessionBuilder for RunOpt { - fn session_state_builder( +impl DistributedSessionBuilder for RunOpt { + async fn build_session_state( &self, - mut builder: SessionStateBuilder, - ) -> Result { + _ctx: DistributedSessionBuilderContext, + ) -> Result { + let mut builder = SessionStateBuilder::new().with_default_features(); + let mut config = self .common .config()? @@ -145,17 +149,14 @@ impl SessionBuilder for RunOpt { builder = builder.with_physical_optimizer_rule(Arc::new(rule)); } - Ok(builder + let state = builder .with_config(config) - .with_runtime_env(rt_builder.build_arc()?)) - } + .with_runtime_env(rt_builder.build_arc()?) + .build(); - async fn session_context( - &self, - ctx: SessionContext, - ) -> std::result::Result { + let ctx = SessionContext::from(state); self.register_tables(&ctx).await?; - Ok(ctx) + Ok(ctx.state()) } } diff --git a/src/common/mod.rs b/src/common/mod.rs index 572b996..6e77e1a 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,2 +1,3 @@ +#[allow(unused)] pub mod ttl_map; pub mod util; diff --git a/src/config_extension_ext.rs b/src/config_extension_ext.rs index 41f5141..eb927e0 100644 --- a/src/config_extension_ext.rs +++ b/src/config_extension_ext.rs @@ -27,9 +27,9 @@ pub trait ConfigExtensionExt { /// # use async_trait::async_trait; /// # use datafusion::common::{extensions_options, DataFusionError}; /// # use datafusion::config::ConfigExtension; - /// # use datafusion::execution::SessionState; + /// # use datafusion::execution::{SessionState, SessionStateBuilder}; /// # use datafusion::prelude::SessionConfig; - /// # use datafusion_distributed::{ConfigExtensionExt, SessionBuilder}; + /// # use datafusion_distributed::{ConfigExtensionExt, DistributedSessionBuilder, DistributedSessionBuilderContext}; /// /// extensions_options! { /// pub struct CustomExtension { @@ -52,11 +52,13 @@ pub trait ConfigExtensionExt { /// struct MyCustomSessionBuilder; /// /// #[async_trait] - /// impl SessionBuilder for MyCustomSessionBuilder { - /// async fn session_state(&self, mut state: SessionState) -> Result { + /// impl DistributedSessionBuilder for MyCustomSessionBuilder { + /// async fn build_session_state(&self, ctx: DistributedSessionBuilderContext) -> Result { + /// let mut state = SessionStateBuilder::new().build(); + /// /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will /// // know how to deserialize the CustomExtension from the gRPC metadata. - /// state.retrieve_distributed_option_extension::()?; + /// state.retrieve_distributed_option_extension::(&ctx.headers)?; /// Ok(state) /// } /// } @@ -76,9 +78,9 @@ pub trait ConfigExtensionExt { /// # use async_trait::async_trait; /// # use datafusion::common::{extensions_options, DataFusionError}; /// # use datafusion::config::ConfigExtension; - /// # use datafusion::execution::SessionState; + /// # use datafusion::execution::{SessionState, SessionStateBuilder}; /// # use datafusion::prelude::SessionConfig; - /// # use datafusion_distributed::{ConfigExtensionExt, SessionBuilder}; + /// # use datafusion_distributed::{ConfigExtensionExt, DistributedSessionBuilder, DistributedSessionBuilderContext}; /// /// extensions_options! { /// pub struct CustomExtension { @@ -101,17 +103,19 @@ pub trait ConfigExtensionExt { /// struct MyCustomSessionBuilder; /// /// #[async_trait] - /// impl SessionBuilder for MyCustomSessionBuilder { - /// async fn session_state(&self, mut state: SessionState) -> Result { + /// impl DistributedSessionBuilder for MyCustomSessionBuilder { + /// async fn build_session_state(&self, ctx: DistributedSessionBuilderContext) -> Result { + /// let mut state = SessionStateBuilder::new().build(); /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will /// // know how to deserialize the CustomExtension from the gRPC metadata. - /// state.retrieve_distributed_option_extension::()?; + /// state.retrieve_distributed_option_extension::(&ctx.headers)?; /// Ok(state) /// } /// } /// ``` fn retrieve_distributed_option_extension( &mut self, + headers: &HeaderMap, ) -> Result<(), DataFusionError>; } @@ -153,14 +157,11 @@ impl ConfigExtensionExt for SessionConfig { fn retrieve_distributed_option_extension( &mut self, + headers: &HeaderMap, ) -> Result<(), DataFusionError> { - let Some(flight_metadata) = self.get_extension::() else { - return Ok(()); - }; - let mut result = T::default(); let mut found_some = false; - for (k, v) in flight_metadata.0.iter() { + for (k, v) in headers.iter() { let key = k.as_str().trim_start_matches(FLIGHT_METADATA_CONFIG_PREFIX); let prefix = format!("{}.", T::PREFIX); if key.starts_with(&prefix) { @@ -185,7 +186,7 @@ impl ConfigExtensionExt for SessionStateBuilder { delegate! { to self.config().get_or_insert_default() { fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; - fn retrieve_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + fn retrieve_distributed_option_extension(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>; } } } @@ -194,7 +195,7 @@ impl ConfigExtensionExt for SessionState { delegate! { to self.config_mut() { fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; - fn retrieve_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + fn retrieve_distributed_option_extension(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>; } } } @@ -203,7 +204,7 @@ impl ConfigExtensionExt for SessionContext { delegate! { to self.state_ref().write().config_mut() { fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; - fn retrieve_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + fn retrieve_distributed_option_extension(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>; } } } @@ -212,17 +213,6 @@ impl ConfigExtensionExt for SessionContext { pub(crate) struct ContextGrpcMetadata(pub HeaderMap); impl ContextGrpcMetadata { - pub(crate) fn from_headers(metadata: HeaderMap) -> Self { - let mut new = HeaderMap::new(); - for (k, v) in metadata.into_iter() { - let Some(k) = k else { continue }; - if k.as_str().starts_with(FLIGHT_METADATA_CONFIG_PREFIX) { - new.insert(k, v); - } - } - Self(new) - } - fn merge(mut self, other: Self) -> Self { for (k, v) in other.0.into_iter() { let Some(k) = k else { continue }; @@ -252,10 +242,9 @@ mod tests { opt.baz = true; config.add_distributed_option_extension(opt)?; - + let metadata = config.get_extension::().unwrap(); let mut new_config = SessionConfig::new(); - new_config.set_extension(config.get_extension::().unwrap()); - new_config.retrieve_distributed_option_extension::()?; + new_config.retrieve_distributed_option_extension::(&metadata.0)?; let opt = get_ext::(&config); let new_opt = get_ext::(&new_config); @@ -317,7 +306,7 @@ mod tests { fn test_propagate_no_metadata() -> Result<(), Box> { let mut config = SessionConfig::new(); - config.retrieve_distributed_option_extension::()?; + config.retrieve_distributed_option_extension::(&Default::default())?; let extension = config.options().extensions.get::(); assert!(extension.is_none()); @@ -330,13 +319,11 @@ mod tests { let mut config = SessionConfig::new(); let mut header_map = HeaderMap::new(); header_map.insert( - HeaderName::from_str("x-datafusion-distributed-other.setting").unwrap(), + HeaderName::from_str("x-datafusion-distributed-config-other.setting").unwrap(), HeaderValue::from_str("value").unwrap(), ); - let flight_metadata = ContextGrpcMetadata::from_headers(header_map); - config.set_extension(std::sync::Arc::new(flight_metadata)); - config.retrieve_distributed_option_extension::()?; + config.retrieve_distributed_option_extension::(&header_map)?; let extension = config.options().extensions.get::(); assert!(extension.is_none()); @@ -384,9 +371,8 @@ mod tests { ); let mut new_config = SessionConfig::new(); - new_config.set_extension(flight_metadata); - new_config.retrieve_distributed_option_extension::()?; - new_config.retrieve_distributed_option_extension::()?; + new_config.retrieve_distributed_option_extension::(&metadata)?; + new_config.retrieve_distributed_option_extension::(&metadata)?; let propagated_custom = get_ext::(&new_config); let propagated_another = get_ext::(&new_config); diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index 2c14bba..0be1c5b 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -3,6 +3,7 @@ use crate::composed_extension_codec::ComposedPhysicalExtensionCodec; use crate::config_extension_ext::ContextGrpcMetadata; use crate::errors::datafusion_error_to_tonic_status; use crate::flight_service::service::ArrowFlightEndpoint; +use crate::flight_service::session_builder::DistributedSessionBuilderContext; use crate::plan::{DistributedCodec, PartitionGroup}; use crate::stage::{stage_from_proto, ExecutionStage, ExecutionStageProto}; use crate::user_provided_codec::get_user_codec; @@ -10,9 +11,7 @@ use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; use arrow_flight::flight_service_server::FlightService; use arrow_flight::Ticket; -use datafusion::execution::{SessionState, SessionStateBuilder}; -use datafusion::optimizer::OptimizerConfig; -use datafusion::prelude::SessionConfig; +use datafusion::execution::SessionState; use futures::TryStreamExt; use prost::Message; use std::sync::Arc; @@ -90,7 +89,7 @@ impl ArrowFlightEndpoint { async fn get_state_and_stage( &self, doget: DoGet, - metadata: MetadataMap, + metadata_map: MetadataMap, ) -> Result<(SessionState, Arc), Status> { let key = doget .stage_key @@ -106,54 +105,34 @@ impl ArrowFlightEndpoint { .stage_proto .ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?; - let mut config = SessionConfig::default(); - config.set_extension(Arc::new(ContextGrpcMetadata::from_headers( - metadata.into_headers(), - ))); - - let state_builder = SessionStateBuilder::new() - .with_runtime_env(Arc::clone(&self.runtime)) - .with_config(config) - .with_default_features(); - - let state_builder = self - .session_builder - .session_state_builder(state_builder) - .map_err(|err| datafusion_error_to_tonic_status(&err))?; - - let state = state_builder.build(); + let headers = metadata_map.into_headers(); let mut state = self .session_builder - .session_state(state) + .build_session_state(DistributedSessionBuilderContext { + runtime_env: Arc::clone(&self.runtime), + headers: headers.clone(), + }) .await .map_err(|err| datafusion_error_to_tonic_status(&err))?; - let function_registry = - state.function_registry().ok_or(Status::invalid_argument( - "FunctionRegistry not present in newly built SessionState", - ))?; - let mut combined_codec = ComposedPhysicalExtensionCodec::default(); combined_codec.push(DistributedCodec); if let Some(ref user_codec) = get_user_codec(state.config()) { combined_codec.push_arc(Arc::clone(user_codec)); } - let stage = stage_from_proto( - stage_proto, - function_registry, - self.runtime.as_ref(), - &combined_codec, - ) - .map(Arc::new) - .map_err(|err| { - Status::invalid_argument(format!("Cannot decode stage proto: {err}")) - })?; + let stage = + stage_from_proto(stage_proto, &state, self.runtime.as_ref(), &combined_codec) + .map(Arc::new) + .map_err(|err| { + Status::invalid_argument(format!("Cannot decode stage proto: {err}")) + })?; // Add the extensions that might be required for ExecutionPlan nodes in the plan let config = state.config_mut(); config.set_extension(Arc::clone(&self.channel_manager)); config.set_extension(stage.clone()); + config.set_extension(Arc::new(ContextGrpcMetadata(headers))); Ok::<_, Status>((state, stage)) }) diff --git a/src/flight_service/mod.rs b/src/flight_service/mod.rs index e3b53c8..b389074 100644 --- a/src/flight_service/mod.rs +++ b/src/flight_service/mod.rs @@ -5,4 +5,6 @@ mod session_builder; pub(crate) use do_get::DoGet; pub use service::{ArrowFlightEndpoint, StageKey}; -pub use session_builder::{NoopSessionBuilder, SessionBuilder}; +pub use session_builder::{ + DefaultSessionBuilder, DistributedSessionBuilder, DistributedSessionBuilderContext, +}; diff --git a/src/flight_service/service.rs b/src/flight_service/service.rs index fc19269..78e6bdb 100644 --- a/src/flight_service/service.rs +++ b/src/flight_service/service.rs @@ -1,6 +1,6 @@ use crate::channel_manager::ChannelManager; -use crate::flight_service::session_builder::NoopSessionBuilder; -use crate::flight_service::SessionBuilder; +use crate::flight_service::session_builder::DefaultSessionBuilder; +use crate::flight_service::DistributedSessionBuilder; use crate::stage::ExecutionStage; use crate::ChannelResolver; use arrow_flight::flight_service_server::FlightService; @@ -36,7 +36,7 @@ pub struct ArrowFlightEndpoint { pub(super) runtime: Arc, #[allow(clippy::type_complexity)] pub(super) stages: DashMap)>>>, - pub(super) session_builder: Arc, + pub(super) session_builder: Arc, } impl ArrowFlightEndpoint { @@ -45,13 +45,13 @@ impl ArrowFlightEndpoint { channel_manager: Arc::new(ChannelManager::new(channel_resolver)), runtime: Arc::new(RuntimeEnv::default()), stages: DashMap::new(), - session_builder: Arc::new(NoopSessionBuilder), + session_builder: Arc::new(DefaultSessionBuilder), } } pub fn with_session_builder( &mut self, - session_builder: impl SessionBuilder + Send + Sync + 'static, + session_builder: impl DistributedSessionBuilder + Send + Sync + 'static, ) { self.session_builder = Arc::new(session_builder); } diff --git a/src/flight_service/session_builder.rs b/src/flight_service/session_builder.rs index 64be9dd..3decda0 100644 --- a/src/flight_service/session_builder.rs +++ b/src/flight_service/session_builder.rs @@ -1,27 +1,34 @@ use async_trait::async_trait; use datafusion::error::DataFusionError; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::{SessionState, SessionStateBuilder}; -use datafusion::prelude::SessionContext; +use http::HeaderMap; +use std::sync::Arc; + +#[derive(Debug, Clone, Default)] +pub struct DistributedSessionBuilderContext { + pub runtime_env: Arc, + pub headers: HeaderMap, +} /// Trait called by the Arrow Flight endpoint that handles distributed parts of a DataFusion -/// plan for building a DataFusion's [datafusion::prelude::SessionContext]. +/// plan for building a DataFusion's [SessionState]. #[async_trait] -pub trait SessionBuilder { - /// Takes a [SessionStateBuilder] and adds whatever is necessary for it to work, like - /// custom extension codecs, custom physical optimization rules, UDFs, UDAFs, config - /// extensions, etc... +pub trait DistributedSessionBuilder { + /// Builds a custom [SessionState] scoped to a single ArrowFlight gRPC call, allowing the + /// users to provide a customized DataFusion session with things like custom extension codecs, + /// custom physical optimization rules, UDFs, UDAFs, config extensions, etc... /// - /// Example: adding some custom extension plan codecs + /// Example: /// /// ```rust /// # use std::sync::Arc; /// # use async_trait::async_trait; /// # use datafusion::error::DataFusionError; - /// # use datafusion::execution::runtime_env::RuntimeEnv; - /// # use datafusion::execution::{FunctionRegistry, SessionStateBuilder}; + /// # use datafusion::execution::{FunctionRegistry, SessionState, SessionStateBuilder}; /// # use datafusion::physical_plan::ExecutionPlan; /// # use datafusion_proto::physical_plan::PhysicalExtensionCodec; - /// # use datafusion_distributed::{with_user_codec, SessionBuilder}; + /// # use datafusion_distributed::{with_user_codec, DistributedSessionBuilder, DistributedSessionBuilderContext}; /// /// #[derive(Debug)] /// struct CustomExecCodec; @@ -40,79 +47,54 @@ pub trait SessionBuilder { /// struct CustomSessionBuilder; /// /// #[async_trait] - /// impl SessionBuilder for CustomSessionBuilder { - /// fn session_state_builder(&self, mut builder: SessionStateBuilder) -> Result { - /// // Add your UDFs, optimization rules, etc... - /// Ok(with_user_codec(builder, CustomExecCodec)) - /// } - /// } - /// ``` - fn session_state_builder( - &self, - builder: SessionStateBuilder, - ) -> Result { - Ok(builder) - } - - /// Modifies the [SessionState] and returns it. Same as [SessionBuilder::session_state_builder] - /// but operating on an already built [SessionState]. - /// - /// Example: - /// - /// ```rust - /// # use async_trait::async_trait; - /// # use datafusion::common::DataFusionError; - /// # use datafusion::execution::SessionState; - /// # use datafusion_distributed::SessionBuilder; + /// impl DistributedSessionBuilder for CustomSessionBuilder { + /// async fn build_session_state(&self, ctx: DistributedSessionBuilderContext) -> Result { + /// let builder = SessionStateBuilder::new() + /// .with_runtime_env(ctx.runtime_env.clone()) + /// .with_default_features(); /// - /// #[derive(Clone)] - /// struct CustomSessionBuilder; + /// let builder = with_user_codec(builder, CustomExecCodec); /// - /// #[async_trait] - /// impl SessionBuilder for CustomSessionBuilder { - /// async fn session_state(&self, state: SessionState) -> Result { - /// // mutate the state adding any custom logic - /// Ok(state) - /// } - /// } - /// ``` - async fn session_state(&self, state: SessionState) -> Result { - Ok(state) - } - - /// Modifies the [SessionContext] and returns it. Same as [SessionBuilder::session_state_builder] - /// or [SessionBuilder::session_state] but operation on an already built [SessionContext]. - /// - /// Example: - /// - /// ```rust - /// # use async_trait::async_trait; - /// # use datafusion::common::DataFusionError; - /// # use datafusion::prelude::SessionContext; - /// # use datafusion_distributed::SessionBuilder; - /// - /// #[derive(Clone)] - /// struct CustomSessionBuilder; + /// // Add your UDFs, optimization rules, etc... /// - /// #[async_trait] - /// impl SessionBuilder for CustomSessionBuilder { - /// async fn session_context(&self, ctx: SessionContext) -> Result { - /// // mutate the context adding any custom logic - /// Ok(ctx) + /// Ok(builder.build()) /// } /// } /// ``` - async fn session_context( + async fn build_session_state( &self, - ctx: SessionContext, - ) -> Result { - Ok(ctx) - } + ctx: DistributedSessionBuilderContext, + ) -> Result; } -/// Noop implementation of the [SessionBuilder]. Used by default if no [SessionBuilder] is provided +/// Noop implementation of the [DistributedSessionBuilder]. Used by default if no [DistributedSessionBuilder] is provided /// while building the Arrow Flight endpoint. #[derive(Debug, Clone)] -pub struct NoopSessionBuilder; +pub struct DefaultSessionBuilder; + +#[async_trait] +impl DistributedSessionBuilder for DefaultSessionBuilder { + async fn build_session_state( + &self, + ctx: DistributedSessionBuilderContext, + ) -> Result { + Ok(SessionStateBuilder::new() + .with_runtime_env(ctx.runtime_env.clone()) + .with_default_features() + .build()) + } +} -impl SessionBuilder for NoopSessionBuilder {} +#[async_trait] +impl DistributedSessionBuilder for F +where + F: Fn(DistributedSessionBuilderContext) -> Fut + Send + Sync + 'static, + Fut: std::future::Future> + Send + 'static, +{ + async fn build_session_state( + &self, + ctx: DistributedSessionBuilderContext, + ) -> Result { + self(ctx).await + } +} diff --git a/src/lib.rs b/src/lib.rs index 40515a9..dbf8a5a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,7 +17,10 @@ pub mod test_utils; pub use channel_manager::{BoxCloneSyncChannel, ChannelManager, ChannelResolver}; pub use config_extension_ext::ConfigExtensionExt; -pub use flight_service::{ArrowFlightEndpoint, NoopSessionBuilder, SessionBuilder}; +pub use flight_service::{ + ArrowFlightEndpoint, DefaultSessionBuilder, DistributedSessionBuilder, + DistributedSessionBuilderContext, +}; pub use physical_optimizer::DistributedPhysicalOptimizerRule; pub use plan::ArrowFlightReadExec; pub use stage::{display_stage_graphviz, ExecutionStage}; diff --git a/src/test_utils/localhost.rs b/src/test_utils/localhost.rs index 83a328e..a782807 100644 --- a/src/test_utils/localhost.rs +++ b/src/test_utils/localhost.rs @@ -1,12 +1,13 @@ use crate::{ - ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelManager, ChannelResolver, SessionBuilder, + ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelManager, ChannelResolver, + DistributedSessionBuilder, DistributedSessionBuilderContext, }; use arrow_flight::flight_service_server::FlightServiceServer; use async_trait::async_trait; +use datafusion::common::runtime::JoinSet; use datafusion::common::DataFusionError; -use datafusion::execution::SessionStateBuilder; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::prelude::SessionContext; -use datafusion::{common::runtime::JoinSet, prelude::SessionConfig}; use std::error::Error; use std::sync::Arc; use std::time::Duration; @@ -19,7 +20,7 @@ pub async fn start_localhost_context( session_builder: B, ) -> (SessionContext, JoinSet<()>) where - B: SessionBuilder + Send + Sync + 'static, + B: DistributedSessionBuilder + Send + Sync + 'static, B: Clone, { let listeners = futures::future::try_join_all( @@ -53,25 +54,19 @@ where } tokio::time::sleep(Duration::from_millis(100)).await; - let config = SessionConfig::new().with_target_partitions(3); - - let builder = SessionStateBuilder::new() - .with_default_features() - .with_config(config); - let builder = session_builder.session_state_builder(builder).unwrap(); - - let state = builder.build(); - let state = session_builder.session_state(state).await.unwrap(); - - let ctx = SessionContext::new_with_state(state); - let ctx = session_builder.session_context(ctx).await.unwrap(); - - ctx.state_ref() - .write() + let mut state = session_builder + .build_session_state(DistributedSessionBuilderContext { + runtime_env: Arc::new(RuntimeEnv::default()), + headers: Default::default(), + }) + .await + .unwrap(); + state .config_mut() .set_extension(Arc::new(ChannelManager::new(channel_resolver))); + state.config_mut().options_mut().execution.target_partitions = 3; - (ctx, join_set) + (SessionContext::from(state), join_set) } #[derive(Clone)] @@ -108,7 +103,7 @@ impl ChannelResolver for LocalHostChannelResolver { pub async fn spawn_flight_service( channel_resolver: impl ChannelResolver + Send + Sync + 'static, - session_builder: impl SessionBuilder + Send + Sync + 'static, + session_builder: impl DistributedSessionBuilder + Send + Sync + 'static, incoming: TcpListener, ) -> Result<(), Box> { let mut endpoint = ArrowFlightEndpoint::new(channel_resolver); diff --git a/tests/custom_config_extension.rs b/tests/custom_config_extension.rs index d671d90..af84a92 100644 --- a/tests/custom_config_extension.rs +++ b/tests/custom_config_extension.rs @@ -1,12 +1,11 @@ #[cfg(all(feature = "integration", test))] mod tests { - use async_trait::async_trait; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::common::{extensions_options, internal_err}; use datafusion::config::ConfigExtension; use datafusion::error::DataFusionError; use datafusion::execution::{ - FunctionRegistry, SendableRecordBatchStream, SessionState, TaskContext, + FunctionRegistry, SendableRecordBatchStream, SessionState, SessionStateBuilder, TaskContext, }; use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; @@ -15,10 +14,10 @@ mod tests { execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; use datafusion_distributed::test_utils::localhost::start_localhost_context; - use datafusion_distributed::{add_user_codec, ConfigExtensionExt}; use datafusion_distributed::{ - ArrowFlightReadExec, DistributedPhysicalOptimizerRule, SessionBuilder, + add_user_codec, ConfigExtensionExt, DistributedSessionBuilderContext, }; + use datafusion_distributed::{ArrowFlightReadExec, DistributedPhysicalOptimizerRule}; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use futures::TryStreamExt; use prost::Message; @@ -28,23 +27,19 @@ mod tests { #[tokio::test] async fn custom_config_extension() -> Result<(), Box> { - #[derive(Clone)] - struct CustomSessionBuilder; - - #[async_trait] - impl SessionBuilder for CustomSessionBuilder { - async fn session_state( - &self, - mut state: SessionState, - ) -> Result { - state.retrieve_distributed_option_extension::()?; - add_user_codec(state.config_mut(), CustomConfigExtensionRequiredExecCodec); - Ok(state) - } + async fn build_state( + ctx: DistributedSessionBuilderContext, + ) -> Result { + let mut state = SessionStateBuilder::new() + .with_runtime_env(ctx.runtime_env) + .with_default_features() + .build(); + state.retrieve_distributed_option_extension::(&ctx.headers)?; + add_user_codec(state.config_mut(), CustomConfigExtensionRequiredExecCodec); + Ok(state) } - let (mut ctx, _guard) = start_localhost_context(3, CustomSessionBuilder).await; - add_user_codec(&mut ctx, CustomConfigExtensionRequiredExecCodec); + let (mut ctx, _guard) = start_localhost_context(3, build_state).await; ctx.add_distributed_option_extension(CustomExtension { foo: "foo".to_string(), bar: 1, diff --git a/tests/custom_extension_codec.rs b/tests/custom_extension_codec.rs index 3d2870d..fa86082 100644 --- a/tests/custom_extension_codec.rs +++ b/tests/custom_extension_codec.rs @@ -7,7 +7,7 @@ mod tests { use datafusion::arrow::util::pretty::pretty_format_batches; use datafusion::error::DataFusionError; use datafusion::execution::{ - FunctionRegistry, SendableRecordBatchStream, SessionStateBuilder, TaskContext, + FunctionRegistry, SendableRecordBatchStream, SessionState, SessionStateBuilder, TaskContext, }; use datafusion::logical_expr::Operator; use datafusion::physical_expr::expressions::{col, lit, BinaryExpr}; @@ -22,11 +22,11 @@ mod tests { use datafusion::physical_plan::{ displayable, execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; - use datafusion_distributed::assert_snapshot; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::{ - with_user_codec, ArrowFlightReadExec, DistributedPhysicalOptimizerRule, SessionBuilder, + add_user_codec, assert_snapshot, DistributedSessionBuilderContext, }; + use datafusion_distributed::{ArrowFlightReadExec, DistributedPhysicalOptimizerRule}; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_proto::protobuf::proto_error; use futures::{stream, TryStreamExt}; @@ -38,18 +38,18 @@ mod tests { #[tokio::test] #[ignore] async fn custom_extension_codec() -> Result<(), Box> { - #[derive(Clone)] - struct CustomSessionBuilder; - impl SessionBuilder for CustomSessionBuilder { - fn session_state_builder( - &self, - builder: SessionStateBuilder, - ) -> Result { - Ok(with_user_codec(builder, Int64ListExecCodec)) - } + async fn build_state( + ctx: DistributedSessionBuilderContext, + ) -> Result { + let mut state = SessionStateBuilder::new() + .with_runtime_env(ctx.runtime_env) + .with_default_features() + .build(); + add_user_codec(state.config_mut(), Int64ListExecCodec); + Ok(state) } - let (ctx, _guard) = start_localhost_context(3, CustomSessionBuilder).await; + let (ctx, _guard) = start_localhost_context(3, build_state).await; let single_node_plan = build_plan(false)?; assert_snapshot!(displayable(single_node_plan.as_ref()).indent(true).to_string(), @r" diff --git a/tests/distributed_aggregation.rs b/tests/distributed_aggregation.rs index 1a572f2..273b332 100644 --- a/tests/distributed_aggregation.rs +++ b/tests/distributed_aggregation.rs @@ -5,13 +5,13 @@ mod tests { use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::test_utils::parquet::register_parquet_tables; use datafusion_distributed::test_utils::plan::distribute_aggregate; - use datafusion_distributed::{assert_snapshot, NoopSessionBuilder}; + use datafusion_distributed::{assert_snapshot, DefaultSessionBuilder}; use futures::TryStreamExt; use std::error::Error; #[tokio::test] async fn distributed_aggregation() -> Result<(), Box> { - let (ctx, _guard) = start_localhost_context(3, NoopSessionBuilder).await; + let (ctx, _guard) = start_localhost_context(3, DefaultSessionBuilder).await; register_parquet_tables(&ctx).await?; let df = ctx diff --git a/tests/error_propagation.rs b/tests/error_propagation.rs index e8d2ca5..b030d09 100644 --- a/tests/error_propagation.rs +++ b/tests/error_propagation.rs @@ -3,7 +3,7 @@ mod tests { use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::error::DataFusionError; use datafusion::execution::{ - FunctionRegistry, SendableRecordBatchStream, SessionStateBuilder, TaskContext, + FunctionRegistry, SendableRecordBatchStream, SessionState, SessionStateBuilder, TaskContext, }; use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; @@ -13,7 +13,8 @@ mod tests { }; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::{ - with_user_codec, ArrowFlightReadExec, DistributedPhysicalOptimizerRule, SessionBuilder, + add_user_codec, ArrowFlightReadExec, DistributedPhysicalOptimizerRule, + DistributedSessionBuilderContext, }; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_proto::protobuf::proto_error; @@ -26,17 +27,18 @@ mod tests { #[tokio::test] async fn test_error_propagation() -> Result<(), Box> { - #[derive(Clone)] - struct CustomSessionBuilder; - impl SessionBuilder for CustomSessionBuilder { - fn session_state_builder( - &self, - builder: SessionStateBuilder, - ) -> Result { - Ok(with_user_codec(builder, ErrorExecCodec)) - } + async fn build_state( + ctx: DistributedSessionBuilderContext, + ) -> Result { + let mut state = SessionStateBuilder::new() + .with_runtime_env(ctx.runtime_env) + .with_default_features() + .build(); + add_user_codec(state.config_mut(), ErrorExecCodec); + Ok(state) } - let (ctx, _guard) = start_localhost_context(3, CustomSessionBuilder).await; + + let (ctx, _guard) = start_localhost_context(3, build_state).await; let mut plan: Arc = Arc::new(ErrorExec::new("something failed")); diff --git a/tests/highly_distributed_query.rs b/tests/highly_distributed_query.rs index d6c2cfa..c2b8160 100644 --- a/tests/highly_distributed_query.rs +++ b/tests/highly_distributed_query.rs @@ -4,7 +4,7 @@ mod tests { use datafusion::physical_plan::{displayable, execute_stream}; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::test_utils::parquet::register_parquet_tables; - use datafusion_distributed::{assert_snapshot, ArrowFlightReadExec, NoopSessionBuilder}; + use datafusion_distributed::{assert_snapshot, ArrowFlightReadExec, DefaultSessionBuilder}; use futures::TryStreamExt; use std::error::Error; use std::sync::Arc; @@ -12,7 +12,7 @@ mod tests { #[tokio::test] #[ignore] async fn highly_distributed_query() -> Result<(), Box> { - let (ctx, _guard) = start_localhost_context(9, NoopSessionBuilder).await; + let (ctx, _guard) = start_localhost_context(9, DefaultSessionBuilder).await; register_parquet_tables(&ctx).await?; let df = ctx.sql(r#"SELECT * FROM flights_1m"#).await?; diff --git a/tests/non_distributed_consistency_test.rs b/tests/non_distributed_consistency_test.rs index ea2f395..819289b 100644 --- a/tests/non_distributed_consistency_test.rs +++ b/tests/non_distributed_consistency_test.rs @@ -3,12 +3,13 @@ mod common; #[cfg(all(feature = "integration", test))] mod tests { use crate::common::{ensure_tpch_data, get_test_data_dir, get_test_tpch_query}; - use async_trait::async_trait; use datafusion::error::DataFusionError; - use datafusion::execution::SessionStateBuilder; + use datafusion::execution::{SessionState, SessionStateBuilder}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_distributed::test_utils::localhost::start_localhost_context; - use datafusion_distributed::{DistributedPhysicalOptimizerRule, SessionBuilder}; + use datafusion_distributed::{ + DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext, + }; use futures::TryStreamExt; use std::error::Error; use std::sync::Arc; @@ -126,48 +127,37 @@ mod tests { } async fn test_tpch_query(query_id: u8) -> Result<(), Box> { - let (ctx, _guard) = start_localhost_context(2, TestSessionBuilder).await; + let (ctx, _guard) = start_localhost_context(2, build_state).await; run_tpch_query(ctx, query_id).await } - #[derive(Clone)] - struct TestSessionBuilder; - - #[async_trait] - impl SessionBuilder for TestSessionBuilder { - fn session_state_builder( - &self, - builder: SessionStateBuilder, - ) -> Result { - let mut config = SessionConfig::new().with_target_partitions(3); - - // FIXME: these three options are critical for the correct function of the library - // but we are not enforcing that the user sets them. They are here at the moment - // but we should figure out a way to do this better. - config - .options_mut() - .optimizer - .hash_join_single_partition_threshold = 0; - config - .options_mut() - .optimizer - .hash_join_single_partition_threshold_rows = 0; - - config.options_mut().optimizer.prefer_hash_join = true; - // end critical options section - - let rule = DistributedPhysicalOptimizerRule::new().with_maximum_partitions_per_task(2); - Ok(builder - .with_config(config) - .with_physical_optimizer_rule(Arc::new(rule))) - } - - async fn session_context( - &self, - ctx: SessionContext, - ) -> std::result::Result { - Ok(ctx) - } + async fn build_state( + ctx: DistributedSessionBuilderContext, + ) -> Result { + let mut config = SessionConfig::new().with_target_partitions(3); + + // FIXME: these three options are critical for the correct function of the library + // but we are not enforcing that the user sets them. They are here at the moment + // but we should figure out a way to do this better. + config + .options_mut() + .optimizer + .hash_join_single_partition_threshold = 0; + config + .options_mut() + .optimizer + .hash_join_single_partition_threshold_rows = 0; + + config.options_mut().optimizer.prefer_hash_join = true; + // end critical options section + + let rule = DistributedPhysicalOptimizerRule::new().with_maximum_partitions_per_task(2); + Ok(SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(ctx.runtime_env) + .with_default_features() + .with_physical_optimizer_rule(Arc::new(rule)) + .build()) } // test_non_distributed_consistency runs each TPC-H query twice - once in a distributed manner From c692a6f0fe18d1b1bc2afc825a93f0a15c51a8ce Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Thu, 21 Aug 2025 14:36:21 +0200 Subject: [PATCH 9/9] Fix clippy errors --- src/common/ttl_map.rs | 6 +++--- src/config_extension_ext.rs | 43 ++++++++++++++++++++++--------------- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/src/common/ttl_map.rs b/src/common/ttl_map.rs index abfd377..064ea98 100644 --- a/src/common/ttl_map.rs +++ b/src/common/ttl_map.rs @@ -94,7 +94,7 @@ where shard.insert(key); } BucketOp::Clear => { - let keys_to_delete = mem::replace(&mut shard, HashSet::new()); + let keys_to_delete = mem::take(&mut shard); for key in keys_to_delete { data.remove(&key); } @@ -253,14 +253,14 @@ where /// run_gc_loop will continuously clear expired entries from the map, checking every `period`. The /// function terminates if `shutdown` is signalled. - async fn run_gc_loop(time: Arc, period: Duration, buckets: &Vec>) { + async fn run_gc_loop(time: Arc, period: Duration, buckets: &[Bucket]) { loop { tokio::time::sleep(period).await; Self::gc(time.clone(), buckets); } } - fn gc(time: Arc, buckets: &Vec>) { + fn gc(time: Arc, buckets: &[Bucket]) { let index = time.load(std::sync::atomic::Ordering::SeqCst) % buckets.len() as u64; buckets[index as usize].clear(); time.fetch_add(1, std::sync::atomic::Ordering::SeqCst); diff --git a/src/config_extension_ext.rs b/src/config_extension_ext.rs index eb927e0..bab103a 100644 --- a/src/config_extension_ext.rs +++ b/src/config_extension_ext.rs @@ -167,7 +167,7 @@ impl ConfigExtensionExt for SessionConfig { if key.starts_with(&prefix) { found_some = true; result.set( - &key.trim_start_matches(&prefix), + key.trim_start_matches(&prefix), v.to_str().map_err(|err| { internal_datafusion_err!("Cannot parse header value: {err}") })?, @@ -236,10 +236,11 @@ mod tests { fn test_propagation() -> Result<(), Box> { let mut config = SessionConfig::new(); - let mut opt = CustomExtension::default(); - opt.foo = "foo".to_string(); - opt.bar = 1; - opt.baz = true; + let opt = CustomExtension { + foo: "".to_string(), + bar: 0, + baz: false, + }; config.add_distributed_option_extension(opt)?; let metadata = config.get_extension::().unwrap(); @@ -283,12 +284,16 @@ mod tests { fn test_new_extension_overwrites_previous() -> Result<(), Box> { let mut config = SessionConfig::new(); - let mut opt1 = CustomExtension::default(); - opt1.foo = "first".to_string(); + let opt1 = CustomExtension { + foo: "first".to_string(), + ..Default::default() + }; config.add_distributed_option_extension(opt1)?; - let mut opt2 = CustomExtension::default(); - opt2.bar = 42; + let opt2 = CustomExtension { + bar: 42, + ..Default::default() + }; config.add_distributed_option_extension(opt2)?; let flight_metadata = config.get_extension::().unwrap(); @@ -335,13 +340,17 @@ mod tests { fn test_multiple_extensions_different_prefixes() -> Result<(), Box> { let mut config = SessionConfig::new(); - let mut custom_opt = CustomExtension::default(); - custom_opt.foo = "custom_value".to_string(); - custom_opt.bar = 123; + let custom_opt = CustomExtension { + foo: "custom_value".to_string(), + bar: 123, + ..Default::default() + }; - let mut another_opt = AnotherExtension::default(); - another_opt.setting1 = "other".to_string(); - another_opt.setting2 = 456; + let another_opt = AnotherExtension { + setting1: "other".to_string(), + setting2: 456, + ..Default::default() + }; config.add_distributed_option_extension(custom_opt)?; config.add_distributed_option_extension(another_opt)?; @@ -371,8 +380,8 @@ mod tests { ); let mut new_config = SessionConfig::new(); - new_config.retrieve_distributed_option_extension::(&metadata)?; - new_config.retrieve_distributed_option_extension::(&metadata)?; + new_config.retrieve_distributed_option_extension::(metadata)?; + new_config.retrieve_distributed_option_extension::(metadata)?; let propagated_custom = get_ext::(&new_config); let propagated_another = get_ext::(&new_config);