diff --git a/src/channel_manager.rs b/src/channel_manager.rs deleted file mode 100644 index 08c60e6..0000000 --- a/src/channel_manager.rs +++ /dev/null @@ -1,71 +0,0 @@ -use async_trait::async_trait; -use datafusion::common::internal_datafusion_err; -use datafusion::error::DataFusionError; -use datafusion::execution::TaskContext; -use datafusion::prelude::{SessionConfig, SessionContext}; -use delegate::delegate; -use std::sync::Arc; -use tonic::body::BoxBody; -use url::Url; - -#[derive(Clone)] -pub struct ChannelManager(Arc); - -impl ChannelManager { - pub fn new(resolver: impl ChannelResolver + Send + Sync + 'static) -> Self { - Self(Arc::new(resolver)) - } -} - -pub type BoxCloneSyncChannel = tower::util::BoxCloneSyncService< - http::Request, - http::Response, - tonic::transport::Error, ->; - -/// Abstracts networking details so that users can implement their own network resolution -/// mechanism. -#[async_trait] -pub trait ChannelResolver { - /// Gets all available worker URLs. Used during stage assignment. - fn get_urls(&self) -> Result, DataFusionError>; - /// For a given URL, get a channel for communicating to it. - async fn get_channel_for_url(&self, url: &Url) -> Result; -} - -impl ChannelManager { - delegate! { - to self.0 { - pub fn get_urls(&self) -> Result, DataFusionError>; - pub async fn get_channel_for_url(&self, url: &Url) -> Result; - } - } -} - -impl TryInto for &SessionConfig { - type Error = DataFusionError; - - fn try_into(self) -> Result { - Ok(self - .get_extension::() - .ok_or_else(|| internal_datafusion_err!("No extension ChannelManager"))? - .as_ref() - .clone()) - } -} - -impl TryInto for &TaskContext { - type Error = DataFusionError; - - fn try_into(self) -> Result { - self.session_config().try_into() - } -} - -impl TryInto for &SessionContext { - type Error = DataFusionError; - - fn try_into(self) -> Result { - self.task_ctx().as_ref().try_into() - } -} diff --git a/src/channel_manager_ext.rs b/src/channel_manager_ext.rs new file mode 100644 index 0000000..1a6df43 --- /dev/null +++ b/src/channel_manager_ext.rs @@ -0,0 +1,52 @@ +use async_trait::async_trait; +use datafusion::error::DataFusionError; +use datafusion::prelude::SessionConfig; +use std::sync::Arc; +use tonic::body::BoxBody; +use url::Url; + +pub(crate) fn set_distributed_channel_resolver( + cfg: &mut SessionConfig, + channel_resolver: impl ChannelResolver + Send + Sync + 'static, +) { + cfg.set_extension(Arc::new(ChannelResolverExtension(Arc::new( + channel_resolver, + )))); +} + +pub(crate) fn get_distributed_channel_resolver( + cfg: &SessionConfig, +) -> Option> { + cfg.get_extension::() + .map(|cm| cm.0.clone()) +} + +#[derive(Clone)] +struct ChannelResolverExtension(Arc); + +pub type BoxCloneSyncChannel = tower::util::BoxCloneSyncService< + http::Request, + http::Response, + tonic::transport::Error, +>; + +/// Abstracts networking details so that users can implement their own network resolution +/// mechanism. +#[async_trait] +pub trait ChannelResolver { + /// Gets all available worker URLs. Used during stage assignment. + fn get_urls(&self) -> Result, DataFusionError>; + /// For a given URL, get a channel for communicating to it. + async fn get_channel_for_url(&self, url: &Url) -> Result; +} + +#[async_trait] +impl ChannelResolver for Arc { + fn get_urls(&self) -> Result, DataFusionError> { + self.as_ref().get_urls() + } + + async fn get_channel_for_url(&self, url: &Url) -> Result { + self.as_ref().get_channel_for_url(url).await + } +} diff --git a/src/composed_extension_codec.rs b/src/common/composed_extension_codec.rs similarity index 100% rename from src/composed_extension_codec.rs rename to src/common/composed_extension_codec.rs diff --git a/src/common/mod.rs b/src/common/mod.rs index 6e77e1a..3104cc9 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,3 +1,6 @@ +mod composed_extension_codec; #[allow(unused)] pub mod ttl_map; pub mod util; + +pub(crate) use composed_extension_codec::ComposedPhysicalExtensionCodec; diff --git a/src/config_extension_ext.rs b/src/config_extension_ext.rs index bab103a..66e9eb4 100644 --- a/src/config_extension_ext.rs +++ b/src/config_extension_ext.rs @@ -1,8 +1,6 @@ use datafusion::common::{internal_datafusion_err, DataFusionError}; use datafusion::config::ConfigExtension; -use datafusion::execution::{SessionState, SessionStateBuilder}; -use datafusion::prelude::{SessionConfig, SessionContext}; -use delegate::delegate; +use datafusion::prelude::SessionConfig; use http::{HeaderMap, HeaderName}; use std::error::Error; use std::str::FromStr; @@ -10,203 +8,64 @@ use std::sync::Arc; const FLIGHT_METADATA_CONFIG_PREFIX: &str = "x-datafusion-distributed-config-"; -/// Extension trait for `SessionConfig` to add support for propagating [ConfigExtension]s across -/// network calls. -pub trait ConfigExtensionExt { - /// Adds the provided [ConfigExtension] to the distributed context. The [ConfigExtension] will - /// be serialized using gRPC metadata and sent across tasks. Users are expected to call this - /// method with their own extensions to be able to access them in any place in the - /// plan. - /// - /// This method also adds the provided [ConfigExtension] to the current session option - /// extensions, the same as calling [SessionConfig::with_option_extension]. - /// - /// Example: - /// - /// ```rust - /// # use async_trait::async_trait; - /// # use datafusion::common::{extensions_options, DataFusionError}; - /// # use datafusion::config::ConfigExtension; - /// # use datafusion::execution::{SessionState, SessionStateBuilder}; - /// # use datafusion::prelude::SessionConfig; - /// # use datafusion_distributed::{ConfigExtensionExt, DistributedSessionBuilder, DistributedSessionBuilderContext}; - /// - /// extensions_options! { - /// pub struct CustomExtension { - /// pub foo: String, default = "".to_string() - /// pub bar: usize, default = 0 - /// pub baz: bool, default = false - /// } - /// } - /// - /// impl ConfigExtension for CustomExtension { - /// const PREFIX: &'static str = "custom"; - /// } - /// - /// let mut config = SessionConfig::new(); - /// let mut opt = CustomExtension::default(); - /// // Now, the CustomExtension will be able to cross network boundaries. Upon making an Arrow - /// // Flight request, it will be sent through gRPC metadata. - /// config.add_distributed_option_extension(opt).unwrap(); - /// - /// struct MyCustomSessionBuilder; - /// - /// #[async_trait] - /// impl DistributedSessionBuilder for MyCustomSessionBuilder { - /// async fn build_session_state(&self, ctx: DistributedSessionBuilderContext) -> Result { - /// let mut state = SessionStateBuilder::new().build(); - /// - /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will - /// // know how to deserialize the CustomExtension from the gRPC metadata. - /// state.retrieve_distributed_option_extension::(&ctx.headers)?; - /// Ok(state) - /// } - /// } - /// ``` - fn add_distributed_option_extension( - &mut self, - t: T, - ) -> Result<(), DataFusionError>; - - /// Gets the specified [ConfigExtension] from the distributed context and adds it to - /// the [SessionConfig::options] extensions. The function will build a new [ConfigExtension] - /// out of the Arrow Flight gRPC metadata present in the [SessionConfig] and will propagate it - /// to the extension options. - /// Example: - /// - /// ```rust - /// # use async_trait::async_trait; - /// # use datafusion::common::{extensions_options, DataFusionError}; - /// # use datafusion::config::ConfigExtension; - /// # use datafusion::execution::{SessionState, SessionStateBuilder}; - /// # use datafusion::prelude::SessionConfig; - /// # use datafusion_distributed::{ConfigExtensionExt, DistributedSessionBuilder, DistributedSessionBuilderContext}; - /// - /// extensions_options! { - /// pub struct CustomExtension { - /// pub foo: String, default = "".to_string() - /// pub bar: usize, default = 0 - /// pub baz: bool, default = false - /// } - /// } - /// - /// impl ConfigExtension for CustomExtension { - /// const PREFIX: &'static str = "custom"; - /// } - /// - /// let mut config = SessionConfig::new(); - /// let mut opt = CustomExtension::default(); - /// // Now, the CustomExtension will be able to cross network boundaries. Upon making an Arrow - /// // Flight request, it will be sent through gRPC metadata. - /// config.add_distributed_option_extension(opt).unwrap(); - /// - /// struct MyCustomSessionBuilder; - /// - /// #[async_trait] - /// impl DistributedSessionBuilder for MyCustomSessionBuilder { - /// async fn build_session_state(&self, ctx: DistributedSessionBuilderContext) -> Result { - /// let mut state = SessionStateBuilder::new().build(); - /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will - /// // know how to deserialize the CustomExtension from the gRPC metadata. - /// state.retrieve_distributed_option_extension::(&ctx.headers)?; - /// Ok(state) - /// } - /// } - /// ``` - fn retrieve_distributed_option_extension( - &mut self, - headers: &HeaderMap, - ) -> Result<(), DataFusionError>; -} - -impl ConfigExtensionExt for SessionConfig { - fn add_distributed_option_extension( - &mut self, - t: T, - ) -> Result<(), DataFusionError> { - fn parse_err(err: impl Error) -> DataFusionError { - DataFusionError::Internal(format!("Failed to add config extension: {err}")) - } - let mut meta = HeaderMap::new(); - - for entry in t.entries() { - if let Some(value) = entry.value { - meta.insert( - HeaderName::from_str(&format!( - "{}{}.{}", - FLIGHT_METADATA_CONFIG_PREFIX, - T::PREFIX, - entry.key - )) - .map_err(parse_err)?, - value.parse().map_err(parse_err)?, - ); - } - } - let flight_metadata = ContextGrpcMetadata(meta); - match self.get_extension::() { - None => self.set_extension(Arc::new(flight_metadata)), - Some(prev) => { - let prev = prev.as_ref().clone(); - self.set_extension(Arc::new(prev.merge(flight_metadata))) - } - } - self.options_mut().extensions.insert(t); - Ok(()) +pub(crate) fn set_distributed_option_extension( + cfg: &mut SessionConfig, + t: T, +) -> Result<(), DataFusionError> { + fn parse_err(err: impl Error) -> DataFusionError { + DataFusionError::Internal(format!("Failed to add config extension: {err}")) } - - fn retrieve_distributed_option_extension( - &mut self, - headers: &HeaderMap, - ) -> Result<(), DataFusionError> { - let mut result = T::default(); - let mut found_some = false; - for (k, v) in headers.iter() { - let key = k.as_str().trim_start_matches(FLIGHT_METADATA_CONFIG_PREFIX); - let prefix = format!("{}.", T::PREFIX); - if key.starts_with(&prefix) { - found_some = true; - result.set( - key.trim_start_matches(&prefix), - v.to_str().map_err(|err| { - internal_datafusion_err!("Cannot parse header value: {err}") - })?, - )?; - } + let mut meta = HeaderMap::new(); + + for entry in t.entries() { + if let Some(value) = entry.value { + meta.insert( + HeaderName::from_str(&format!( + "{}{}.{}", + FLIGHT_METADATA_CONFIG_PREFIX, + T::PREFIX, + entry.key + )) + .map_err(parse_err)?, + value.parse().map_err(parse_err)?, + ); } - if !found_some { - return Ok(()); - } - self.options_mut().extensions.insert(result); - Ok(()) } -} - -impl ConfigExtensionExt for SessionStateBuilder { - delegate! { - to self.config().get_or_insert_default() { - fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; - fn retrieve_distributed_option_extension(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>; + let flight_metadata = ContextGrpcMetadata(meta); + match cfg.get_extension::() { + None => cfg.set_extension(Arc::new(flight_metadata)), + Some(prev) => { + let prev = prev.as_ref().clone(); + cfg.set_extension(Arc::new(prev.merge(flight_metadata))) } } + cfg.options_mut().extensions.insert(t); + Ok(()) } -impl ConfigExtensionExt for SessionState { - delegate! { - to self.config_mut() { - fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; - fn retrieve_distributed_option_extension(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>; +pub(crate) fn set_distributed_option_extension_from_headers( + cfg: &mut SessionConfig, + headers: &HeaderMap, +) -> Result<(), DataFusionError> { + let mut result = T::default(); + let mut found_some = false; + for (k, v) in headers.iter() { + let key = k.as_str().trim_start_matches(FLIGHT_METADATA_CONFIG_PREFIX); + let prefix = format!("{}.", T::PREFIX); + if key.starts_with(&prefix) { + found_some = true; + result.set( + key.trim_start_matches(&prefix), + v.to_str() + .map_err(|err| internal_datafusion_err!("Cannot parse header value: {err}"))?, + )?; } } -} - -impl ConfigExtensionExt for SessionContext { - delegate! { - to self.state_ref().write().config_mut() { - fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; - fn retrieve_distributed_option_extension(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>; - } + if !found_some { + return Ok(()); } + cfg.options_mut().extensions.insert(result); + Ok(()) } #[derive(Clone, Debug, Default)] @@ -224,8 +83,10 @@ impl ContextGrpcMetadata { #[cfg(test)] mod tests { - use crate::config_extension_ext::ContextGrpcMetadata; - use crate::ConfigExtensionExt; + use crate::config_extension_ext::{ + set_distributed_option_extension, set_distributed_option_extension_from_headers, + ContextGrpcMetadata, + }; use datafusion::common::extensions_options; use datafusion::config::ConfigExtension; use datafusion::prelude::SessionConfig; @@ -242,10 +103,13 @@ mod tests { baz: false, }; - config.add_distributed_option_extension(opt)?; + set_distributed_option_extension(&mut config, opt)?; let metadata = config.get_extension::().unwrap(); let mut new_config = SessionConfig::new(); - new_config.retrieve_distributed_option_extension::(&metadata.0)?; + set_distributed_option_extension_from_headers::( + &mut new_config, + &metadata.0, + )?; let opt = get_ext::(&config); let new_opt = get_ext::(&new_config); @@ -262,7 +126,7 @@ mod tests { let mut config = SessionConfig::new(); let opt = CustomExtension::default(); - config.add_distributed_option_extension(opt)?; + set_distributed_option_extension(&mut config, opt)?; let flight_metadata = config.get_extension::(); assert!(flight_metadata.is_some()); @@ -288,13 +152,13 @@ mod tests { foo: "first".to_string(), ..Default::default() }; - config.add_distributed_option_extension(opt1)?; + set_distributed_option_extension(&mut config, opt1)?; let opt2 = CustomExtension { bar: 42, ..Default::default() }; - config.add_distributed_option_extension(opt2)?; + set_distributed_option_extension(&mut config, opt2)?; let flight_metadata = config.get_extension::().unwrap(); let metadata = &flight_metadata.0; @@ -311,7 +175,10 @@ mod tests { fn test_propagate_no_metadata() -> Result<(), Box> { let mut config = SessionConfig::new(); - config.retrieve_distributed_option_extension::(&Default::default())?; + set_distributed_option_extension_from_headers::( + &mut config, + &Default::default(), + )?; let extension = config.options().extensions.get::(); assert!(extension.is_none()); @@ -328,7 +195,7 @@ mod tests { HeaderValue::from_str("value").unwrap(), ); - config.retrieve_distributed_option_extension::(&header_map)?; + set_distributed_option_extension_from_headers::(&mut config, &header_map)?; let extension = config.options().extensions.get::(); assert!(extension.is_none()); @@ -352,8 +219,8 @@ mod tests { ..Default::default() }; - config.add_distributed_option_extension(custom_opt)?; - config.add_distributed_option_extension(another_opt)?; + set_distributed_option_extension(&mut config, custom_opt)?; + set_distributed_option_extension(&mut config, another_opt)?; let flight_metadata = config.get_extension::().unwrap(); let metadata = &flight_metadata.0; @@ -380,8 +247,14 @@ mod tests { ); let mut new_config = SessionConfig::new(); - new_config.retrieve_distributed_option_extension::(metadata)?; - new_config.retrieve_distributed_option_extension::(metadata)?; + set_distributed_option_extension_from_headers::( + &mut new_config, + metadata, + )?; + set_distributed_option_extension_from_headers::( + &mut new_config, + metadata, + )?; let propagated_custom = get_ext::(&new_config); let propagated_another = get_ext::(&new_config); @@ -399,7 +272,7 @@ mod tests { let mut config = SessionConfig::new(); let extension = InvalidExtension::default(); - let result = config.add_distributed_option_extension(extension); + let result = set_distributed_option_extension(&mut config, extension); assert!(result.is_err()); } @@ -408,7 +281,7 @@ mod tests { let mut config = SessionConfig::new(); let extension = InvalidValueExtension::default(); - let result = config.add_distributed_option_extension(extension); + let result = set_distributed_option_extension(&mut config, extension); assert!(result.is_err()); } diff --git a/src/distributed_ext.rs b/src/distributed_ext.rs new file mode 100644 index 0000000..f7f06e0 --- /dev/null +++ b/src/distributed_ext.rs @@ -0,0 +1,340 @@ +use crate::channel_manager_ext::set_distributed_channel_resolver; +use crate::config_extension_ext::{ + set_distributed_option_extension, set_distributed_option_extension_from_headers, +}; +use crate::user_codec_ext::set_distributed_user_codec; +use crate::ChannelResolver; +use datafusion::common::DataFusionError; +use datafusion::config::ConfigExtension; +use datafusion::execution::{SessionState, SessionStateBuilder}; +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_proto::physical_plan::PhysicalExtensionCodec; +use delegate::delegate; +use http::HeaderMap; + +/// Extends DataFusion with distributed capabilities. +pub trait DistributedExt: Sized { + /// Adds the provided [ConfigExtension] to the distributed context. The [ConfigExtension] will + /// be serialized using gRPC metadata and sent across tasks. Users are expected to call this + /// method with their own extensions to be able to access them in any place in the + /// plan. + /// + /// This method also adds the provided [ConfigExtension] to the current session option + /// extensions, the same as calling [SessionConfig::with_option_extension]. + /// + /// Example: + /// + /// ```rust + /// # use async_trait::async_trait; + /// # use datafusion::common::{extensions_options, DataFusionError}; + /// # use datafusion::config::ConfigExtension; + /// # use datafusion::execution::{SessionState, SessionStateBuilder}; + /// # use datafusion::prelude::SessionConfig; + /// # use datafusion_distributed::{DistributedExt, DistributedSessionBuilder, DistributedSessionBuilderContext}; + /// + /// extensions_options! { + /// pub struct CustomExtension { + /// pub foo: String, default = "".to_string() + /// pub bar: usize, default = 0 + /// pub baz: bool, default = false + /// } + /// } + /// + /// impl ConfigExtension for CustomExtension { + /// const PREFIX: &'static str = "custom"; + /// } + /// + /// let mut my_custom_extension = CustomExtension::default(); + /// // Now, the CustomExtension will be able to cross network boundaries. Upon making an Arrow + /// // Flight request, it will be sent through gRPC metadata. + /// let mut config = SessionConfig::new() + /// .with_distributed_option_extension(my_custom_extension).unwrap(); + /// + /// async fn build_state(ctx: DistributedSessionBuilderContext) -> Result { + /// // This function can be provided to an ArrowFlightEndpoint in order to tell it how to + /// // build sessions that retrieve the CustomExtension from gRPC metadata. + /// Ok(SessionStateBuilder::new() + /// .with_distributed_option_extension_from_headers::(&ctx.headers)? + /// .build()) + /// } + /// ``` + fn with_distributed_option_extension( + self, + t: T, + ) -> Result; + + /// Same as [DistributedExt::with_distributed_option_extension] but with an in-place mutation + fn set_distributed_option_extension( + &mut self, + t: T, + ) -> Result<(), DataFusionError>; + + /// Adds the provided [ConfigExtension] to the distributed context. The [ConfigExtension] will + /// be serialized using gRPC metadata and sent across tasks. Users are expected to call this + /// method with their own extensions to be able to access them in any place in the + /// plan. + /// + /// This method also adds the provided [ConfigExtension] to the current session option + /// extensions, the same as calling [SessionConfig::with_option_extension]. + /// + /// Example: + /// + /// ```rust + /// # use async_trait::async_trait; + /// # use datafusion::common::{extensions_options, DataFusionError}; + /// # use datafusion::config::ConfigExtension; + /// # use datafusion::execution::{SessionState, SessionStateBuilder}; + /// # use datafusion::prelude::SessionConfig; + /// # use datafusion_distributed::{DistributedExt, DistributedSessionBuilder, DistributedSessionBuilderContext}; + /// + /// extensions_options! { + /// pub struct CustomExtension { + /// pub foo: String, default = "".to_string() + /// pub bar: usize, default = 0 + /// pub baz: bool, default = false + /// } + /// } + /// + /// impl ConfigExtension for CustomExtension { + /// const PREFIX: &'static str = "custom"; + /// } + /// + /// let mut my_custom_extension = CustomExtension::default(); + /// // Now, the CustomExtension will be able to cross network boundaries. Upon making an Arrow + /// // Flight request, it will be sent through gRPC metadata. + /// let mut config = SessionConfig::new() + /// .with_distributed_option_extension(my_custom_extension).unwrap(); + /// + /// async fn build_state(ctx: DistributedSessionBuilderContext) -> Result { + /// // This function can be provided to an ArrowFlightEndpoint in order to tell it how to + /// // build sessions that retrieve the CustomExtension from gRPC metadata. + /// Ok(SessionStateBuilder::new() + /// .with_distributed_option_extension_from_headers::(&ctx.headers)? + /// .build()) + /// } + /// ``` + fn with_distributed_option_extension_from_headers( + self, + headers: &HeaderMap, + ) -> Result; + + /// Same as [DistributedExt::with_distributed_option_extension_from_headers] but with an in-place mutation + fn set_distributed_option_extension_from_headers( + &mut self, + headers: &HeaderMap, + ) -> Result<(), DataFusionError>; + + /// Injects a user-defined [PhysicalExtensionCodec] that is capable of encoding/decoding + /// custom execution nodes. + /// + /// Example: + /// + /// ``` + /// # use std::sync::Arc; + /// # use datafusion::common::DataFusionError; + /// # use datafusion::execution::{SessionState, FunctionRegistry, SessionStateBuilder}; + /// # use datafusion::physical_plan::ExecutionPlan; + /// # use datafusion::prelude::SessionConfig; + /// # use datafusion_proto::physical_plan::PhysicalExtensionCodec; + /// # use datafusion_distributed::{DistributedExt, DistributedSessionBuilderContext}; + /// + /// #[derive(Debug)] + /// struct CustomExecCodec; + /// + /// impl PhysicalExtensionCodec for CustomExecCodec { + /// fn try_decode(&self, buf: &[u8], inputs: &[Arc], registry: &dyn FunctionRegistry) -> datafusion::common::Result> { + /// todo!() + /// } + /// + /// fn try_encode(&self, node: Arc, buf: &mut Vec) -> datafusion::common::Result<()> { + /// todo!() + /// } + /// } + /// + /// let config = SessionConfig::new().with_distributed_user_codec(CustomExecCodec); + /// + /// async fn build_state(ctx: DistributedSessionBuilderContext) -> Result { + /// // This function can be provided to an ArrowFlightEndpoint in order to tell it how to + /// // encode/decode CustomExec nodes. + /// Ok(SessionStateBuilder::new() + /// .with_distributed_user_codec(CustomExecCodec) + /// .build()) + /// } + /// ``` + fn with_distributed_user_codec(self, codec: T) -> Self; + + /// Same as [DistributedExt::with_distributed_user_codec] but with an in-place mutation + fn set_distributed_user_codec(&mut self, codec: T); + + /// Injects a [ChannelResolver] implementation for Distributed DataFusion to resolve worker + /// nodes. When running in distributed mode, setting a [ChannelResolver] is required. + /// + /// Example: + /// + /// ``` + /// # use async_trait::async_trait; + /// # use datafusion::common::DataFusionError; + /// # use datafusion::execution::{SessionState, SessionStateBuilder}; + /// # use datafusion::prelude::SessionConfig; + /// # use url::Url; + /// # use datafusion_distributed::{BoxCloneSyncChannel, ChannelResolver, DistributedExt, DistributedSessionBuilderContext}; + /// + /// struct CustomChannelResolver; + /// + /// #[async_trait] + /// impl ChannelResolver for CustomChannelResolver { + /// fn get_urls(&self) -> Result, DataFusionError> { + /// todo!() + /// } + /// + /// async fn get_channel_for_url(&self, url: &Url) -> Result { + /// todo!() + /// } + /// } + /// + /// let config = SessionConfig::new().with_distributed_channel_resolver(CustomChannelResolver); + /// + /// async fn build_state(ctx: DistributedSessionBuilderContext) -> Result { + /// // This function can be provided to an ArrowFlightEndpoint so that it knows how to + /// // resolve tonic channels from URLs upon making network calls to other nodes. + /// Ok(SessionStateBuilder::new() + /// .with_distributed_channel_resolver(CustomChannelResolver) + /// .build()) + /// } + /// ``` + fn with_distributed_channel_resolver( + self, + resolver: T, + ) -> Self; + + /// Same as [DistributedExt::with_distributed_channel_resolver] but with an in-place mutation. + fn set_distributed_channel_resolver( + &mut self, + resolver: T, + ); +} + +impl DistributedExt for SessionConfig { + fn set_distributed_option_extension( + &mut self, + t: T, + ) -> Result<(), DataFusionError> { + set_distributed_option_extension(self, t) + } + + fn set_distributed_option_extension_from_headers( + &mut self, + headers: &HeaderMap, + ) -> Result<(), DataFusionError> { + set_distributed_option_extension_from_headers::(self, headers) + } + + fn set_distributed_user_codec(&mut self, codec: T) { + set_distributed_user_codec(self, codec) + } + + fn set_distributed_channel_resolver( + &mut self, + resolver: T, + ) { + set_distributed_channel_resolver(self, resolver) + } + + delegate! { + to self { + #[call(set_distributed_option_extension)] + #[expr($?;Ok(self))] + fn with_distributed_option_extension(mut self, t: T) -> Result; + + #[call(set_distributed_option_extension_from_headers)] + #[expr($?;Ok(self))] + fn with_distributed_option_extension_from_headers(mut self, headers: &HeaderMap) -> Result; + + #[call(set_distributed_user_codec)] + #[expr($;self)] + fn with_distributed_user_codec(mut self, codec: T) -> Self; + + #[call(set_distributed_channel_resolver)] + #[expr($;self)] + fn with_distributed_channel_resolver(mut self, resolver: T) -> Self; + } + } +} + +impl DistributedExt for SessionStateBuilder { + delegate! { + to self.config().get_or_insert_default() { + fn set_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; + #[call(set_distributed_option_extension)] + #[expr($?;Ok(self))] + fn with_distributed_option_extension(mut self, t: T) -> Result; + + fn set_distributed_option_extension_from_headers(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>; + #[call(set_distributed_option_extension_from_headers)] + #[expr($?;Ok(self))] + fn with_distributed_option_extension_from_headers(mut self, headers: &HeaderMap) -> Result; + + fn set_distributed_user_codec(&mut self, codec: T); + #[call(set_distributed_user_codec)] + #[expr($;self)] + fn with_distributed_user_codec(mut self, codec: T) -> Self; + + fn set_distributed_channel_resolver(&mut self, resolver: T); + #[call(set_distributed_channel_resolver)] + #[expr($;self)] + fn with_distributed_channel_resolver(mut self, resolver: T) -> Self; + } + } +} + +impl DistributedExt for SessionState { + delegate! { + to self.config_mut() { + fn set_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; + #[call(set_distributed_option_extension)] + #[expr($?;Ok(self))] + fn with_distributed_option_extension(mut self, t: T) -> Result; + + fn set_distributed_option_extension_from_headers(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>; + #[call(set_distributed_option_extension_from_headers)] + #[expr($?;Ok(self))] + fn with_distributed_option_extension_from_headers(mut self, headers: &HeaderMap) -> Result; + + fn set_distributed_user_codec(&mut self, codec: T); + #[call(set_distributed_user_codec)] + #[expr($;self)] + fn with_distributed_user_codec(mut self, codec: T) -> Self; + + fn set_distributed_channel_resolver(&mut self, resolver: T); + #[call(set_distributed_channel_resolver)] + #[expr($;self)] + fn with_distributed_channel_resolver(mut self, resolver: T) -> Self; + } + } +} + +impl DistributedExt for SessionContext { + delegate! { + to self.state_ref().write().config_mut() { + fn set_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; + #[call(set_distributed_option_extension)] + #[expr($?;Ok(self))] + fn with_distributed_option_extension(self, t: T) -> Result; + + fn set_distributed_option_extension_from_headers(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>; + #[call(set_distributed_option_extension_from_headers)] + #[expr($?;Ok(self))] + fn with_distributed_option_extension_from_headers(self, headers: &HeaderMap) -> Result; + + fn set_distributed_user_codec(&mut self, codec: T); + #[call(set_distributed_user_codec)] + #[expr($;self)] + fn with_distributed_user_codec(self, codec: T) -> Self; + + fn set_distributed_channel_resolver(&mut self, resolver: T); + #[call(set_distributed_channel_resolver)] + #[expr($;self)] + fn with_distributed_channel_resolver(self, resolver: T) -> Self; + } + } +} diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index 0be1c5b..adc7d74 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -1,12 +1,12 @@ use super::service::StageKey; -use crate::composed_extension_codec::ComposedPhysicalExtensionCodec; +use crate::common::ComposedPhysicalExtensionCodec; use crate::config_extension_ext::ContextGrpcMetadata; use crate::errors::datafusion_error_to_tonic_status; use crate::flight_service::service::ArrowFlightEndpoint; use crate::flight_service::session_builder::DistributedSessionBuilderContext; use crate::plan::{DistributedCodec, PartitionGroup}; use crate::stage::{stage_from_proto, ExecutionStage, ExecutionStageProto}; -use crate::user_provided_codec::get_user_codec; +use crate::user_codec_ext::get_distributed_user_codec; use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; use arrow_flight::flight_service_server::FlightService; @@ -117,7 +117,7 @@ impl ArrowFlightEndpoint { let mut combined_codec = ComposedPhysicalExtensionCodec::default(); combined_codec.push(DistributedCodec); - if let Some(ref user_codec) = get_user_codec(state.config()) { + if let Some(ref user_codec) = get_distributed_user_codec(state.config()) { combined_codec.push_arc(Arc::clone(user_codec)); } @@ -130,7 +130,6 @@ impl ArrowFlightEndpoint { // Add the extensions that might be required for ExecutionPlan nodes in the plan let config = state.config_mut(); - config.set_extension(Arc::clone(&self.channel_manager)); config.set_extension(stage.clone()); config.set_extension(Arc::new(ContextGrpcMetadata(headers))); diff --git a/src/flight_service/mod.rs b/src/flight_service/mod.rs index b389074..8277e32 100644 --- a/src/flight_service/mod.rs +++ b/src/flight_service/mod.rs @@ -7,4 +7,5 @@ pub(crate) use do_get::DoGet; pub use service::{ArrowFlightEndpoint, StageKey}; pub use session_builder::{ DefaultSessionBuilder, DistributedSessionBuilder, DistributedSessionBuilderContext, + MappedDistributedSessionBuilder, MappedDistributedSessionBuilderExt, }; diff --git a/src/flight_service/service.rs b/src/flight_service/service.rs index 78e6bdb..403ec20 100644 --- a/src/flight_service/service.rs +++ b/src/flight_service/service.rs @@ -1,8 +1,5 @@ -use crate::channel_manager::ChannelManager; -use crate::flight_service::session_builder::DefaultSessionBuilder; use crate::flight_service::DistributedSessionBuilder; use crate::stage::ExecutionStage; -use crate::ChannelResolver; use arrow_flight::flight_service_server::FlightService; use arrow_flight::{ Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, @@ -32,7 +29,6 @@ pub struct StageKey { } pub struct ArrowFlightEndpoint { - pub(super) channel_manager: Arc, pub(super) runtime: Arc, #[allow(clippy::type_complexity)] pub(super) stages: DashMap)>>>, @@ -40,21 +36,13 @@ pub struct ArrowFlightEndpoint { } impl ArrowFlightEndpoint { - pub fn new(channel_resolver: impl ChannelResolver + Send + Sync + 'static) -> Self { + pub fn new(session_builder: impl DistributedSessionBuilder + Send + Sync + 'static) -> Self { Self { - channel_manager: Arc::new(ChannelManager::new(channel_resolver)), runtime: Arc::new(RuntimeEnv::default()), stages: DashMap::new(), - session_builder: Arc::new(DefaultSessionBuilder), + session_builder: Arc::new(session_builder), } } - - pub fn with_session_builder( - &mut self, - session_builder: impl DistributedSessionBuilder + Send + Sync + 'static, - ) { - self.session_builder = Arc::new(session_builder); - } } #[async_trait] diff --git a/src/flight_service/session_builder.rs b/src/flight_service/session_builder.rs index 3decda0..3cf617b 100644 --- a/src/flight_service/session_builder.rs +++ b/src/flight_service/session_builder.rs @@ -28,7 +28,7 @@ pub trait DistributedSessionBuilder { /// # use datafusion::execution::{FunctionRegistry, SessionState, SessionStateBuilder}; /// # use datafusion::physical_plan::ExecutionPlan; /// # use datafusion_proto::physical_plan::PhysicalExtensionCodec; - /// # use datafusion_distributed::{with_user_codec, DistributedSessionBuilder, DistributedSessionBuilderContext}; + /// # use datafusion_distributed::{DistributedExt, DistributedSessionBuilder, DistributedSessionBuilderContext}; /// /// #[derive(Debug)] /// struct CustomExecCodec; @@ -49,12 +49,10 @@ pub trait DistributedSessionBuilder { /// #[async_trait] /// impl DistributedSessionBuilder for CustomSessionBuilder { /// async fn build_session_state(&self, ctx: DistributedSessionBuilderContext) -> Result { - /// let builder = SessionStateBuilder::new() + /// let mut builder = SessionStateBuilder::new() /// .with_runtime_env(ctx.runtime_env.clone()) /// .with_default_features(); - /// - /// let builder = with_user_codec(builder, CustomExecCodec); - /// + /// builder.set_distributed_user_codec(CustomExecCodec); /// // Add your UDFs, optimization rules, etc... /// /// Ok(builder.build()) @@ -85,6 +83,7 @@ impl DistributedSessionBuilder for DefaultSessionBuilder { } } +/// Implementation of [DistributedSessionBuilder] for any async function that returns a [Result] #[async_trait] impl DistributedSessionBuilder for F where @@ -98,3 +97,70 @@ where self(ctx).await } } + +pub trait MappedDistributedSessionBuilderExt { + /// Maps an existing [DistributedSessionBuilder] allowing to add further extensions + /// to its already built [SessionStateBuilder]. + /// + /// Useful if there's already a [DistributedSessionBuilder] that needs to be extended + /// with further capabilities. + /// + /// Example: + /// + /// ```rust + /// # use datafusion::execution::SessionStateBuilder; + /// # use datafusion_distributed::{DefaultSessionBuilder, MappedDistributedSessionBuilderExt}; + /// + /// let session_builder = DefaultSessionBuilder + /// .map(|b: SessionStateBuilder| { + /// // Add further things. + /// Ok(b.build()) + /// }); + /// ``` + fn map(self, f: F) -> MappedDistributedSessionBuilder + where + Self: Sized, + F: Fn(SessionStateBuilder) -> Result; +} + +impl MappedDistributedSessionBuilderExt for T { + fn map(self, f: F) -> MappedDistributedSessionBuilder + where + Self: Sized, + { + MappedDistributedSessionBuilder { + inner: self, + f: Arc::new(f), + } + } +} + +pub struct MappedDistributedSessionBuilder { + inner: T, + f: Arc, +} + +impl Clone for MappedDistributedSessionBuilder { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + f: self.f.clone(), + } + } +} + +#[async_trait] +impl DistributedSessionBuilder for MappedDistributedSessionBuilder +where + T: DistributedSessionBuilder + Send + Sync + 'static, + F: Fn(SessionStateBuilder) -> Result + Send + Sync, +{ + async fn build_session_state( + &self, + ctx: DistributedSessionBuilderContext, + ) -> Result { + let state = self.inner.build_session_state(ctx).await?; + let builder = SessionStateBuilder::new_from_existing(state); + (self.f)(builder) + } +} diff --git a/src/lib.rs b/src/lib.rs index dbf8a5a..1746695 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,27 +1,27 @@ #![deny(clippy::all)] -mod channel_manager; +mod channel_manager_ext; mod common; -mod composed_extension_codec; mod config_extension_ext; +mod distributed_ext; mod errors; mod flight_service; mod physical_optimizer; mod plan; mod stage; mod task; -mod user_provided_codec; +mod user_codec_ext; #[cfg(any(feature = "integration", test))] pub mod test_utils; -pub use channel_manager::{BoxCloneSyncChannel, ChannelManager, ChannelResolver}; -pub use config_extension_ext::ConfigExtensionExt; +pub use channel_manager_ext::{BoxCloneSyncChannel, ChannelResolver}; +pub use distributed_ext::DistributedExt; pub use flight_service::{ ArrowFlightEndpoint, DefaultSessionBuilder, DistributedSessionBuilder, - DistributedSessionBuilderContext, + DistributedSessionBuilderContext, MappedDistributedSessionBuilder, + MappedDistributedSessionBuilderExt, }; pub use physical_optimizer::DistributedPhysicalOptimizerRule; pub use plan::ArrowFlightReadExec; pub use stage::{display_stage_graphviz, ExecutionStage}; -pub use user_provided_codec::{add_user_codec, with_user_codec}; diff --git a/src/plan/arrow_flight_read.rs b/src/plan/arrow_flight_read.rs index 36d7786..04cf6d9 100644 --- a/src/plan/arrow_flight_read.rs +++ b/src/plan/arrow_flight_read.rs @@ -1,12 +1,13 @@ use super::combined::CombinedRecordBatchStream; -use crate::channel_manager::ChannelManager; -use crate::composed_extension_codec::ComposedPhysicalExtensionCodec; +use crate::channel_manager_ext::get_distributed_channel_resolver; +use crate::common::ComposedPhysicalExtensionCodec; use crate::config_extension_ext::ContextGrpcMetadata; use crate::errors::tonic_status_to_datafusion_error; use crate::flight_service::{DoGet, StageKey}; use crate::plan::DistributedCodec; use crate::stage::{proto_from_stage, ExecutionStage}; -use crate::user_provided_codec::get_user_codec; +use crate::user_codec_ext::get_distributed_user_codec; +use crate::ChannelResolver; use arrow_flight::decode::FlightRecordBatchStream; use arrow_flight::error::FlightError; use arrow_flight::flight_service_client::FlightServiceClient; @@ -158,7 +159,12 @@ impl ExecutionPlan for ArrowFlightReadExec { }; // get the channel manager and current stage from our context - let channel_manager: ChannelManager = context.as_ref().try_into()?; + let Some(channel_resolver) = get_distributed_channel_resolver(context.session_config()) + else { + return exec_err!( + "ArrowFlightReadExec requires a ChannelResolver in the session config" + ); + }; let stage = context .session_config() @@ -183,7 +189,7 @@ impl ExecutionPlan for ArrowFlightReadExec { let mut combined_codec = ComposedPhysicalExtensionCodec::default(); combined_codec.push(DistributedCodec {}); - if let Some(ref user_codec) = get_user_codec(context.session_config()) { + if let Some(ref user_codec) = get_distributed_user_codec(context.session_config()) { combined_codec.push_arc(Arc::clone(user_codec)); } @@ -199,8 +205,8 @@ impl ExecutionPlan for ArrowFlightReadExec { let stream = async move { let futs = child_stage_tasks.iter().enumerate().map(|(i, task)| { - let child_stage_proto_capture = child_stage_proto.clone(); - let channel_manager_capture = channel_manager.clone(); + let child_stage_proto = child_stage_proto.clone(); + let channel_resolver = channel_resolver.clone(); let schema = schema.clone(); let query_id = query_id.clone(); let flight_metadata = flight_metadata @@ -218,7 +224,7 @@ impl ExecutionPlan for ArrowFlightReadExec { ))?; let ticket_bytes = DoGet { - stage_proto: Some(child_stage_proto_capture), + stage_proto: Some(child_stage_proto), partition: partition as u64, stage_key: Some(key), task_number: i as u64, @@ -235,7 +241,7 @@ impl ExecutionPlan for ArrowFlightReadExec { flight_metadata, &url, schema.clone(), - &channel_manager_capture, + &channel_resolver, ) .await } @@ -261,7 +267,7 @@ async fn stream_from_stage_task( metadata: ContextGrpcMetadata, url: &Url, schema: SchemaRef, - channel_manager: &ChannelManager, + channel_manager: &impl ChannelResolver, ) -> Result { let channel = channel_manager.get_channel_for_url(url).await?; diff --git a/src/stage/execution_stage.rs b/src/stage/execution_stage.rs index d66591c..c4ba9b8 100644 --- a/src/stage/execution_stage.rs +++ b/src/stage/execution_stage.rs @@ -1,19 +1,19 @@ use std::sync::Arc; -use datafusion::common::internal_err; +use datafusion::common::{exec_err, internal_err}; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::TaskContext; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; +use crate::channel_manager_ext::get_distributed_channel_resolver; +use crate::task::ExecutionTask; +use crate::ChannelResolver; use itertools::Itertools; use rand::Rng; use url::Url; use uuid::Uuid; -use crate::task::ExecutionTask; -use crate::ChannelManager; - /// A unit of isolation for a portion of a physical execution plan /// that can be executed independently and across a network boundary. /// It implements [`ExecutionPlan`] and can be executed to produce a @@ -171,11 +171,8 @@ impl ExecutionStage { format!("Stage {:<3}{}", self.num, child_str) } - pub fn try_assign( - self, - channel_manager: impl TryInto, - ) -> Result { - let urls: Vec = channel_manager.try_into()?.get_urls()?; + pub fn try_assign(self, channel_resolver: &impl ChannelResolver) -> Result { + let urls: Vec = channel_resolver.get_urls()?; if urls.is_empty() { return internal_err!("No URLs found in ChannelManager"); } @@ -264,14 +261,12 @@ impl ExecutionPlan for ExecutionStage { .downcast_ref::() .expect("Unwrapping myself should always work"); - let channel_manager = context - .session_config() - .get_extension::() - .ok_or(DataFusionError::Execution( - "ChannelManager not found in session config".to_string(), - ))?; + let Some(channel_resolver) = get_distributed_channel_resolver(context.session_config()) + else { + return exec_err!("ChannelManager not found in session config"); + }; - let urls = channel_manager.get_urls()?; + let urls = channel_resolver.get_urls()?; let assigned_stage = stage .try_assign_urls(&urls) diff --git a/src/test_utils/localhost.rs b/src/test_utils/localhost.rs index a782807..c20472d 100644 --- a/src/test_utils/localhost.rs +++ b/src/test_utils/localhost.rs @@ -1,12 +1,14 @@ use crate::{ - ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelManager, ChannelResolver, + ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt, DistributedSessionBuilder, DistributedSessionBuilderContext, + MappedDistributedSessionBuilderExt, }; use arrow_flight::flight_service_server::FlightServiceServer; use async_trait::async_trait; use datafusion::common::runtime::JoinSet; use datafusion::common::DataFusionError; use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::SessionStateBuilder; use datafusion::prelude::SessionContext; use std::error::Error; use std::sync::Arc; @@ -42,12 +44,17 @@ where .collect(); let channel_resolver = LocalHostChannelResolver::new(ports.clone()); + let session_builder = session_builder.map(move |builder: SessionStateBuilder| { + let channel_resolver = channel_resolver.clone(); + Ok(builder + .with_distributed_channel_resolver(channel_resolver) + .build()) + }); let mut join_set = JoinSet::new(); for listener in listeners { - let channel_resolver = channel_resolver.clone(); let session_builder = session_builder.clone(); join_set.spawn(async move { - spawn_flight_service(channel_resolver, session_builder, listener) + spawn_flight_service(session_builder, listener) .await .unwrap(); }); @@ -61,9 +68,6 @@ where }) .await .unwrap(); - state - .config_mut() - .set_extension(Arc::new(ChannelManager::new(channel_resolver))); state.config_mut().options_mut().execution.target_partitions = 3; (SessionContext::from(state), join_set) @@ -102,12 +106,10 @@ impl ChannelResolver for LocalHostChannelResolver { } pub async fn spawn_flight_service( - channel_resolver: impl ChannelResolver + Send + Sync + 'static, session_builder: impl DistributedSessionBuilder + Send + Sync + 'static, incoming: TcpListener, ) -> Result<(), Box> { - let mut endpoint = ArrowFlightEndpoint::new(channel_resolver); - endpoint.with_session_builder(session_builder); + let endpoint = ArrowFlightEndpoint::new(session_builder); let incoming = tokio_stream::wrappers::TcpListenerStream::new(incoming); diff --git a/src/user_codec_ext.rs b/src/user_codec_ext.rs new file mode 100644 index 0000000..f1e7ab9 --- /dev/null +++ b/src/user_codec_ext.rs @@ -0,0 +1,18 @@ +use datafusion::prelude::SessionConfig; +use datafusion_proto::physical_plan::PhysicalExtensionCodec; +use std::sync::Arc; + +pub struct UserProvidedCodec(Arc); + +pub(crate) fn set_distributed_user_codec( + cfg: &mut SessionConfig, + codec: T, +) { + cfg.set_extension(Arc::new(UserProvidedCodec(Arc::new(codec)))) +} + +pub(crate) fn get_distributed_user_codec( + cfg: &SessionConfig, +) -> Option> { + Some(Arc::clone(&cfg.get_extension::()?.0)) +} diff --git a/src/user_provided_codec.rs b/src/user_provided_codec.rs deleted file mode 100644 index 4881165..0000000 --- a/src/user_provided_codec.rs +++ /dev/null @@ -1,126 +0,0 @@ -use datafusion::execution::SessionStateBuilder; -use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_proto::physical_plan::PhysicalExtensionCodec; -use std::sync::Arc; - -pub struct UserProvidedCodec(Arc); - -/// Injects a user-defined codec that is capable of encoding/decoding custom execution nodes. -/// It will inject the codec as a config extension in the provided [SessionConfig], [SessionContext] -/// or [SessionStateBuilder]. -/// -/// Example: -/// -/// ``` -/// # use std::sync::Arc; -/// # use datafusion::execution::{SessionState, FunctionRegistry, SessionStateBuilder}; -/// # use datafusion::physical_plan::ExecutionPlan; -/// # use datafusion_proto::physical_plan::PhysicalExtensionCodec; -/// # use datafusion_distributed::{add_user_codec}; -/// -/// #[derive(Debug)] -/// struct CustomExecCodec; -/// -/// impl PhysicalExtensionCodec for CustomExecCodec { -/// fn try_decode(&self, buf: &[u8], inputs: &[Arc], registry: &dyn FunctionRegistry) -> datafusion::common::Result> { -/// todo!() -/// } -/// -/// fn try_encode(&self, node: Arc, buf: &mut Vec) -> datafusion::common::Result<()> { -/// todo!() -/// } -/// } -/// -/// let builder = SessionStateBuilder::new(); -/// let mut state = builder.build(); -/// add_user_codec(state.config_mut(), CustomExecCodec); -/// ``` -#[allow(private_bounds)] -pub fn add_user_codec( - transport: &mut impl UserCodecTransport, - codec: impl PhysicalExtensionCodec + 'static, -) { - transport.set(codec); -} - -/// Adds a user-defined codec that is capable of encoding/decoding custom execution nodes. -/// It returns the [SessionContext], [SessionConfig] or [SessionStateBuilder] passed on the first -/// argument with the user-defined codec already placed into the config extensions. -/// -/// Example: -/// -/// ``` -/// # use std::sync::Arc; -/// # use datafusion::execution::{SessionState, FunctionRegistry, SessionStateBuilder}; -/// # use datafusion::physical_plan::ExecutionPlan; -/// # use datafusion_proto::physical_plan::PhysicalExtensionCodec; -/// # use datafusion_distributed::with_user_codec; -/// -/// #[derive(Debug)] -/// struct CustomExecCodec; -/// -/// impl PhysicalExtensionCodec for CustomExecCodec { -/// fn try_decode(&self, buf: &[u8], inputs: &[Arc], registry: &dyn FunctionRegistry) -> datafusion::common::Result> { -/// todo!() -/// } -/// -/// fn try_encode(&self, node: Arc, buf: &mut Vec) -> datafusion::common::Result<()> { -/// todo!() -/// } -/// } -/// -/// let builder = SessionStateBuilder::new(); -/// let builder = with_user_codec(builder, CustomExecCodec); -/// let state = builder.build(); -/// ``` -#[allow(private_bounds)] -pub fn with_user_codec( - mut transport: T, - codec: impl PhysicalExtensionCodec + 'static, -) -> T { - transport.set(codec); - transport -} - -#[allow(private_bounds)] -pub(crate) fn get_user_codec( - transport: &impl UserCodecTransport, -) -> Option> { - transport.get() -} - -trait UserCodecTransport { - fn set(&mut self, codec: impl PhysicalExtensionCodec + 'static); - fn get(&self) -> Option>; -} - -impl UserCodecTransport for SessionConfig { - fn set(&mut self, codec: impl PhysicalExtensionCodec + 'static) { - self.set_extension(Arc::new(UserProvidedCodec(Arc::new(codec)))); - } - - fn get(&self) -> Option> { - Some(Arc::clone(&self.get_extension::()?.0)) - } -} - -impl UserCodecTransport for SessionContext { - fn set(&mut self, codec: impl PhysicalExtensionCodec + 'static) { - self.state_ref().write().config_mut().set(codec) - } - - fn get(&self) -> Option> { - self.state_ref().read().config().get() - } -} - -impl UserCodecTransport for SessionStateBuilder { - fn set(&mut self, codec: impl PhysicalExtensionCodec + 'static) { - self.config().get_or_insert_default().set(codec); - } - - fn get(&self) -> Option> { - // Nobody will never want to retriever a user codec from a SessionStateBuilder - None - } -} diff --git a/tests/custom_config_extension.rs b/tests/custom_config_extension.rs index af84a92..a08b615 100644 --- a/tests/custom_config_extension.rs +++ b/tests/custom_config_extension.rs @@ -14,10 +14,8 @@ mod tests { execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; use datafusion_distributed::test_utils::localhost::start_localhost_context; - use datafusion_distributed::{ - add_user_codec, ConfigExtensionExt, DistributedSessionBuilderContext, - }; use datafusion_distributed::{ArrowFlightReadExec, DistributedPhysicalOptimizerRule}; + use datafusion_distributed::{DistributedExt, DistributedSessionBuilderContext}; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use futures::TryStreamExt; use prost::Message; @@ -30,17 +28,16 @@ mod tests { async fn build_state( ctx: DistributedSessionBuilderContext, ) -> Result { - let mut state = SessionStateBuilder::new() + Ok(SessionStateBuilder::new() .with_runtime_env(ctx.runtime_env) .with_default_features() - .build(); - state.retrieve_distributed_option_extension::(&ctx.headers)?; - add_user_codec(state.config_mut(), CustomConfigExtensionRequiredExecCodec); - Ok(state) + .with_distributed_option_extension_from_headers::(&ctx.headers)? + .with_distributed_user_codec(CustomConfigExtensionRequiredExecCodec) + .build()) } let (mut ctx, _guard) = start_localhost_context(3, build_state).await; - ctx.add_distributed_option_extension(CustomExtension { + ctx.set_distributed_option_extension(CustomExtension { foo: "foo".to_string(), bar: 1, baz: true, diff --git a/tests/custom_extension_codec.rs b/tests/custom_extension_codec.rs index fa86082..fff1c55 100644 --- a/tests/custom_extension_codec.rs +++ b/tests/custom_extension_codec.rs @@ -24,7 +24,7 @@ mod tests { }; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::{ - add_user_codec, assert_snapshot, DistributedSessionBuilderContext, + assert_snapshot, DistributedExt, DistributedSessionBuilderContext, }; use datafusion_distributed::{ArrowFlightReadExec, DistributedPhysicalOptimizerRule}; use datafusion_proto::physical_plan::PhysicalExtensionCodec; @@ -41,12 +41,11 @@ mod tests { async fn build_state( ctx: DistributedSessionBuilderContext, ) -> Result { - let mut state = SessionStateBuilder::new() + Ok(SessionStateBuilder::new() .with_runtime_env(ctx.runtime_env) .with_default_features() - .build(); - add_user_codec(state.config_mut(), Int64ListExecCodec); - Ok(state) + .with_distributed_user_codec(Int64ListExecCodec) + .build()) } let (ctx, _guard) = start_localhost_context(3, build_state).await; diff --git a/tests/error_propagation.rs b/tests/error_propagation.rs index b030d09..e4424fc 100644 --- a/tests/error_propagation.rs +++ b/tests/error_propagation.rs @@ -13,7 +13,7 @@ mod tests { }; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::{ - add_user_codec, ArrowFlightReadExec, DistributedPhysicalOptimizerRule, + ArrowFlightReadExec, DistributedExt, DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext, }; use datafusion_proto::physical_plan::PhysicalExtensionCodec; @@ -30,12 +30,11 @@ mod tests { async fn build_state( ctx: DistributedSessionBuilderContext, ) -> Result { - let mut state = SessionStateBuilder::new() + Ok(SessionStateBuilder::new() .with_runtime_env(ctx.runtime_env) .with_default_features() - .build(); - add_user_codec(state.config_mut(), ErrorExecCodec); - Ok(state) + .with_distributed_user_codec(ErrorExecCodec) + .build()) } let (ctx, _guard) = start_localhost_context(3, build_state).await;