diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index 1f8085b..03f9a4d 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -142,25 +142,25 @@ impl DistributedSessionBuilder for RunOpt { &self, ctx: DistributedSessionBuilderContext, ) -> Result { - let mut builder = SessionStateBuilder::new().with_default_features(); - + let rt_builder = self.common.runtime_env_builder()?; let config = self .common .config()? - .with_collect_statistics(!self.disable_statistics) + .with_target_partitions(self.partitions()) + .with_collect_statistics(!self.disable_statistics); + let mut builder = SessionStateBuilder::new() + .with_runtime_env(rt_builder.build_arc()?) + .with_default_features() + .with_config(config) .with_distributed_user_codec(InMemoryCacheExecCodec) - .with_distributed_channel_resolver(LocalHostChannelResolver::new(self.workers.clone())) - .with_distributed_option_extension_from_headers::(&ctx.headers)? - .with_target_partitions(self.partitions()); - - let rt_builder = self.common.runtime_env_builder()?; + .with_distributed_execution(LocalHostChannelResolver::new(self.workers.clone())) + .with_distributed_option_extension_from_headers::(&ctx.headers)?; if self.mem_table { builder = builder.with_physical_optimizer_rule(Arc::new(InMemoryDataSourceRule)); } if !self.workers.is_empty() { builder = builder - .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule)) .with_distributed_network_coalesce_tasks( self.coalesce_tasks.unwrap_or(self.workers.len()), ) @@ -169,10 +169,7 @@ impl DistributedSessionBuilder for RunOpt { ); } - Ok(builder - .with_config(config) - .with_runtime_env(rt_builder.build_arc()?) - .build()) + Ok(builder.build()) } } @@ -196,7 +193,12 @@ impl RunOpt { } async fn run_local(mut self) -> Result<()> { - let state = self.build_session_state(Default::default()).await?; + let mut state = self.build_session_state(Default::default()).await?; + if self.mem_table { + state = SessionStateBuilder::from(state) + .with_distributed_option_extension(WarmingUpMarker::warming_up())? + .build(); + } let ctx = SessionContext::new_with_state(state); self.register_tables(&ctx).await?; @@ -218,9 +220,6 @@ impl RunOpt { for query_id in query_range.clone() { // put the WarmingUpMarker in the context, otherwise, queries will fail as the // InMemoryCacheExec node will think they should already be warmed up. - let ctx = ctx - .clone() - .with_distributed_option_extension(WarmingUpMarker::warming_up())?; for query in get_query_sql(query_id)? { self.execute_query(&ctx, &query).await?; } diff --git a/examples/in_memory_cluster.rs b/examples/in_memory_cluster.rs index 7171a75..4e83f81 100644 --- a/examples/in_memory_cluster.rs +++ b/examples/in_memory_cluster.rs @@ -7,12 +7,11 @@ use datafusion::physical_plan::displayable; use datafusion::prelude::{ParquetReadOptions, SessionContext}; use datafusion_distributed::{ ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt, - DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext, create_flight_client, + DistributedSessionBuilderContext, create_flight_client, }; use futures::TryStreamExt; use hyper_util::rt::TokioIo; use std::error::Error; -use std::sync::Arc; use structopt::StructOpt; use tonic::transport::{Endpoint, Server}; @@ -41,8 +40,7 @@ async fn main() -> Result<(), Box> { let state = SessionStateBuilder::new() .with_default_features() - .with_distributed_channel_resolver(InMemoryChannelResolver::new()) - .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule)) + .with_distributed_execution(InMemoryChannelResolver::new()) .with_distributed_network_coalesce_tasks(args.network_shuffle_tasks) .with_distributed_network_shuffle_tasks(args.network_coalesce_tasks) .build(); @@ -107,7 +105,7 @@ impl InMemoryChannelResolver { async move { let builder = SessionStateBuilder::new() .with_default_features() - .with_distributed_channel_resolver(this) + .with_distributed_execution(this) .with_runtime_env(ctx.runtime_env.clone()); Ok(builder.build()) } diff --git a/examples/localhost_run.rs b/examples/localhost_run.rs index 0035a0c..48509c3 100644 --- a/examples/localhost_run.rs +++ b/examples/localhost_run.rs @@ -6,12 +6,9 @@ use datafusion::common::DataFusionError; use datafusion::execution::SessionStateBuilder; use datafusion::physical_plan::displayable; use datafusion::prelude::{ParquetReadOptions, SessionContext}; -use datafusion_distributed::{ - BoxCloneSyncChannel, ChannelResolver, DistributedExt, DistributedPhysicalOptimizerRule, -}; +use datafusion_distributed::{BoxCloneSyncChannel, ChannelResolver, DistributedExt}; use futures::TryStreamExt; use std::error::Error; -use std::sync::Arc; use structopt::StructOpt; use tonic::transport::Channel; use url::Url; @@ -47,8 +44,7 @@ async fn main() -> Result<(), Box> { let state = SessionStateBuilder::new() .with_default_features() - .with_distributed_channel_resolver(localhost_resolver) - .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule)) + .with_distributed_execution(localhost_resolver) .with_distributed_network_coalesce_tasks(args.network_coalesce_tasks) .with_distributed_network_shuffle_tasks(args.network_shuffle_tasks) .build(); diff --git a/examples/localhost_worker.rs b/examples/localhost_worker.rs index f02afbc..fada931 100644 --- a/examples/localhost_worker.rs +++ b/examples/localhost_worker.rs @@ -38,7 +38,7 @@ async fn main() -> Result<(), Box> { async move { Ok(SessionStateBuilder::new() .with_runtime_env(ctx.runtime_env) - .with_distributed_channel_resolver(local_host_resolver) + .with_distributed_execution(local_host_resolver) .with_default_features() .build()) } diff --git a/src/distributed_ext.rs b/src/distributed_ext.rs index dfc1ab1..63669d5 100644 --- a/src/distributed_ext.rs +++ b/src/distributed_ext.rs @@ -1,4 +1,3 @@ -use crate::ChannelResolver; use crate::channel_resolver_ext::set_distributed_channel_resolver; use crate::config_extension_ext::{ set_distributed_option_extension, set_distributed_option_extension_from_headers, @@ -7,10 +6,10 @@ use crate::distributed_planner::{ set_distributed_network_coalesce_tasks, set_distributed_network_shuffle_tasks, }; use crate::protobuf::{set_distributed_user_codec, set_distributed_user_codec_arc}; +use crate::{ChannelResolver, DistributedConfig, DistributedPhysicalOptimizerRule}; use datafusion::common::DataFusionError; use datafusion::config::ConfigExtension; -use datafusion::execution::{SessionState, SessionStateBuilder}; -use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion::execution::SessionStateBuilder; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use delegate::delegate; use http::HeaderMap; @@ -51,8 +50,9 @@ pub trait DistributedExt: Sized { /// let mut my_custom_extension = CustomExtension::default(); /// // Now, the CustomExtension will be able to cross network boundaries. Upon making an Arrow /// // Flight request, it will be sent through gRPC metadata. - /// let mut config = SessionConfig::new() - /// .with_distributed_option_extension(my_custom_extension).unwrap(); + /// let state = SessionStateBuilder::new() + /// .with_distributed_option_extension(my_custom_extension).unwrap() + /// .build(); /// /// async fn build_state(ctx: DistributedSessionBuilderContext) -> Result { /// // This function can be provided to an ArrowFlightEndpoint in order to tell it how to @@ -106,8 +106,9 @@ pub trait DistributedExt: Sized { /// let mut my_custom_extension = CustomExtension::default(); /// // Now, the CustomExtension will be able to cross network boundaries. Upon making an Arrow /// // Flight request, it will be sent through gRPC metadata. - /// let mut config = SessionConfig::new() - /// .with_distributed_option_extension(my_custom_extension).unwrap(); + /// let state = SessionStateBuilder::new() + /// .with_distributed_option_extension(my_custom_extension).unwrap() + /// .build(); /// /// async fn build_state(ctx: DistributedSessionBuilderContext) -> Result { /// // This function can be provided to an ArrowFlightEndpoint in order to tell it how to @@ -156,7 +157,9 @@ pub trait DistributedExt: Sized { /// } /// } /// - /// let config = SessionConfig::new().with_distributed_user_codec(CustomExecCodec); + /// let state = SessionStateBuilder::new() + /// .with_distributed_user_codec(CustomExecCodec) + /// .build(); /// /// async fn build_state(ctx: DistributedSessionBuilderContext) -> Result { /// // This function can be provided to an ArrowFlightEndpoint in order to tell it how to @@ -177,8 +180,14 @@ pub trait DistributedExt: Sized { /// Same as [DistributedExt::set_distributed_user_codec] but with a dynamic argument. fn set_distributed_user_codec_arc(&mut self, codec: Arc); - /// Injects a [ChannelResolver] implementation for Distributed DataFusion to resolve worker - /// nodes. When running in distributed mode, setting a [ChannelResolver] is required. + /// Enables distributed execution. For this, several things happen: + /// + /// - Injects a [ChannelResolver] implementation for Distributed DataFusion to resolve worker + /// nodes. When running in distributed mode, setting a [ChannelResolver] is required. + /// - Injects a [DistributedPhysicalOptimizerRule] rule that will inject network boundaries + /// in the plan and will break it down into stages. + /// - Injects a [DistributedConfig] object with configuration about the amount of tasks that + /// should be spawned while distributing the queries. /// /// Example: /// @@ -204,23 +213,25 @@ pub trait DistributedExt: Sized { /// } /// } /// - /// let config = SessionConfig::new().with_distributed_channel_resolver(CustomChannelResolver); + /// let state = SessionStateBuilder::new() + /// .with_distributed_execution(CustomChannelResolver) + /// .build(); /// /// async fn build_state(ctx: DistributedSessionBuilderContext) -> Result { /// // This function can be provided to an ArrowFlightEndpoint so that it knows how to /// // resolve tonic channels from URLs upon making network calls to other nodes. /// Ok(SessionStateBuilder::new() - /// .with_distributed_channel_resolver(CustomChannelResolver) + /// .with_distributed_execution(CustomChannelResolver) /// .build()) /// } /// ``` - fn with_distributed_channel_resolver( + fn with_distributed_execution( self, resolver: T, ) -> Self; - /// Same as [DistributedExt::with_distributed_channel_resolver] but with an in-place mutation. - fn set_distributed_channel_resolver( + /// Same as [DistributedExt::with_distributed_execution] but with an in-place mutation. + fn set_distributed_execution( &mut self, resolver: T, ); @@ -252,42 +263,53 @@ pub trait DistributedExt: Sized { fn set_distributed_network_shuffle_tasks(&mut self, tasks: usize); } -impl DistributedExt for SessionConfig { +impl DistributedExt for SessionStateBuilder { fn set_distributed_option_extension( &mut self, t: T, ) -> Result<(), DataFusionError> { - set_distributed_option_extension(self, t) + set_distributed_option_extension(self.config().get_or_insert_default(), t) } fn set_distributed_option_extension_from_headers( &mut self, headers: &HeaderMap, ) -> Result<(), DataFusionError> { - set_distributed_option_extension_from_headers::(self, headers) + set_distributed_option_extension_from_headers::( + self.config().get_or_insert_default(), + headers, + ) } fn set_distributed_user_codec(&mut self, codec: T) { - set_distributed_user_codec(self, codec) + set_distributed_user_codec(self.config().get_or_insert_default(), codec) } fn set_distributed_user_codec_arc(&mut self, codec: Arc) { - set_distributed_user_codec_arc(self, codec) + set_distributed_user_codec_arc(self.config().get_or_insert_default(), codec) } - fn set_distributed_channel_resolver( + fn set_distributed_execution( &mut self, resolver: T, ) { - set_distributed_channel_resolver(self, resolver) + let cfg = self.config().get_or_insert_default(); + set_distributed_channel_resolver(cfg, resolver); + let opts = cfg.options_mut(); + if opts.extensions.get::().is_none() { + opts.extensions.insert(DistributedConfig::default()); + } + + let rules = self.physical_optimizer_rules().get_or_insert_default(); + rules.push(Arc::new(DistributedPhysicalOptimizerRule)); } fn set_distributed_network_coalesce_tasks(&mut self, tasks: usize) { - set_distributed_network_coalesce_tasks(self, tasks) + set_distributed_network_coalesce_tasks(self.config().get_or_insert_default(), tasks) } fn set_distributed_network_shuffle_tasks(&mut self, tasks: usize) { - set_distributed_network_shuffle_tasks(self, tasks) + set_distributed_network_shuffle_tasks(self.config().get_or_insert_default(), tasks) } delegate! { @@ -308,9 +330,9 @@ impl DistributedExt for SessionConfig { #[expr($;self)] fn with_distributed_user_codec_arc(mut self, codec: Arc) -> Self; - #[call(set_distributed_channel_resolver)] + #[call(set_distributed_execution)] #[expr($;self)] - fn with_distributed_channel_resolver(mut self, resolver: T) -> Self; + fn with_distributed_execution(mut self, resolver: T) -> Self; #[call(set_distributed_network_coalesce_tasks)] #[expr($;self)] @@ -322,126 +344,3 @@ impl DistributedExt for SessionConfig { } } } - -impl DistributedExt for SessionStateBuilder { - delegate! { - to self.config().get_or_insert_default() { - fn set_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; - #[call(set_distributed_option_extension)] - #[expr($?;Ok(self))] - fn with_distributed_option_extension(mut self, t: T) -> Result; - - fn set_distributed_option_extension_from_headers(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>; - #[call(set_distributed_option_extension_from_headers)] - #[expr($?;Ok(self))] - fn with_distributed_option_extension_from_headers(mut self, headers: &HeaderMap) -> Result; - - fn set_distributed_user_codec(&mut self, codec: T); - #[call(set_distributed_user_codec)] - #[expr($;self)] - fn with_distributed_user_codec(mut self, codec: T) -> Self; - - fn set_distributed_user_codec_arc(&mut self, codec: Arc); - #[call(set_distributed_user_codec_arc)] - #[expr($;self)] - fn with_distributed_user_codec_arc(mut self, codec: Arc) -> Self; - - fn set_distributed_channel_resolver(&mut self, resolver: T); - #[call(set_distributed_channel_resolver)] - #[expr($;self)] - fn with_distributed_channel_resolver(mut self, resolver: T) -> Self; - - fn set_distributed_network_coalesce_tasks(&mut self, tasks: usize); - #[call(set_distributed_network_coalesce_tasks)] - #[expr($;self)] - fn with_distributed_network_coalesce_tasks(mut self, tasks: usize) -> Self; - - fn set_distributed_network_shuffle_tasks(&mut self, tasks: usize); - #[call(set_distributed_network_shuffle_tasks)] - #[expr($;self)] - fn with_distributed_network_shuffle_tasks(mut self, tasks: usize) -> Self; - } - } -} - -impl DistributedExt for SessionState { - delegate! { - to self.config_mut() { - fn set_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; - #[call(set_distributed_option_extension)] - #[expr($?;Ok(self))] - fn with_distributed_option_extension(mut self, t: T) -> Result; - - fn set_distributed_option_extension_from_headers(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>; - #[call(set_distributed_option_extension_from_headers)] - #[expr($?;Ok(self))] - fn with_distributed_option_extension_from_headers(mut self, headers: &HeaderMap) -> Result; - - fn set_distributed_user_codec(&mut self, codec: T); - #[call(set_distributed_user_codec)] - #[expr($;self)] - fn with_distributed_user_codec(mut self, codec: T) -> Self; - - fn set_distributed_user_codec_arc(&mut self, codec: Arc); - #[call(set_distributed_user_codec_arc)] - #[expr($;self)] - fn with_distributed_user_codec_arc(mut self, codec: Arc) -> Self; - - fn set_distributed_channel_resolver(&mut self, resolver: T); - #[call(set_distributed_channel_resolver)] - #[expr($;self)] - fn with_distributed_channel_resolver(mut self, resolver: T) -> Self; - - fn set_distributed_network_coalesce_tasks(&mut self, tasks: usize); - #[call(set_distributed_network_coalesce_tasks)] - #[expr($;self)] - fn with_distributed_network_coalesce_tasks(mut self, tasks: usize) -> Self; - - fn set_distributed_network_shuffle_tasks(&mut self, tasks: usize); - #[call(set_distributed_network_shuffle_tasks)] - #[expr($;self)] - fn with_distributed_network_shuffle_tasks(mut self, tasks: usize) -> Self; - } - } -} - -impl DistributedExt for SessionContext { - delegate! { - to self.state_ref().write().config_mut() { - fn set_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; - #[call(set_distributed_option_extension)] - #[expr($?;Ok(self))] - fn with_distributed_option_extension(self, t: T) -> Result; - - fn set_distributed_option_extension_from_headers(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>; - #[call(set_distributed_option_extension_from_headers)] - #[expr($?;Ok(self))] - fn with_distributed_option_extension_from_headers(self, headers: &HeaderMap) -> Result; - - fn set_distributed_user_codec(&mut self, codec: T); - #[call(set_distributed_user_codec)] - #[expr($;self)] - fn with_distributed_user_codec(self, codec: T) -> Self; - - fn set_distributed_user_codec_arc(&mut self, codec: Arc); - #[call(set_distributed_user_codec_arc)] - #[expr($;self)] - fn with_distributed_user_codec_arc(self, codec: Arc) -> Self; - - fn set_distributed_channel_resolver(&mut self, resolver: T); - #[call(set_distributed_channel_resolver)] - #[expr($;self)] - fn with_distributed_channel_resolver(self, resolver: T) -> Self; - - fn set_distributed_network_coalesce_tasks(&mut self, tasks: usize); - #[call(set_distributed_network_coalesce_tasks)] - #[expr($;self)] - fn with_distributed_network_coalesce_tasks(self, tasks: usize) -> Self; - - fn set_distributed_network_shuffle_tasks(&mut self, tasks: usize); - #[call(set_distributed_network_shuffle_tasks)] - #[expr($;self)] - fn with_distributed_network_shuffle_tasks(self, tasks: usize) -> Self; - } - } -} diff --git a/src/distributed_planner/distributed_physical_optimizer_rule.rs b/src/distributed_planner/distributed_physical_optimizer_rule.rs index 1f4de76..eba0318 100644 --- a/src/distributed_planner/distributed_physical_optimizer_rule.rs +++ b/src/distributed_planner/distributed_physical_optimizer_rule.rs @@ -66,6 +66,9 @@ impl PhysicalOptimizerRule for DistributedPhysicalOptimizerRule { let Some(cfg) = config.extensions.get::() else { return Ok(plan); }; + if cfg.network_coalesce_tasks.is_none() && cfg.network_shuffle_tasks.is_none() { + return Ok(plan); + } // We can only optimize plans that are not already distributed distribute_plan(apply_network_boundaries(plan, cfg)?) } @@ -180,6 +183,9 @@ pub fn apply_network_boundaries( pub fn distribute_plan( plan: Arc, ) -> Result, DataFusionError> { + if plan.as_any().is::() { + return Ok(plan); + } let stage = match _distribute_plan_inner(Uuid::new_v4(), plan.clone(), &mut 1, 0, 1) { Ok(stage) => stage, Err(err) => { @@ -304,13 +310,13 @@ impl Referenced<'_, T> { #[cfg(test)] mod tests { + use crate::DistributedExt; + use crate::test_utils::in_memory_channel_resolver::InMemoryChannelResolver; use crate::test_utils::parquet::register_parquet_tables; - use crate::{DistributedConfig, DistributedPhysicalOptimizerRule}; use crate::{assert_snapshot, display_plan_ascii}; use datafusion::error::DataFusionError; use datafusion::execution::SessionStateBuilder; use datafusion::prelude::{SessionConfig, SessionContext}; - use std::sync::Arc; /* shema for the "weather" table MinTemp [type=DOUBLE] [repetitiontype=OPTIONAL] @@ -580,13 +586,12 @@ mod tests { async fn sql_to_explain(query: &str) -> Result { let config = SessionConfig::new() .with_target_partitions(4) - .with_option_extension(DistributedConfig::default()) .with_information_schema(true); let state = SessionStateBuilder::new() .with_default_features() .with_config(config) - .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule)) + .with_distributed_execution(InMemoryChannelResolver::new()) .build(); let ctx = SessionContext::new_with_state(state); diff --git a/src/metrics/task_metrics_collector.rs b/src/metrics/task_metrics_collector.rs index ea956da..e24a041 100644 --- a/src/metrics/task_metrics_collector.rs +++ b/src/metrics/task_metrics_collector.rs @@ -128,7 +128,6 @@ mod tests { use futures::StreamExt; use crate::DistributedExt; - use crate::DistributedPhysicalOptimizerRule; use crate::execution_plans::DistributedExec; use crate::metrics::proto::metrics_set_proto_to_df; use crate::test_utils::in_memory_channel_resolver::InMemoryChannelResolver; @@ -152,8 +151,7 @@ mod tests { let state = SessionStateBuilder::new() .with_default_features() .with_config(config) - .with_distributed_channel_resolver(InMemoryChannelResolver::new()) - .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule)) + .with_distributed_execution(InMemoryChannelResolver::new()) .with_distributed_network_coalesce_tasks(2) .with_distributed_network_shuffle_tasks(2) .build(); diff --git a/src/metrics/task_metrics_rewriter.rs b/src/metrics/task_metrics_rewriter.rs index e170fe3..3ff8a6e 100644 --- a/src/metrics/task_metrics_rewriter.rs +++ b/src/metrics/task_metrics_rewriter.rs @@ -218,7 +218,6 @@ mod tests { use uuid::Uuid; use crate::DistributedExt; - use crate::DistributedPhysicalOptimizerRule; use crate::metrics::task_metrics_rewriter::MetricsWrapperExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::metrics::MetricsSet; @@ -245,8 +244,7 @@ mod tests { if distributed { builder = builder - .with_distributed_channel_resolver(InMemoryChannelResolver::new()) - .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule)) + .with_distributed_execution(InMemoryChannelResolver::new()) .with_distributed_network_coalesce_tasks(2) .with_distributed_network_shuffle_tasks(2) } diff --git a/src/test_utils/in_memory_channel_resolver.rs b/src/test_utils/in_memory_channel_resolver.rs index 47fcd15..24ebdfb 100644 --- a/src/test_utils/in_memory_channel_resolver.rs +++ b/src/test_utils/in_memory_channel_resolver.rs @@ -49,7 +49,7 @@ impl InMemoryChannelResolver { async move { let builder = SessionStateBuilder::new() .with_default_features() - .with_distributed_channel_resolver(this) + .with_distributed_execution(this) .with_runtime_env(ctx.runtime_env.clone()); Ok(builder.build()) } diff --git a/src/test_utils/localhost.rs b/src/test_utils/localhost.rs index 02d6cd1..1162195 100644 --- a/src/test_utils/localhost.rs +++ b/src/test_utils/localhost.rs @@ -53,9 +53,7 @@ where let channel_resolver = LocalHostChannelResolver::new(ports.clone()); let session_builder = session_builder.map(move |builder: SessionStateBuilder| { let channel_resolver = channel_resolver.clone(); - Ok(builder - .with_distributed_channel_resolver(channel_resolver) - .build()) + Ok(builder.with_distributed_execution(channel_resolver).build()) }); let mut join_set = JoinSet::new(); for listener in listeners { diff --git a/tests/custom_config_extension.rs b/tests/custom_config_extension.rs index ed69050..6eb7adc 100644 --- a/tests/custom_config_extension.rs +++ b/tests/custom_config_extension.rs @@ -40,11 +40,14 @@ mod tests { } let (mut ctx, _guard) = start_localhost_context(3, build_state).await; - ctx.set_distributed_option_extension(CustomExtension { - foo: "foo".to_string(), - bar: 1, - baz: true, - })?; + ctx = SessionStateBuilder::from(ctx.state()) + .with_distributed_option_extension(CustomExtension { + foo: "foo".to_string(), + bar: 1, + baz: true, + })? + .build() + .into(); let mut plan: Arc = Arc::new(CustomConfigExtensionRequiredExec::new()); diff --git a/tests/tpch_validation_test.rs b/tests/tpch_validation_test.rs index 9167dac..d136c40 100644 --- a/tests/tpch_validation_test.rs +++ b/tests/tpch_validation_test.rs @@ -7,13 +7,12 @@ mod tests { use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::test_utils::tpch; use datafusion_distributed::{ - DistributedExt, DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext, - assert_snapshot, display_plan_ascii, explain_analyze, + DistributedExt, DistributedSessionBuilderContext, assert_snapshot, display_plan_ascii, + explain_analyze, }; use futures::TryStreamExt; use std::error::Error; use std::fs; - use std::sync::Arc; use tokio::sync::OnceCell; const PARTITIONS: usize = 6; @@ -2200,7 +2199,6 @@ mod tests { Ok(SessionStateBuilder::new() .with_runtime_env(ctx.runtime_env) .with_default_features() - .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule)) .with_distributed_network_coalesce_tasks(COALESCE_TASKS) .with_distributed_network_shuffle_tasks(SHUFFLE_TASKS) .build())