diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index 507e408..f3cd2bb 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -36,14 +36,16 @@ use datafusion::datasource::listing::{ }; use datafusion::datasource::{MemTable, TableProvider}; use datafusion::error::{DataFusionError, Result}; -use datafusion::execution::SessionStateBuilder; +use datafusion::execution::{SessionState, SessionStateBuilder}; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::{collect, displayable}; use datafusion::prelude::*; use crate::util::{print_memory_stats, BenchmarkRun, CommonOpt, QueryResult}; use datafusion_distributed::test_utils::localhost::start_localhost_context; -use datafusion_distributed::{DistributedPhysicalOptimizerRule, SessionBuilder}; +use datafusion_distributed::{ + DistributedPhysicalOptimizerRule, DistributedSessionBuilder, DistributedSessionBuilderContext, +}; use log::info; use structopt::StructOpt; @@ -110,11 +112,13 @@ pub struct RunOpt { } #[async_trait] -impl SessionBuilder for RunOpt { - fn session_state_builder( +impl DistributedSessionBuilder for RunOpt { + async fn build_session_state( &self, - mut builder: SessionStateBuilder, - ) -> Result { + _ctx: DistributedSessionBuilderContext, + ) -> Result { + let mut builder = SessionStateBuilder::new().with_default_features(); + let mut config = self .common .config()? @@ -145,17 +149,14 @@ impl SessionBuilder for RunOpt { builder = builder.with_physical_optimizer_rule(Arc::new(rule)); } - Ok(builder + let state = builder .with_config(config) - .with_runtime_env(rt_builder.build_arc()?)) - } + .with_runtime_env(rt_builder.build_arc()?) + .build(); - async fn session_context( - &self, - ctx: SessionContext, - ) -> std::result::Result { + let ctx = SessionContext::from(state); self.register_tables(&ctx).await?; - Ok(ctx) + Ok(ctx.state()) } } diff --git a/src/common/mod.rs b/src/common/mod.rs index 572b996..6e77e1a 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,2 +1,3 @@ +#[allow(unused)] pub mod ttl_map; pub mod util; diff --git a/src/common/ttl_map.rs b/src/common/ttl_map.rs index ff66c73..568b1f5 100644 --- a/src/common/ttl_map.rs +++ b/src/common/ttl_map.rs @@ -93,7 +93,7 @@ where shard.insert(key); } BucketOp::Clear => { - let keys_to_delete = std::mem::replace(&mut shard, HashSet::new()); + let keys_to_delete = std::mem::take(&mut shard); for key in keys_to_delete { data.remove(&key); } @@ -252,14 +252,14 @@ where /// run_gc_loop will continuously clear expired entries from the map, checking every `period`. The /// function terminates if `shutdown` is signalled. - async fn run_gc_loop(time: Arc, period: Duration, buckets: &Vec>) { + async fn run_gc_loop(time: Arc, period: Duration, buckets: &[Bucket]) { loop { tokio::time::sleep(period).await; Self::gc(time.clone(), buckets); } } - fn gc(time: Arc, buckets: &Vec>) { + fn gc(time: Arc, buckets: &[Bucket]) { let index = time.load(std::sync::atomic::Ordering::SeqCst) % buckets.len() as u64; buckets[index as usize].clear(); time.fetch_add(1, std::sync::atomic::Ordering::SeqCst); diff --git a/src/config_extension_ext.rs b/src/config_extension_ext.rs index 41f5141..bab103a 100644 --- a/src/config_extension_ext.rs +++ b/src/config_extension_ext.rs @@ -27,9 +27,9 @@ pub trait ConfigExtensionExt { /// # use async_trait::async_trait; /// # use datafusion::common::{extensions_options, DataFusionError}; /// # use datafusion::config::ConfigExtension; - /// # use datafusion::execution::SessionState; + /// # use datafusion::execution::{SessionState, SessionStateBuilder}; /// # use datafusion::prelude::SessionConfig; - /// # use datafusion_distributed::{ConfigExtensionExt, SessionBuilder}; + /// # use datafusion_distributed::{ConfigExtensionExt, DistributedSessionBuilder, DistributedSessionBuilderContext}; /// /// extensions_options! { /// pub struct CustomExtension { @@ -52,11 +52,13 @@ pub trait ConfigExtensionExt { /// struct MyCustomSessionBuilder; /// /// #[async_trait] - /// impl SessionBuilder for MyCustomSessionBuilder { - /// async fn session_state(&self, mut state: SessionState) -> Result { + /// impl DistributedSessionBuilder for MyCustomSessionBuilder { + /// async fn build_session_state(&self, ctx: DistributedSessionBuilderContext) -> Result { + /// let mut state = SessionStateBuilder::new().build(); + /// /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will /// // know how to deserialize the CustomExtension from the gRPC metadata. - /// state.retrieve_distributed_option_extension::()?; + /// state.retrieve_distributed_option_extension::(&ctx.headers)?; /// Ok(state) /// } /// } @@ -76,9 +78,9 @@ pub trait ConfigExtensionExt { /// # use async_trait::async_trait; /// # use datafusion::common::{extensions_options, DataFusionError}; /// # use datafusion::config::ConfigExtension; - /// # use datafusion::execution::SessionState; + /// # use datafusion::execution::{SessionState, SessionStateBuilder}; /// # use datafusion::prelude::SessionConfig; - /// # use datafusion_distributed::{ConfigExtensionExt, SessionBuilder}; + /// # use datafusion_distributed::{ConfigExtensionExt, DistributedSessionBuilder, DistributedSessionBuilderContext}; /// /// extensions_options! { /// pub struct CustomExtension { @@ -101,17 +103,19 @@ pub trait ConfigExtensionExt { /// struct MyCustomSessionBuilder; /// /// #[async_trait] - /// impl SessionBuilder for MyCustomSessionBuilder { - /// async fn session_state(&self, mut state: SessionState) -> Result { + /// impl DistributedSessionBuilder for MyCustomSessionBuilder { + /// async fn build_session_state(&self, ctx: DistributedSessionBuilderContext) -> Result { + /// let mut state = SessionStateBuilder::new().build(); /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will /// // know how to deserialize the CustomExtension from the gRPC metadata. - /// state.retrieve_distributed_option_extension::()?; + /// state.retrieve_distributed_option_extension::(&ctx.headers)?; /// Ok(state) /// } /// } /// ``` fn retrieve_distributed_option_extension( &mut self, + headers: &HeaderMap, ) -> Result<(), DataFusionError>; } @@ -153,20 +157,17 @@ impl ConfigExtensionExt for SessionConfig { fn retrieve_distributed_option_extension( &mut self, + headers: &HeaderMap, ) -> Result<(), DataFusionError> { - let Some(flight_metadata) = self.get_extension::() else { - return Ok(()); - }; - let mut result = T::default(); let mut found_some = false; - for (k, v) in flight_metadata.0.iter() { + for (k, v) in headers.iter() { let key = k.as_str().trim_start_matches(FLIGHT_METADATA_CONFIG_PREFIX); let prefix = format!("{}.", T::PREFIX); if key.starts_with(&prefix) { found_some = true; result.set( - &key.trim_start_matches(&prefix), + key.trim_start_matches(&prefix), v.to_str().map_err(|err| { internal_datafusion_err!("Cannot parse header value: {err}") })?, @@ -185,7 +186,7 @@ impl ConfigExtensionExt for SessionStateBuilder { delegate! { to self.config().get_or_insert_default() { fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; - fn retrieve_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + fn retrieve_distributed_option_extension(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>; } } } @@ -194,7 +195,7 @@ impl ConfigExtensionExt for SessionState { delegate! { to self.config_mut() { fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; - fn retrieve_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + fn retrieve_distributed_option_extension(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>; } } } @@ -203,7 +204,7 @@ impl ConfigExtensionExt for SessionContext { delegate! { to self.state_ref().write().config_mut() { fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; - fn retrieve_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + fn retrieve_distributed_option_extension(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>; } } } @@ -212,17 +213,6 @@ impl ConfigExtensionExt for SessionContext { pub(crate) struct ContextGrpcMetadata(pub HeaderMap); impl ContextGrpcMetadata { - pub(crate) fn from_headers(metadata: HeaderMap) -> Self { - let mut new = HeaderMap::new(); - for (k, v) in metadata.into_iter() { - let Some(k) = k else { continue }; - if k.as_str().starts_with(FLIGHT_METADATA_CONFIG_PREFIX) { - new.insert(k, v); - } - } - Self(new) - } - fn merge(mut self, other: Self) -> Self { for (k, v) in other.0.into_iter() { let Some(k) = k else { continue }; @@ -246,16 +236,16 @@ mod tests { fn test_propagation() -> Result<(), Box> { let mut config = SessionConfig::new(); - let mut opt = CustomExtension::default(); - opt.foo = "foo".to_string(); - opt.bar = 1; - opt.baz = true; + let opt = CustomExtension { + foo: "".to_string(), + bar: 0, + baz: false, + }; config.add_distributed_option_extension(opt)?; - + let metadata = config.get_extension::().unwrap(); let mut new_config = SessionConfig::new(); - new_config.set_extension(config.get_extension::().unwrap()); - new_config.retrieve_distributed_option_extension::()?; + new_config.retrieve_distributed_option_extension::(&metadata.0)?; let opt = get_ext::(&config); let new_opt = get_ext::(&new_config); @@ -294,12 +284,16 @@ mod tests { fn test_new_extension_overwrites_previous() -> Result<(), Box> { let mut config = SessionConfig::new(); - let mut opt1 = CustomExtension::default(); - opt1.foo = "first".to_string(); + let opt1 = CustomExtension { + foo: "first".to_string(), + ..Default::default() + }; config.add_distributed_option_extension(opt1)?; - let mut opt2 = CustomExtension::default(); - opt2.bar = 42; + let opt2 = CustomExtension { + bar: 42, + ..Default::default() + }; config.add_distributed_option_extension(opt2)?; let flight_metadata = config.get_extension::().unwrap(); @@ -317,7 +311,7 @@ mod tests { fn test_propagate_no_metadata() -> Result<(), Box> { let mut config = SessionConfig::new(); - config.retrieve_distributed_option_extension::()?; + config.retrieve_distributed_option_extension::(&Default::default())?; let extension = config.options().extensions.get::(); assert!(extension.is_none()); @@ -330,13 +324,11 @@ mod tests { let mut config = SessionConfig::new(); let mut header_map = HeaderMap::new(); header_map.insert( - HeaderName::from_str("x-datafusion-distributed-other.setting").unwrap(), + HeaderName::from_str("x-datafusion-distributed-config-other.setting").unwrap(), HeaderValue::from_str("value").unwrap(), ); - let flight_metadata = ContextGrpcMetadata::from_headers(header_map); - config.set_extension(std::sync::Arc::new(flight_metadata)); - config.retrieve_distributed_option_extension::()?; + config.retrieve_distributed_option_extension::(&header_map)?; let extension = config.options().extensions.get::(); assert!(extension.is_none()); @@ -348,13 +340,17 @@ mod tests { fn test_multiple_extensions_different_prefixes() -> Result<(), Box> { let mut config = SessionConfig::new(); - let mut custom_opt = CustomExtension::default(); - custom_opt.foo = "custom_value".to_string(); - custom_opt.bar = 123; + let custom_opt = CustomExtension { + foo: "custom_value".to_string(), + bar: 123, + ..Default::default() + }; - let mut another_opt = AnotherExtension::default(); - another_opt.setting1 = "other".to_string(); - another_opt.setting2 = 456; + let another_opt = AnotherExtension { + setting1: "other".to_string(), + setting2: 456, + ..Default::default() + }; config.add_distributed_option_extension(custom_opt)?; config.add_distributed_option_extension(another_opt)?; @@ -384,9 +380,8 @@ mod tests { ); let mut new_config = SessionConfig::new(); - new_config.set_extension(flight_metadata); - new_config.retrieve_distributed_option_extension::()?; - new_config.retrieve_distributed_option_extension::()?; + new_config.retrieve_distributed_option_extension::(metadata)?; + new_config.retrieve_distributed_option_extension::(metadata)?; let propagated_custom = get_ext::(&new_config); let propagated_another = get_ext::(&new_config); diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index 2c14bba..0be1c5b 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -3,6 +3,7 @@ use crate::composed_extension_codec::ComposedPhysicalExtensionCodec; use crate::config_extension_ext::ContextGrpcMetadata; use crate::errors::datafusion_error_to_tonic_status; use crate::flight_service::service::ArrowFlightEndpoint; +use crate::flight_service::session_builder::DistributedSessionBuilderContext; use crate::plan::{DistributedCodec, PartitionGroup}; use crate::stage::{stage_from_proto, ExecutionStage, ExecutionStageProto}; use crate::user_provided_codec::get_user_codec; @@ -10,9 +11,7 @@ use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; use arrow_flight::flight_service_server::FlightService; use arrow_flight::Ticket; -use datafusion::execution::{SessionState, SessionStateBuilder}; -use datafusion::optimizer::OptimizerConfig; -use datafusion::prelude::SessionConfig; +use datafusion::execution::SessionState; use futures::TryStreamExt; use prost::Message; use std::sync::Arc; @@ -90,7 +89,7 @@ impl ArrowFlightEndpoint { async fn get_state_and_stage( &self, doget: DoGet, - metadata: MetadataMap, + metadata_map: MetadataMap, ) -> Result<(SessionState, Arc), Status> { let key = doget .stage_key @@ -106,54 +105,34 @@ impl ArrowFlightEndpoint { .stage_proto .ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?; - let mut config = SessionConfig::default(); - config.set_extension(Arc::new(ContextGrpcMetadata::from_headers( - metadata.into_headers(), - ))); - - let state_builder = SessionStateBuilder::new() - .with_runtime_env(Arc::clone(&self.runtime)) - .with_config(config) - .with_default_features(); - - let state_builder = self - .session_builder - .session_state_builder(state_builder) - .map_err(|err| datafusion_error_to_tonic_status(&err))?; - - let state = state_builder.build(); + let headers = metadata_map.into_headers(); let mut state = self .session_builder - .session_state(state) + .build_session_state(DistributedSessionBuilderContext { + runtime_env: Arc::clone(&self.runtime), + headers: headers.clone(), + }) .await .map_err(|err| datafusion_error_to_tonic_status(&err))?; - let function_registry = - state.function_registry().ok_or(Status::invalid_argument( - "FunctionRegistry not present in newly built SessionState", - ))?; - let mut combined_codec = ComposedPhysicalExtensionCodec::default(); combined_codec.push(DistributedCodec); if let Some(ref user_codec) = get_user_codec(state.config()) { combined_codec.push_arc(Arc::clone(user_codec)); } - let stage = stage_from_proto( - stage_proto, - function_registry, - self.runtime.as_ref(), - &combined_codec, - ) - .map(Arc::new) - .map_err(|err| { - Status::invalid_argument(format!("Cannot decode stage proto: {err}")) - })?; + let stage = + stage_from_proto(stage_proto, &state, self.runtime.as_ref(), &combined_codec) + .map(Arc::new) + .map_err(|err| { + Status::invalid_argument(format!("Cannot decode stage proto: {err}")) + })?; // Add the extensions that might be required for ExecutionPlan nodes in the plan let config = state.config_mut(); config.set_extension(Arc::clone(&self.channel_manager)); config.set_extension(stage.clone()); + config.set_extension(Arc::new(ContextGrpcMetadata(headers))); Ok::<_, Status>((state, stage)) }) diff --git a/src/flight_service/mod.rs b/src/flight_service/mod.rs index e3b53c8..b389074 100644 --- a/src/flight_service/mod.rs +++ b/src/flight_service/mod.rs @@ -5,4 +5,6 @@ mod session_builder; pub(crate) use do_get::DoGet; pub use service::{ArrowFlightEndpoint, StageKey}; -pub use session_builder::{NoopSessionBuilder, SessionBuilder}; +pub use session_builder::{ + DefaultSessionBuilder, DistributedSessionBuilder, DistributedSessionBuilderContext, +}; diff --git a/src/flight_service/service.rs b/src/flight_service/service.rs index fc19269..78e6bdb 100644 --- a/src/flight_service/service.rs +++ b/src/flight_service/service.rs @@ -1,6 +1,6 @@ use crate::channel_manager::ChannelManager; -use crate::flight_service::session_builder::NoopSessionBuilder; -use crate::flight_service::SessionBuilder; +use crate::flight_service::session_builder::DefaultSessionBuilder; +use crate::flight_service::DistributedSessionBuilder; use crate::stage::ExecutionStage; use crate::ChannelResolver; use arrow_flight::flight_service_server::FlightService; @@ -36,7 +36,7 @@ pub struct ArrowFlightEndpoint { pub(super) runtime: Arc, #[allow(clippy::type_complexity)] pub(super) stages: DashMap)>>>, - pub(super) session_builder: Arc, + pub(super) session_builder: Arc, } impl ArrowFlightEndpoint { @@ -45,13 +45,13 @@ impl ArrowFlightEndpoint { channel_manager: Arc::new(ChannelManager::new(channel_resolver)), runtime: Arc::new(RuntimeEnv::default()), stages: DashMap::new(), - session_builder: Arc::new(NoopSessionBuilder), + session_builder: Arc::new(DefaultSessionBuilder), } } pub fn with_session_builder( &mut self, - session_builder: impl SessionBuilder + Send + Sync + 'static, + session_builder: impl DistributedSessionBuilder + Send + Sync + 'static, ) { self.session_builder = Arc::new(session_builder); } diff --git a/src/flight_service/session_builder.rs b/src/flight_service/session_builder.rs index 64be9dd..3decda0 100644 --- a/src/flight_service/session_builder.rs +++ b/src/flight_service/session_builder.rs @@ -1,27 +1,34 @@ use async_trait::async_trait; use datafusion::error::DataFusionError; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::{SessionState, SessionStateBuilder}; -use datafusion::prelude::SessionContext; +use http::HeaderMap; +use std::sync::Arc; + +#[derive(Debug, Clone, Default)] +pub struct DistributedSessionBuilderContext { + pub runtime_env: Arc, + pub headers: HeaderMap, +} /// Trait called by the Arrow Flight endpoint that handles distributed parts of a DataFusion -/// plan for building a DataFusion's [datafusion::prelude::SessionContext]. +/// plan for building a DataFusion's [SessionState]. #[async_trait] -pub trait SessionBuilder { - /// Takes a [SessionStateBuilder] and adds whatever is necessary for it to work, like - /// custom extension codecs, custom physical optimization rules, UDFs, UDAFs, config - /// extensions, etc... +pub trait DistributedSessionBuilder { + /// Builds a custom [SessionState] scoped to a single ArrowFlight gRPC call, allowing the + /// users to provide a customized DataFusion session with things like custom extension codecs, + /// custom physical optimization rules, UDFs, UDAFs, config extensions, etc... /// - /// Example: adding some custom extension plan codecs + /// Example: /// /// ```rust /// # use std::sync::Arc; /// # use async_trait::async_trait; /// # use datafusion::error::DataFusionError; - /// # use datafusion::execution::runtime_env::RuntimeEnv; - /// # use datafusion::execution::{FunctionRegistry, SessionStateBuilder}; + /// # use datafusion::execution::{FunctionRegistry, SessionState, SessionStateBuilder}; /// # use datafusion::physical_plan::ExecutionPlan; /// # use datafusion_proto::physical_plan::PhysicalExtensionCodec; - /// # use datafusion_distributed::{with_user_codec, SessionBuilder}; + /// # use datafusion_distributed::{with_user_codec, DistributedSessionBuilder, DistributedSessionBuilderContext}; /// /// #[derive(Debug)] /// struct CustomExecCodec; @@ -40,79 +47,54 @@ pub trait SessionBuilder { /// struct CustomSessionBuilder; /// /// #[async_trait] - /// impl SessionBuilder for CustomSessionBuilder { - /// fn session_state_builder(&self, mut builder: SessionStateBuilder) -> Result { - /// // Add your UDFs, optimization rules, etc... - /// Ok(with_user_codec(builder, CustomExecCodec)) - /// } - /// } - /// ``` - fn session_state_builder( - &self, - builder: SessionStateBuilder, - ) -> Result { - Ok(builder) - } - - /// Modifies the [SessionState] and returns it. Same as [SessionBuilder::session_state_builder] - /// but operating on an already built [SessionState]. - /// - /// Example: - /// - /// ```rust - /// # use async_trait::async_trait; - /// # use datafusion::common::DataFusionError; - /// # use datafusion::execution::SessionState; - /// # use datafusion_distributed::SessionBuilder; + /// impl DistributedSessionBuilder for CustomSessionBuilder { + /// async fn build_session_state(&self, ctx: DistributedSessionBuilderContext) -> Result { + /// let builder = SessionStateBuilder::new() + /// .with_runtime_env(ctx.runtime_env.clone()) + /// .with_default_features(); /// - /// #[derive(Clone)] - /// struct CustomSessionBuilder; + /// let builder = with_user_codec(builder, CustomExecCodec); /// - /// #[async_trait] - /// impl SessionBuilder for CustomSessionBuilder { - /// async fn session_state(&self, state: SessionState) -> Result { - /// // mutate the state adding any custom logic - /// Ok(state) - /// } - /// } - /// ``` - async fn session_state(&self, state: SessionState) -> Result { - Ok(state) - } - - /// Modifies the [SessionContext] and returns it. Same as [SessionBuilder::session_state_builder] - /// or [SessionBuilder::session_state] but operation on an already built [SessionContext]. - /// - /// Example: - /// - /// ```rust - /// # use async_trait::async_trait; - /// # use datafusion::common::DataFusionError; - /// # use datafusion::prelude::SessionContext; - /// # use datafusion_distributed::SessionBuilder; - /// - /// #[derive(Clone)] - /// struct CustomSessionBuilder; + /// // Add your UDFs, optimization rules, etc... /// - /// #[async_trait] - /// impl SessionBuilder for CustomSessionBuilder { - /// async fn session_context(&self, ctx: SessionContext) -> Result { - /// // mutate the context adding any custom logic - /// Ok(ctx) + /// Ok(builder.build()) /// } /// } /// ``` - async fn session_context( + async fn build_session_state( &self, - ctx: SessionContext, - ) -> Result { - Ok(ctx) - } + ctx: DistributedSessionBuilderContext, + ) -> Result; } -/// Noop implementation of the [SessionBuilder]. Used by default if no [SessionBuilder] is provided +/// Noop implementation of the [DistributedSessionBuilder]. Used by default if no [DistributedSessionBuilder] is provided /// while building the Arrow Flight endpoint. #[derive(Debug, Clone)] -pub struct NoopSessionBuilder; +pub struct DefaultSessionBuilder; + +#[async_trait] +impl DistributedSessionBuilder for DefaultSessionBuilder { + async fn build_session_state( + &self, + ctx: DistributedSessionBuilderContext, + ) -> Result { + Ok(SessionStateBuilder::new() + .with_runtime_env(ctx.runtime_env.clone()) + .with_default_features() + .build()) + } +} -impl SessionBuilder for NoopSessionBuilder {} +#[async_trait] +impl DistributedSessionBuilder for F +where + F: Fn(DistributedSessionBuilderContext) -> Fut + Send + Sync + 'static, + Fut: std::future::Future> + Send + 'static, +{ + async fn build_session_state( + &self, + ctx: DistributedSessionBuilderContext, + ) -> Result { + self(ctx).await + } +} diff --git a/src/lib.rs b/src/lib.rs index 40515a9..dbf8a5a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,7 +17,10 @@ pub mod test_utils; pub use channel_manager::{BoxCloneSyncChannel, ChannelManager, ChannelResolver}; pub use config_extension_ext::ConfigExtensionExt; -pub use flight_service::{ArrowFlightEndpoint, NoopSessionBuilder, SessionBuilder}; +pub use flight_service::{ + ArrowFlightEndpoint, DefaultSessionBuilder, DistributedSessionBuilder, + DistributedSessionBuilderContext, +}; pub use physical_optimizer::DistributedPhysicalOptimizerRule; pub use plan::ArrowFlightReadExec; pub use stage::{display_stage_graphviz, ExecutionStage}; diff --git a/src/test_utils/localhost.rs b/src/test_utils/localhost.rs index 83a328e..a782807 100644 --- a/src/test_utils/localhost.rs +++ b/src/test_utils/localhost.rs @@ -1,12 +1,13 @@ use crate::{ - ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelManager, ChannelResolver, SessionBuilder, + ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelManager, ChannelResolver, + DistributedSessionBuilder, DistributedSessionBuilderContext, }; use arrow_flight::flight_service_server::FlightServiceServer; use async_trait::async_trait; +use datafusion::common::runtime::JoinSet; use datafusion::common::DataFusionError; -use datafusion::execution::SessionStateBuilder; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::prelude::SessionContext; -use datafusion::{common::runtime::JoinSet, prelude::SessionConfig}; use std::error::Error; use std::sync::Arc; use std::time::Duration; @@ -19,7 +20,7 @@ pub async fn start_localhost_context( session_builder: B, ) -> (SessionContext, JoinSet<()>) where - B: SessionBuilder + Send + Sync + 'static, + B: DistributedSessionBuilder + Send + Sync + 'static, B: Clone, { let listeners = futures::future::try_join_all( @@ -53,25 +54,19 @@ where } tokio::time::sleep(Duration::from_millis(100)).await; - let config = SessionConfig::new().with_target_partitions(3); - - let builder = SessionStateBuilder::new() - .with_default_features() - .with_config(config); - let builder = session_builder.session_state_builder(builder).unwrap(); - - let state = builder.build(); - let state = session_builder.session_state(state).await.unwrap(); - - let ctx = SessionContext::new_with_state(state); - let ctx = session_builder.session_context(ctx).await.unwrap(); - - ctx.state_ref() - .write() + let mut state = session_builder + .build_session_state(DistributedSessionBuilderContext { + runtime_env: Arc::new(RuntimeEnv::default()), + headers: Default::default(), + }) + .await + .unwrap(); + state .config_mut() .set_extension(Arc::new(ChannelManager::new(channel_resolver))); + state.config_mut().options_mut().execution.target_partitions = 3; - (ctx, join_set) + (SessionContext::from(state), join_set) } #[derive(Clone)] @@ -108,7 +103,7 @@ impl ChannelResolver for LocalHostChannelResolver { pub async fn spawn_flight_service( channel_resolver: impl ChannelResolver + Send + Sync + 'static, - session_builder: impl SessionBuilder + Send + Sync + 'static, + session_builder: impl DistributedSessionBuilder + Send + Sync + 'static, incoming: TcpListener, ) -> Result<(), Box> { let mut endpoint = ArrowFlightEndpoint::new(channel_resolver); diff --git a/tests/custom_config_extension.rs b/tests/custom_config_extension.rs index d671d90..af84a92 100644 --- a/tests/custom_config_extension.rs +++ b/tests/custom_config_extension.rs @@ -1,12 +1,11 @@ #[cfg(all(feature = "integration", test))] mod tests { - use async_trait::async_trait; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::common::{extensions_options, internal_err}; use datafusion::config::ConfigExtension; use datafusion::error::DataFusionError; use datafusion::execution::{ - FunctionRegistry, SendableRecordBatchStream, SessionState, TaskContext, + FunctionRegistry, SendableRecordBatchStream, SessionState, SessionStateBuilder, TaskContext, }; use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; @@ -15,10 +14,10 @@ mod tests { execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; use datafusion_distributed::test_utils::localhost::start_localhost_context; - use datafusion_distributed::{add_user_codec, ConfigExtensionExt}; use datafusion_distributed::{ - ArrowFlightReadExec, DistributedPhysicalOptimizerRule, SessionBuilder, + add_user_codec, ConfigExtensionExt, DistributedSessionBuilderContext, }; + use datafusion_distributed::{ArrowFlightReadExec, DistributedPhysicalOptimizerRule}; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use futures::TryStreamExt; use prost::Message; @@ -28,23 +27,19 @@ mod tests { #[tokio::test] async fn custom_config_extension() -> Result<(), Box> { - #[derive(Clone)] - struct CustomSessionBuilder; - - #[async_trait] - impl SessionBuilder for CustomSessionBuilder { - async fn session_state( - &self, - mut state: SessionState, - ) -> Result { - state.retrieve_distributed_option_extension::()?; - add_user_codec(state.config_mut(), CustomConfigExtensionRequiredExecCodec); - Ok(state) - } + async fn build_state( + ctx: DistributedSessionBuilderContext, + ) -> Result { + let mut state = SessionStateBuilder::new() + .with_runtime_env(ctx.runtime_env) + .with_default_features() + .build(); + state.retrieve_distributed_option_extension::(&ctx.headers)?; + add_user_codec(state.config_mut(), CustomConfigExtensionRequiredExecCodec); + Ok(state) } - let (mut ctx, _guard) = start_localhost_context(3, CustomSessionBuilder).await; - add_user_codec(&mut ctx, CustomConfigExtensionRequiredExecCodec); + let (mut ctx, _guard) = start_localhost_context(3, build_state).await; ctx.add_distributed_option_extension(CustomExtension { foo: "foo".to_string(), bar: 1, diff --git a/tests/custom_extension_codec.rs b/tests/custom_extension_codec.rs index 3d2870d..fa86082 100644 --- a/tests/custom_extension_codec.rs +++ b/tests/custom_extension_codec.rs @@ -7,7 +7,7 @@ mod tests { use datafusion::arrow::util::pretty::pretty_format_batches; use datafusion::error::DataFusionError; use datafusion::execution::{ - FunctionRegistry, SendableRecordBatchStream, SessionStateBuilder, TaskContext, + FunctionRegistry, SendableRecordBatchStream, SessionState, SessionStateBuilder, TaskContext, }; use datafusion::logical_expr::Operator; use datafusion::physical_expr::expressions::{col, lit, BinaryExpr}; @@ -22,11 +22,11 @@ mod tests { use datafusion::physical_plan::{ displayable, execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; - use datafusion_distributed::assert_snapshot; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::{ - with_user_codec, ArrowFlightReadExec, DistributedPhysicalOptimizerRule, SessionBuilder, + add_user_codec, assert_snapshot, DistributedSessionBuilderContext, }; + use datafusion_distributed::{ArrowFlightReadExec, DistributedPhysicalOptimizerRule}; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_proto::protobuf::proto_error; use futures::{stream, TryStreamExt}; @@ -38,18 +38,18 @@ mod tests { #[tokio::test] #[ignore] async fn custom_extension_codec() -> Result<(), Box> { - #[derive(Clone)] - struct CustomSessionBuilder; - impl SessionBuilder for CustomSessionBuilder { - fn session_state_builder( - &self, - builder: SessionStateBuilder, - ) -> Result { - Ok(with_user_codec(builder, Int64ListExecCodec)) - } + async fn build_state( + ctx: DistributedSessionBuilderContext, + ) -> Result { + let mut state = SessionStateBuilder::new() + .with_runtime_env(ctx.runtime_env) + .with_default_features() + .build(); + add_user_codec(state.config_mut(), Int64ListExecCodec); + Ok(state) } - let (ctx, _guard) = start_localhost_context(3, CustomSessionBuilder).await; + let (ctx, _guard) = start_localhost_context(3, build_state).await; let single_node_plan = build_plan(false)?; assert_snapshot!(displayable(single_node_plan.as_ref()).indent(true).to_string(), @r" diff --git a/tests/distributed_aggregation.rs b/tests/distributed_aggregation.rs index 1a572f2..273b332 100644 --- a/tests/distributed_aggregation.rs +++ b/tests/distributed_aggregation.rs @@ -5,13 +5,13 @@ mod tests { use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::test_utils::parquet::register_parquet_tables; use datafusion_distributed::test_utils::plan::distribute_aggregate; - use datafusion_distributed::{assert_snapshot, NoopSessionBuilder}; + use datafusion_distributed::{assert_snapshot, DefaultSessionBuilder}; use futures::TryStreamExt; use std::error::Error; #[tokio::test] async fn distributed_aggregation() -> Result<(), Box> { - let (ctx, _guard) = start_localhost_context(3, NoopSessionBuilder).await; + let (ctx, _guard) = start_localhost_context(3, DefaultSessionBuilder).await; register_parquet_tables(&ctx).await?; let df = ctx diff --git a/tests/error_propagation.rs b/tests/error_propagation.rs index e8d2ca5..b030d09 100644 --- a/tests/error_propagation.rs +++ b/tests/error_propagation.rs @@ -3,7 +3,7 @@ mod tests { use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::error::DataFusionError; use datafusion::execution::{ - FunctionRegistry, SendableRecordBatchStream, SessionStateBuilder, TaskContext, + FunctionRegistry, SendableRecordBatchStream, SessionState, SessionStateBuilder, TaskContext, }; use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; @@ -13,7 +13,8 @@ mod tests { }; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::{ - with_user_codec, ArrowFlightReadExec, DistributedPhysicalOptimizerRule, SessionBuilder, + add_user_codec, ArrowFlightReadExec, DistributedPhysicalOptimizerRule, + DistributedSessionBuilderContext, }; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_proto::protobuf::proto_error; @@ -26,17 +27,18 @@ mod tests { #[tokio::test] async fn test_error_propagation() -> Result<(), Box> { - #[derive(Clone)] - struct CustomSessionBuilder; - impl SessionBuilder for CustomSessionBuilder { - fn session_state_builder( - &self, - builder: SessionStateBuilder, - ) -> Result { - Ok(with_user_codec(builder, ErrorExecCodec)) - } + async fn build_state( + ctx: DistributedSessionBuilderContext, + ) -> Result { + let mut state = SessionStateBuilder::new() + .with_runtime_env(ctx.runtime_env) + .with_default_features() + .build(); + add_user_codec(state.config_mut(), ErrorExecCodec); + Ok(state) } - let (ctx, _guard) = start_localhost_context(3, CustomSessionBuilder).await; + + let (ctx, _guard) = start_localhost_context(3, build_state).await; let mut plan: Arc = Arc::new(ErrorExec::new("something failed")); diff --git a/tests/highly_distributed_query.rs b/tests/highly_distributed_query.rs index d6c2cfa..c2b8160 100644 --- a/tests/highly_distributed_query.rs +++ b/tests/highly_distributed_query.rs @@ -4,7 +4,7 @@ mod tests { use datafusion::physical_plan::{displayable, execute_stream}; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::test_utils::parquet::register_parquet_tables; - use datafusion_distributed::{assert_snapshot, ArrowFlightReadExec, NoopSessionBuilder}; + use datafusion_distributed::{assert_snapshot, ArrowFlightReadExec, DefaultSessionBuilder}; use futures::TryStreamExt; use std::error::Error; use std::sync::Arc; @@ -12,7 +12,7 @@ mod tests { #[tokio::test] #[ignore] async fn highly_distributed_query() -> Result<(), Box> { - let (ctx, _guard) = start_localhost_context(9, NoopSessionBuilder).await; + let (ctx, _guard) = start_localhost_context(9, DefaultSessionBuilder).await; register_parquet_tables(&ctx).await?; let df = ctx.sql(r#"SELECT * FROM flights_1m"#).await?; diff --git a/tests/tpch_validation_test.rs b/tests/tpch_validation_test.rs index c54dc41..6c4bb0f 100644 --- a/tests/tpch_validation_test.rs +++ b/tests/tpch_validation_test.rs @@ -3,13 +3,13 @@ mod common; #[cfg(all(feature = "integration", test))] mod tests { use crate::common::{ensure_tpch_data, get_test_data_dir, get_test_tpch_query}; - use async_trait::async_trait; use datafusion::error::DataFusionError; - use datafusion::execution::SessionStateBuilder; - + use datafusion::execution::{SessionState, SessionStateBuilder}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_distributed::test_utils::localhost::start_localhost_context; - use datafusion_distributed::{DistributedPhysicalOptimizerRule, SessionBuilder}; + use datafusion_distributed::{ + DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext, + }; use futures::TryStreamExt; use std::error::Error; use std::sync::Arc; @@ -125,47 +125,37 @@ mod tests { } async fn test_tpch_query(query_id: u8) -> Result<(), Box> { - let (ctx, _guard) = start_localhost_context(2, TestSessionBuilder).await; + let (ctx, _guard) = start_localhost_context(2, build_state).await; run_tpch_query(ctx, query_id).await } - #[derive(Clone)] - struct TestSessionBuilder; - - #[async_trait] - impl SessionBuilder for TestSessionBuilder { - fn session_state_builder( - &self, - builder: SessionStateBuilder, - ) -> Result { - let mut config = SessionConfig::new().with_target_partitions(3); - - // FIXME: these three options are critical for the correct function of the library - // but we are not enforcing that the user sets them. They are here at the moment - // but we should figure out a way to do this better. - config - .options_mut() - .optimizer - .hash_join_single_partition_threshold = 0; - config - .options_mut() - .optimizer - .hash_join_single_partition_threshold_rows = 0; - - config.options_mut().optimizer.prefer_hash_join = true; - // end critical options section - let rule = DistributedPhysicalOptimizerRule::new().with_maximum_partitions_per_task(2); - Ok(builder - .with_config(config) - .with_physical_optimizer_rule(Arc::new(rule))) - } - - async fn session_context( - &self, - ctx: SessionContext, - ) -> std::result::Result { - Ok(ctx) - } + async fn build_state( + ctx: DistributedSessionBuilderContext, + ) -> Result { + let mut config = SessionConfig::new().with_target_partitions(3); + + // FIXME: these three options are critical for the correct function of the library + // but we are not enforcing that the user sets them. They are here at the moment + // but we should figure out a way to do this better. + config + .options_mut() + .optimizer + .hash_join_single_partition_threshold = 0; + config + .options_mut() + .optimizer + .hash_join_single_partition_threshold_rows = 0; + + config.options_mut().optimizer.prefer_hash_join = true; + // end critical options section + + let rule = DistributedPhysicalOptimizerRule::new().with_maximum_partitions_per_task(2); + Ok(SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(ctx.runtime_env) + .with_default_features() + .with_physical_optimizer_rule(Arc::new(rule)) + .build()) } // test_non_distributed_consistency runs each TPC-H query twice - once in a distributed manner