diff --git a/Cargo.lock b/Cargo.lock index 9ab5890..4b03f80 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1129,6 +1129,7 @@ dependencies = [ "prost", "rand 0.8.5", "tokio", + "tokio-stream", "tonic", "tower 0.5.2", "tpchgen", diff --git a/Cargo.toml b/Cargo.toml index c7ac91e..564be59 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,5 @@ [workspace] -members = [ - "benchmarks" -] +members = ["benchmarks"] [workspace.dependencies] datafusion = { version = "49.0.0" } @@ -37,14 +35,16 @@ tpchgen = { git = "https://github.com/clflushopt/tpchgen-rs", rev = "c8d82343252 tpchgen-arrow = { git = "https://github.com/clflushopt/tpchgen-rs", rev = "c8d823432528eed4f70fca5a1296a66c68a389a8", optional = true } parquet = { version = "55.2.0", optional = true } arrow = { version = "55.2.0", optional = true } +tokio-stream = { version = "0.1.17", optional = true } [features] integration = [ "insta", "tpchgen", - "tpchgen-arrow", + "tpchgen-arrow", "parquet", - "arrow" + "arrow", + "tokio-stream", ] [dev-dependencies] diff --git a/benchmarks/README.md b/benchmarks/README.md index 6723821..75c1b8b 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -14,4 +14,11 @@ After generating the data with the command above: ```shell cargo run -p datafusion-distributed-benchmarks --release -- tpch --path data/tpch_sf1 -``` \ No newline at end of file +``` + +In order to validate the correctness of the results against single node execution, add +`--validate` + +```shell +cargo run -p datafusion-distributed-benchmarks --release -- tpch --path data/tpch_sf1 --validate +``` diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index b0ac256..882603a 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -100,8 +100,7 @@ pub struct RunOpt { #[structopt(short = "t", long = "sorted")] sorted: bool, - /// Mark the first column of each table as sorted in ascending order. - /// The tables should have been created with the `--sort` option for this to have any effect. + /// The maximum number of partitions per task. #[structopt(long = "ppt")] partitions_per_task: Option, } @@ -115,8 +114,23 @@ impl SessionBuilder for RunOpt { let mut config = self .common .config()? - .with_collect_statistics(!self.disable_statistics); + .with_collect_statistics(!self.disable_statistics) + .with_target_partitions(self.partitions()); + + // FIXME: these three options are critical for the correct function of the library + // but we are not enforcing that the user sets them. They are here at the moment + // but we should figure out a way to do this better. + config + .options_mut() + .optimizer + .hash_join_single_partition_threshold = 0; + config + .options_mut() + .optimizer + .hash_join_single_partition_threshold_rows = 0; + config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; + // end critical options section let rt_builder = self.common.runtime_env_builder()?; let mut rule = DistributedPhysicalOptimizerRule::new(); @@ -140,7 +154,7 @@ impl SessionBuilder for RunOpt { impl RunOpt { pub async fn run(self) -> Result<()> { - let (ctx, _guard) = start_localhost_context([50051], self.clone()).await; + let (ctx, _guard) = start_localhost_context(1, self.clone()).await; println!("Running benchmarks with the following options: {self:?}"); let query_range = match self.query { Some(query_id) => query_id..=query_id, @@ -180,23 +194,22 @@ impl RunOpt { let sql = &get_query_sql(query_id)?; + let single_node_ctx = SessionContext::new(); + self.register_tables(&single_node_ctx).await?; + for i in 0..self.iterations() { let start = Instant::now(); + let mut result = vec![]; // query 15 is special, with 3 statements. the second statement is the one from which we // want to capture the results - let mut result = vec![]; - if query_id == 15 { - for (n, query) in sql.iter().enumerate() { - if n == 1 { - result = self.execute_query(ctx, query).await?; - } else { - self.execute_query(ctx, query).await?; - } - } - } else { - for query in sql { + let result_stmt = if query_id == 15 { 1 } else { sql.len() - 1 }; + + for (i, query) in sql.iter().enumerate() { + if i == result_stmt { result = self.execute_query(ctx, query).await?; + } else { + self.execute_query(ctx, query).await?; } } @@ -208,6 +221,7 @@ impl RunOpt { println!( "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" ); + query_results.push(QueryResult { elapsed, row_count }); } diff --git a/context.rs b/context.rs deleted file mode 100644 index eae1e12..0000000 --- a/context.rs +++ /dev/null @@ -1,20 +0,0 @@ -use url::Url; -use uuid::Uuid; - -#[derive(Debug, Clone)] -pub struct StageContext { - /// Unique identifier of the Stage. - pub id: Uuid, - /// Number of tasks involved in the query. - pub n_tasks: usize, - /// Unique identifier of the input Stage. - pub input_id: Uuid, - /// Urls from which the current stage will need to read data. - pub input_urls: Vec, -} - -#[derive(Debug, Clone)] -pub struct StageTaskContext { - /// Index of the current task in a stage - pub task_idx: usize, -} diff --git a/src/errors/arrow_error.rs b/src/errors/arrow_error.rs index c810609..33565ee 100644 --- a/src/errors/arrow_error.rs +++ b/src/errors/arrow_error.rs @@ -224,8 +224,8 @@ mod tests { let (recovered_error, recovered_ctx) = proto.to_arrow_error(); if original_error.to_string() != recovered_error.to_string() { - println!("original error: {}", original_error.to_string()); - println!("recovered error: {}", recovered_error.to_string()); + println!("original error: {}", original_error); + println!("recovered error: {}", recovered_error); } assert_eq!(original_error.to_string(), recovered_error.to_string()); diff --git a/src/errors/mod.rs b/src/errors/mod.rs index 172aa86..e2de185 100644 --- a/src/errors/mod.rs +++ b/src/errors/mod.rs @@ -17,8 +17,8 @@ mod schema_error; pub fn datafusion_error_to_tonic_status(err: &DataFusionError) -> tonic::Status { let err = DataFusionErrorProto::from_datafusion_error(err); let err = err.encode_to_vec(); - let status = tonic::Status::with_details(tonic::Code::Internal, "DataFusionError", err.into()); - status + + tonic::Status::with_details(tonic::Code::Internal, "DataFusionError", err.into()) } /// Decodes a [DataFusionError] from a [tonic::Status] error. If the provided [tonic::Status] diff --git a/src/errors/schema_error.rs b/src/errors/schema_error.rs index 4ec2097..e55709a 100644 --- a/src/errors/schema_error.rs +++ b/src/errors/schema_error.rs @@ -198,7 +198,7 @@ impl SchemaErrorProto { valid_fields, } => SchemaErrorProto { inner: Some(SchemaErrorInnerProto::FieldNotFound(FieldNotFoundProto { - field: Some(Box::new(ColumnProto::from_column(&field))), + field: Some(Box::new(ColumnProto::from_column(field))), valid_fields: valid_fields.iter().map(ColumnProto::from_column).collect(), })), backtrace: backtrace.cloned(), diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index a2b52ab..38b7d25 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -1,29 +1,39 @@ use crate::composed_extension_codec::ComposedPhysicalExtensionCodec; use crate::errors::datafusion_error_to_tonic_status; use crate::flight_service::service::ArrowFlightEndpoint; -use crate::plan::DistributedCodec; -use crate::stage::{stage_from_proto, ExecutionStageProto}; +use crate::plan::{DistributedCodec, PartitionGroup}; +use crate::stage::{stage_from_proto, ExecutionStage, ExecutionStageProto}; use crate::user_provided_codec::get_user_codec; use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; use arrow_flight::flight_service_server::FlightService; use arrow_flight::Ticket; -use datafusion::execution::SessionStateBuilder; +use datafusion::execution::{SessionState, SessionStateBuilder}; use datafusion::optimizer::OptimizerConfig; -use datafusion::prelude::SessionContext; use futures::TryStreamExt; use prost::Message; use std::sync::Arc; use tonic::{Request, Response, Status}; +use super::service::StageKey; + #[derive(Clone, PartialEq, ::prost::Message)] pub struct DoGet { /// The ExecutionStage that we are going to execute #[prost(message, optional, tag = "1")] pub stage_proto: Option, - /// the partition of the stage to execute + /// The index to the task within the stage that we want to execute #[prost(uint64, tag = "2")] + pub task_number: u64, + /// the partition number we want to execute + #[prost(uint64, tag = "3")] pub partition: u64, + /// The stage key that identifies the stage. This is useful to keep + /// outside of the stage proto as it is used to store the stage + /// and we may not need to deserialize the entire stage proto + /// if we already have stored it + #[prost(message, optional, tag = "4")] + pub stage_key: Option, } impl ArrowFlightEndpoint { @@ -36,59 +46,28 @@ impl ArrowFlightEndpoint { Status::invalid_argument(format!("Cannot decode DoGet message: {err}")) })?; - let stage_msg = doget - .stage_proto - .ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?; - - let state_builder = SessionStateBuilder::new() - .with_runtime_env(Arc::clone(&self.runtime)) - .with_default_features(); - let state_builder = self - .session_builder - .session_state_builder(state_builder) - .map_err(|err| datafusion_error_to_tonic_status(&err))?; - - let state = state_builder.build(); - let mut state = self - .session_builder - .session_state(state) - .await - .map_err(|err| datafusion_error_to_tonic_status(&err))?; - - let function_registry = state.function_registry().ok_or(Status::invalid_argument( - "FunctionRegistry not present in newly built SessionState", - ))?; - - let mut combined_codec = ComposedPhysicalExtensionCodec::default(); - combined_codec.push(DistributedCodec); - if let Some(ref user_codec) = get_user_codec(state.config()) { - combined_codec.push_arc(Arc::clone(&user_codec)); - } - - let stage = stage_from_proto( - stage_msg, - function_registry, - &self.runtime.as_ref(), - &combined_codec, - ) - .map_err(|err| Status::invalid_argument(format!("Cannot decode stage proto: {err}")))?; - let inner_plan = Arc::clone(&stage.plan); - - // Add the extensions that might be required for ExecutionPlan nodes in the plan - let config = state.config_mut(); - config.set_extension(Arc::clone(&self.channel_manager)); - config.set_extension(Arc::new(stage)); - - let ctx = SessionContext::new_with_state(state); - - let ctx = self - .session_builder - .session_context(ctx) - .await - .map_err(|err| datafusion_error_to_tonic_status(&err))?; + let partition = doget.partition as usize; + let task_number = doget.task_number as usize; + let (mut state, stage) = self.get_state_and_stage(doget).await?; + + // find out which partition group we are executing + let task = stage + .tasks + .get(task_number) + .ok_or(Status::invalid_argument(format!( + "Task number {} not found in stage {}", + task_number, + stage.name() + )))?; + + let partition_group = + PartitionGroup(task.partition_group.iter().map(|p| *p as usize).collect()); + state.config_mut().set_extension(Arc::new(partition_group)); + + let inner_plan = stage.plan.clone(); let stream = inner_plan - .execute(doget.partition as usize, ctx.task_ctx()) + .execute(partition, state.task_ctx()) .map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?; let flight_data_stream = FlightDataEncoderBuilder::new() @@ -104,4 +83,71 @@ impl ArrowFlightEndpoint { }, )))) } + + async fn get_state_and_stage( + &self, + doget: DoGet, + ) -> Result<(SessionState, Arc), Status> { + let key = doget + .stage_key + .ok_or(Status::invalid_argument("DoGet is missing the stage key"))?; + let once_stage = { + let entry = self.stages.entry(key).or_default(); + Arc::clone(&entry) + }; + + let (state, stage) = once_stage + .get_or_try_init(|| async { + let stage_proto = doget + .stage_proto + .ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?; + + let state_builder = SessionStateBuilder::new() + .with_runtime_env(Arc::clone(&self.runtime)) + .with_default_features(); + let state_builder = self + .session_builder + .session_state_builder(state_builder) + .map_err(|err| datafusion_error_to_tonic_status(&err))?; + + let state = state_builder.build(); + let mut state = self + .session_builder + .session_state(state) + .await + .map_err(|err| datafusion_error_to_tonic_status(&err))?; + + let function_registry = + state.function_registry().ok_or(Status::invalid_argument( + "FunctionRegistry not present in newly built SessionState", + ))?; + + let mut combined_codec = ComposedPhysicalExtensionCodec::default(); + combined_codec.push(DistributedCodec); + if let Some(ref user_codec) = get_user_codec(state.config()) { + combined_codec.push_arc(Arc::clone(user_codec)); + } + + let stage = stage_from_proto( + stage_proto, + function_registry, + self.runtime.as_ref(), + &combined_codec, + ) + .map(Arc::new) + .map_err(|err| { + Status::invalid_argument(format!("Cannot decode stage proto: {err}")) + })?; + + // Add the extensions that might be required for ExecutionPlan nodes in the plan + let config = state.config_mut(); + config.set_extension(Arc::clone(&self.channel_manager)); + config.set_extension(stage.clone()); + + Ok::<_, Status>((state, stage)) + }) + .await?; + + Ok((state.clone(), stage.clone())) + } } diff --git a/src/flight_service/mod.rs b/src/flight_service/mod.rs index 76373a9..4ae7b59 100644 --- a/src/flight_service/mod.rs +++ b/src/flight_service/mod.rs @@ -5,5 +5,5 @@ mod stream_partitioner_registry; pub(crate) use do_get::DoGet; -pub use service::ArrowFlightEndpoint; +pub use service::{ArrowFlightEndpoint, StageKey}; pub use session_builder::{NoopSessionBuilder, SessionBuilder}; diff --git a/src/flight_service/service.rs b/src/flight_service/service.rs index b761f41..fc19269 100644 --- a/src/flight_service/service.rs +++ b/src/flight_service/service.rs @@ -1,7 +1,7 @@ use crate::channel_manager::ChannelManager; use crate::flight_service::session_builder::NoopSessionBuilder; -use crate::flight_service::stream_partitioner_registry::StreamPartitionerRegistry; use crate::flight_service::SessionBuilder; +use crate::stage::ExecutionStage; use crate::ChannelResolver; use arrow_flight::flight_service_server::FlightService; use arrow_flight::{ @@ -9,15 +9,33 @@ use arrow_flight::{ HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket, }; use async_trait::async_trait; +use dashmap::DashMap; use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::SessionState; use futures::stream::BoxStream; use std::sync::Arc; +use tokio::sync::OnceCell; use tonic::{Request, Response, Status, Streaming}; +/// A key that uniquely identifies a stage in a query +#[derive(Clone, Hash, Eq, PartialEq, ::prost::Message)] +pub struct StageKey { + /// Our query id + #[prost(string, tag = "1")] + pub query_id: String, + /// Our stage id + #[prost(uint64, tag = "2")] + pub stage_id: u64, + /// The task number within the stage + #[prost(uint64, tag = "3")] + pub task_number: u64, +} + pub struct ArrowFlightEndpoint { pub(super) channel_manager: Arc, pub(super) runtime: Arc, - pub(super) partitioner_registry: Arc, + #[allow(clippy::type_complexity)] + pub(super) stages: DashMap)>>>, pub(super) session_builder: Arc, } @@ -26,7 +44,7 @@ impl ArrowFlightEndpoint { Self { channel_manager: Arc::new(ChannelManager::new(channel_resolver)), runtime: Arc::new(RuntimeEnv::default()), - partitioner_registry: Arc::new(StreamPartitionerRegistry::default()), + stages: DashMap::new(), session_builder: Arc::new(NoopSessionBuilder), } } diff --git a/src/physical_optimizer.rs b/src/physical_optimizer.rs index 2371ba5..93af9d1 100644 --- a/src/physical_optimizer.rs +++ b/src/physical_optimizer.rs @@ -12,10 +12,9 @@ use datafusion::{ config::ConfigOptions, error::Result, physical_optimizer::PhysicalOptimizerRule, - physical_plan::{ - displayable, repartition::RepartitionExec, ExecutionPlan, ExecutionPlanProperties, - }, + physical_plan::{repartition::RepartitionExec, ExecutionPlan, ExecutionPlanProperties}, }; +use uuid::Uuid; #[derive(Debug, Default)] pub struct DistributedPhysicalOptimizerRule { @@ -58,10 +57,6 @@ impl PhysicalOptimizerRule for DistributedPhysicalOptimizerRule { if plan.as_any().is::() { return Ok(plan); } - println!( - "DistributedPhysicalOptimizerRule: optimizing plan: {}", - displayable(plan.as_ref()).indent(false) - ); let plan = self.apply_network_boundaries(plan)?; let plan = self.distribute_plan(plan)?; @@ -112,11 +107,13 @@ impl DistributedPhysicalOptimizerRule { &self, plan: Arc, ) -> Result { - self._distribute_plan_inner(plan, &mut 1, 0) + let query_id = Uuid::new_v4(); + self._distribute_plan_inner(query_id, plan, &mut 1, 0) } fn _distribute_plan_inner( &self, + query_id: Uuid, plan: Arc, num: &mut usize, depth: usize, @@ -130,14 +127,14 @@ impl DistributedPhysicalOptimizerRule { let child = Arc::clone(node.children().first().cloned().ok_or( internal_datafusion_err!("Expected ArrowFlightExecRead to have a child"), )?); - let stage = self._distribute_plan_inner(child, num, depth + 1)?; + let stage = self._distribute_plan_inner(query_id, child, num, depth + 1)?; let node = Arc::new(node.to_distributed(stage.num)?); inputs.push(stage); Ok(Transformed::new(node, true, TreeNodeRecursion::Jump)) })?; let inputs = inputs.into_iter().map(Arc::new).collect(); - let mut stage = ExecutionStage::new(*num, distributed.data, inputs); + let mut stage = ExecutionStage::new(query_id, *num, distributed.data, inputs); *num += 1; if let Some(partitions_per_task) = self.partitions_per_task { diff --git a/src/plan/arrow_flight_read.rs b/src/plan/arrow_flight_read.rs index dd970f3..62c5e76 100644 --- a/src/plan/arrow_flight_read.rs +++ b/src/plan/arrow_flight_read.rs @@ -2,7 +2,7 @@ use super::combined::CombinedRecordBatchStream; use crate::channel_manager::ChannelManager; use crate::composed_extension_codec::ComposedPhysicalExtensionCodec; use crate::errors::tonic_status_to_datafusion_error; -use crate::flight_service::DoGet; +use crate::flight_service::{DoGet, StageKey}; use crate::plan::DistributedCodec; use crate::stage::{proto_from_stage, ExecutionStage}; use crate::user_provided_codec::get_user_codec; @@ -183,26 +183,44 @@ impl ExecutionPlan for ArrowFlightReadExec { internal_datafusion_err!("ArrowFlightReadExec: failed to convert stage to proto: {e}") })?; - let ticket_bytes = DoGet { - stage_proto: Some(child_stage_proto), - partition: partition as u64, - } - .encode_to_vec() - .into(); - - let ticket = Ticket { - ticket: ticket_bytes, - }; - let schema = child_stage.plan.schema(); let child_stage_tasks = child_stage.tasks.clone(); + let child_stage_num = child_stage.num as u64; + let query_id = stage.query_id.to_string(); + let stream = async move { - let futs = child_stage_tasks.iter().enumerate().map(|(i, task)| async { - let url = task.url()?.ok_or(internal_datafusion_err!( - "ArrowFlightReadExec: task is unassigned, cannot proceed" - ))?; - stream_from_stage_task(ticket.clone(), &url, schema.clone(), &channel_manager).await + let futs = child_stage_tasks.iter().enumerate().map(|(i, task)| { + let child_stage_proto_capture = child_stage_proto.clone(); + let channel_manager_capture = channel_manager.clone(); + let schema = schema.clone(); + let query_id = query_id.clone(); + let key = StageKey { + query_id, + stage_id: child_stage_num, + task_number: i as u64, + }; + async move { + let url = task.url()?.ok_or(internal_datafusion_err!( + "ArrowFlightReadExec: task is unassigned, cannot proceed" + ))?; + + let ticket_bytes = DoGet { + stage_proto: Some(child_stage_proto_capture), + partition: partition as u64, + stage_key: Some(key), + task_number: i as u64, + } + .encode_to_vec() + .into(); + + let ticket = Ticket { + ticket: ticket_bytes, + }; + + stream_from_stage_task(ticket, &url, schema.clone(), &channel_manager_capture) + .await + } }); let streams = future::try_join_all(futs).await?; @@ -226,7 +244,7 @@ async fn stream_from_stage_task( schema: SchemaRef, channel_manager: &ChannelManager, ) -> Result { - let channel = channel_manager.get_channel_for_url(&url).await?; + let channel = channel_manager.get_channel_for_url(url).await?; let mut client = FlightServiceClient::new(channel); let stream = client diff --git a/src/plan/isolator.rs b/src/plan/isolator.rs index 9e50c56..0ebf82a 100644 --- a/src/plan/isolator.rs +++ b/src/plan/isolator.rs @@ -1,17 +1,18 @@ use std::{fmt::Formatter, sync::Arc}; use datafusion::{ + common::internal_datafusion_err, error::Result, execution::SendableRecordBatchStream, physical_plan::{ - DisplayAs, DisplayFormatType, EmptyRecordBatchStream, ExecutionPlan, - ExecutionPlanProperties, Partitioning, PlanProperties, + DisplayAs, DisplayFormatType, EmptyRecordBatchStream, ExecutionPlan, Partitioning, + PlanProperties, }, }; /// We will add this as an extension to the SessionConfig whenever we need /// to execute a plan that might include this node. -pub struct PartitionGroup(Vec); +pub struct PartitionGroup(pub Vec); /// This is a simple execution plan that isolates a partition from the input /// plan It will advertise that it has a single partition and when @@ -88,9 +89,19 @@ impl ExecutionPlan for PartitionIsolatorExec { context: std::sync::Arc, ) -> Result { let config = context.session_config(); - let partition_group = &[0, 1]; + let partition_group = config + .get_extension::() + .ok_or(internal_datafusion_err!( + "No extension PartitionGroup in SessionConfig" + ))? + .0 + .clone(); - let partitions_in_input = self.input.output_partitioning().partition_count() as u64; + let partitions_in_input = self + .input + .properties() + .output_partitioning() + .partition_count(); // if our partition group is [7,8,9] and we are asked for parittion 1, // then look up that index in our group and execute that partition, in this @@ -103,8 +114,7 @@ impl ExecutionPlan for PartitionIsolatorExec { Ok(Box::pin(EmptyRecordBatchStream::new(self.input.schema())) as SendableRecordBatchStream) } else { - self.input - .execute(*actual_partition_number as usize, context) + self.input.execute(*actual_partition_number, context) } } None => Ok(Box::pin(EmptyRecordBatchStream::new(self.input.schema())) diff --git a/src/plan/mod.rs b/src/plan/mod.rs index 4ef089e..21b6ce5 100644 --- a/src/plan/mod.rs +++ b/src/plan/mod.rs @@ -5,4 +5,4 @@ mod isolator; pub use arrow_flight_read::ArrowFlightReadExec; pub use codec::DistributedCodec; -pub use isolator::PartitionIsolatorExec; +pub use isolator::{PartitionGroup, PartitionIsolatorExec}; diff --git a/src/stage/display.rs b/src/stage/display.rs index f3b7642..b4369d0 100644 --- a/src/stage/display.rs +++ b/src/stage/display.rs @@ -26,19 +26,13 @@ use super::ExecutionStage; // Unicode box-drawing characters for creating borders and connections. const LTCORNER: &str = "┌"; // Left top corner -const RTCORNER: &str = "┐"; // Right top corner const LDCORNER: &str = "└"; // Left bottom corner -const RDCORNER: &str = "┘"; // Right bottom corner - -const TMIDDLE: &str = "┬"; // Top T-junction (connects down) -const LMIDDLE: &str = "├"; // Left T-junction (connects right) -const DMIDDLE: &str = "┴"; // Bottom T-junction (connects up) - const VERTICAL: &str = "│"; // Vertical line const HORIZONTAL: &str = "─"; // Horizontal line impl DisplayAs for ExecutionStage { fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + #[allow(clippy::format_in_format_args)] match t { DisplayFormatType::Default => { write!(f, "{}", self.name) diff --git a/src/stage/stage.rs b/src/stage/execution_stage.rs similarity index 93% rename from src/stage/stage.rs rename to src/stage/execution_stage.rs index b29b4ef..d66591c 100644 --- a/src/stage/stage.rs +++ b/src/stage/execution_stage.rs @@ -3,12 +3,13 @@ use std::sync::Arc; use datafusion::common::internal_err; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::TaskContext; -use datafusion::physical_plan::{displayable, ExecutionPlan}; +use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; use itertools::Itertools; use rand::Rng; use url::Url; +use uuid::Uuid; use crate::task::ExecutionTask; use crate::ChannelManager; @@ -24,7 +25,7 @@ use crate::ChannelManager; /// When an [`ExecutionStage`] is execute()'d if will execute its plan and return a stream /// of record batches. /// -/// If the stage has input stages, then those input stages will be executed on remote resources +/// If the stage has input stages, then it those input stages will be executed on remote resources /// and will be provided the remainder of the stage tree. /// /// For example if our stage tree looks like this: @@ -74,6 +75,8 @@ use crate::ChannelManager; /// producing data from a [`DataSourceExec`]. #[derive(Debug, Clone)] pub struct ExecutionStage { + /// Our query_id + pub query_id: Uuid, /// Our stage number pub num: usize, /// Our stage name @@ -92,22 +95,18 @@ pub struct ExecutionStage { impl ExecutionStage { /// Creates a new `ExecutionStage` with the given plan and inputs. One task will be created /// responsible for partitions in the plan. - pub fn new(num: usize, plan: Arc, inputs: Vec>) -> Self { - println!( - "Creating ExecutionStage: {}, with inputs {}", - num, - inputs - .iter() - .map(|s| format!("{}", s.num)) - .collect::>() - .join(", ") - ); - + pub fn new( + query_id: Uuid, + num: usize, + plan: Arc, + inputs: Vec>, + ) -> Self { let name = format!("Stage {:<3}", num); let partition_group = (0..plan.properties().partitioning.partition_count()) .map(|p| p as u64) .collect(); ExecutionStage { + query_id, num, name, plan, @@ -209,9 +208,8 @@ impl ExecutionStage { }) .collect::>(); - println!("stage {} assigned_tasks: {:?}", self.num, assigned_tasks); - let assigned_stage = ExecutionStage { + query_id: self.query_id, num: self.num, name: self.name.clone(), plan: self.plan.clone(), @@ -242,6 +240,7 @@ impl ExecutionPlan for ExecutionStage { children: Vec>, ) -> Result> { Ok(Arc::new(ExecutionStage { + query_id: self.query_id, num: self.num, name: self.name.clone(), plan: self.plan.clone(), @@ -289,11 +288,6 @@ impl ExecutionPlan for ExecutionStage { let new_ctx = SessionContext::new_with_config_rt(config, context.runtime_env().clone()).task_ctx(); - println!( - "assinged_stage:\n{}", - displayable(assigned_stage.as_ref()).indent(true) - ); - assigned_stage.plan.execute(partition, new_ctx) } } diff --git a/src/stage/mod.rs b/src/stage/mod.rs index 034f1a3..3151f96 100644 --- a/src/stage/mod.rs +++ b/src/stage/mod.rs @@ -1,7 +1,7 @@ mod display; +mod execution_stage; mod proto; -mod stage; pub use display::display_stage_graphviz; +pub use execution_stage::ExecutionStage; pub use proto::{proto_from_stage, stage_from_proto, ExecutionStageProto}; -pub use stage::ExecutionStage; diff --git a/src/stage/proto.rs b/src/stage/proto.rs index d9a4d32..a4a9f97 100644 --- a/src/stage/proto.rs +++ b/src/stage/proto.rs @@ -17,21 +17,24 @@ use super::ExecutionStage; #[derive(Clone, PartialEq, ::prost::Message)] pub struct ExecutionStageProto { + /// Our query id + #[prost(bytes, tag = "1")] + pub query_id: Vec, /// Our stage number - #[prost(uint64, tag = "1")] + #[prost(uint64, tag = "2")] pub num: u64, /// Our stage name - #[prost(string, tag = "2")] + #[prost(string, tag = "3")] pub name: String, /// The physical execution plan that this stage will execute. - #[prost(message, optional, boxed, tag = "3")] + #[prost(message, optional, boxed, tag = "4")] pub plan: Option>, /// The input stages to this stage - #[prost(repeated, message, tag = "4")] - pub inputs: Vec>, + #[prost(repeated, message, tag = "5")] + pub inputs: Vec, /// Our tasks which tell us how finely grained to execute the partitions in /// the plan - #[prost(message, repeated, tag = "5")] + #[prost(message, repeated, tag = "6")] pub tasks: Vec, } @@ -42,10 +45,11 @@ pub fn proto_from_stage( let proto_plan = PhysicalPlanNode::try_from_physical_plan(stage.plan.clone(), codec)?; let inputs = stage .child_stages_iter() - .map(|s| Ok(Box::new(proto_from_stage(s, codec)?))) + .map(|s| proto_from_stage(s, codec)) .collect::>>()?; Ok(ExecutionStageProto { + query_id: stage.query_id.as_bytes().to_vec(), num: stage.num as u64, name: stage.name(), plan: Some(Box::new(proto_plan)), @@ -70,12 +74,16 @@ pub fn stage_from_proto( .inputs .into_iter() .map(|s| { - stage_from_proto(*s, registry, runtime, codec) + stage_from_proto(s, registry, runtime, codec) .map(|s| Arc::new(s) as Arc) }) .collect::>>()?; Ok(ExecutionStage { + query_id: msg + .query_id + .try_into() + .map_err(|_| internal_datafusion_err!("Invalid query_id in ExecutionStageProto"))?, num: msg.num as usize, name: msg.name, plan, @@ -102,6 +110,7 @@ mod tests { }; use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; use prost::Message; + use uuid::Uuid; use crate::stage::proto::proto_from_stage; use crate::stage::{proto::stage_from_proto, ExecutionStage, ExecutionStageProto}; @@ -141,6 +150,7 @@ mod tests { // Wrap it in an ExecutionStage let stage = ExecutionStage { + query_id: Uuid::new_v4(), num: 1, name: "TestStage".to_string(), plan: physical_plan, diff --git a/src/test_utils/insta.rs b/src/test_utils/insta.rs index 40fdb60..ad7b6c4 100644 --- a/src/test_utils/insta.rs +++ b/src/test_utils/insta.rs @@ -1,4 +1,3 @@ -use datafusion::common::utils::get_available_parallelism; use std::env; pub use insta; diff --git a/src/test_utils/localhost.rs b/src/test_utils/localhost.rs index fb5ff21..83a328e 100644 --- a/src/test_utils/localhost.rs +++ b/src/test_utils/localhost.rs @@ -8,31 +8,45 @@ use datafusion::execution::SessionStateBuilder; use datafusion::prelude::SessionContext; use datafusion::{common::runtime::JoinSet, prelude::SessionConfig}; use std::error::Error; -use std::sync::atomic::AtomicUsize; use std::sync::Arc; use std::time::Duration; +use tokio::net::TcpListener; use tonic::transport::{Channel, Server}; use url::Url; -pub async fn start_localhost_context( - ports: I, +pub async fn start_localhost_context( + num_workers: usize, session_builder: B, ) -> (SessionContext, JoinSet<()>) where - N::Error: std::fmt::Debug, - N: TryInto, - I: IntoIterator, B: SessionBuilder + Send + Sync + 'static, B: Clone, { - let ports: Vec = ports.into_iter().map(|x| x.try_into().unwrap()).collect(); + let listeners = futures::future::try_join_all( + (0..num_workers) + .map(|_| TcpListener::bind("127.0.0.1:0")) + .collect::>(), + ) + .await + .expect("Failed to bind to address"); + + let ports: Vec = listeners + .iter() + .map(|listener| { + listener + .local_addr() + .expect("Failed to get local address") + .port() + }) + .collect(); + let channel_resolver = LocalHostChannelResolver::new(ports.clone()); let mut join_set = JoinSet::new(); - for port in ports { + for listener in listeners { let channel_resolver = channel_resolver.clone(); let session_builder = session_builder.clone(); join_set.spawn(async move { - spawn_flight_service(channel_resolver, session_builder, port) + spawn_flight_service(channel_resolver, session_builder, listener) .await .unwrap(); }); @@ -63,7 +77,6 @@ where #[derive(Clone)] pub struct LocalHostChannelResolver { ports: Vec, - i: Arc, } impl LocalHostChannelResolver { @@ -72,7 +85,6 @@ impl LocalHostChannelResolver { N::Error: std::fmt::Debug, { Self { - i: Arc::new(AtomicUsize::new(0)), ports: ports.into_iter().map(|v| v.try_into().unwrap()).collect(), } } @@ -97,13 +109,16 @@ impl ChannelResolver for LocalHostChannelResolver { pub async fn spawn_flight_service( channel_resolver: impl ChannelResolver + Send + Sync + 'static, session_builder: impl SessionBuilder + Send + Sync + 'static, - port: u16, + incoming: TcpListener, ) -> Result<(), Box> { let mut endpoint = ArrowFlightEndpoint::new(channel_resolver); endpoint.with_session_builder(session_builder); + + let incoming = tokio_stream::wrappers::TcpListenerStream::new(incoming); + Ok(Server::builder() .add_service(FlightServiceServer::new(endpoint)) - .serve(format!("127.0.0.1:{port}").parse()?) + .serve_with_incoming(incoming) .await?) } diff --git a/src/test_utils/tpch.rs b/src/test_utils/tpch.rs index e435eab..b10787c 100644 --- a/src/test_utils/tpch.rs +++ b/src/test_utils/tpch.rs @@ -148,7 +148,7 @@ where writer.write(&first_batch)?; - while let Some(batch) = data_source.next() { + for batch in data_source { writer.write(&batch)?; } diff --git a/tests/custom_extension_codec.rs b/tests/custom_extension_codec.rs index c0673f8..3d2870d 100644 --- a/tests/custom_extension_codec.rs +++ b/tests/custom_extension_codec.rs @@ -49,8 +49,7 @@ mod tests { } } - let (ctx, _guard) = - start_localhost_context([50050, 50051, 50052], CustomSessionBuilder).await; + let (ctx, _guard) = start_localhost_context(3, CustomSessionBuilder).await; let single_node_plan = build_plan(false)?; assert_snapshot!(displayable(single_node_plan.as_ref()).indent(true).to_string(), @r" diff --git a/tests/distributed_aggregation.rs b/tests/distributed_aggregation.rs index 0d70ae4..fa59aa5 100644 --- a/tests/distributed_aggregation.rs +++ b/tests/distributed_aggregation.rs @@ -10,12 +10,8 @@ mod tests { use std::error::Error; #[tokio::test] - #[ignore] async fn distributed_aggregation() -> Result<(), Box> { - // FIXME: these ports are in use on my machine, we should find unused ports - // Changed them for now - let (ctx, _guard) = - start_localhost_context([40050, 40051, 40052], NoopSessionBuilder).await; + let (ctx, _guard) = start_localhost_context(3, NoopSessionBuilder).await; register_parquet_tables(&ctx).await?; let df = ctx @@ -31,16 +27,18 @@ mod tests { .indent(true) .to_string(); + println!("physical plan:\n{}", physical_str); + assert_snapshot!(physical_str, @r" ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday] SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST] - SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true] + SortExec: expr=[count(*)@0 ASC NULLS LAST], preserve_partitioning=[true] ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))] AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([RainToday@0], CPUs), input_partitions=CPUs - RepartitionExec: partitioning=RoundRobinBatch(CPUs), input_partitions=1 + RepartitionExec: partitioning=Hash([RainToday@0], 3), input_partitions=3 + RepartitionExec: partitioning=RoundRobinBatch(3), input_partitions=1 AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet ", @@ -48,18 +46,27 @@ mod tests { assert_snapshot!(physical_distributed_str, @r" - ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday] - SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST] - SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true] - ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))] - AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] - ArrowFlightReadExec: input_tasks=8 hash_expr=[] stage_id=UUID input_stage_id=UUID input_hosts=[http://localhost:50050/, http://localhost:50051/, http://localhost:50052/, http://localhost:50050/, http://localhost:50051/, http://localhost:50052/, http://localhost:50050/, http://localhost:50051/] - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([RainToday@0], CPUs), input_partitions=CPUs - RepartitionExec: partitioning=RoundRobinBatch(CPUs), input_partitions=1 - AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] - ArrowFlightReadExec: input_tasks=1 hash_expr=[RainToday@0] stage_id=UUID input_stage_id=UUID input_hosts=[http://localhost:50052/] - DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet + ┌───── Stage 3 Task: partitions: 0,unassigned] + │partitions [out:1 <-- in:1 ] ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday] + │partitions [out:1 <-- in:8 ] SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST] + │partitions [out:8 <-- in:8 ] SortExec: expr=[count(*)@0 ASC NULLS LAST], preserve_partitioning=[true] + │partitions [out:8 <-- in:8 ] ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))] + │partitions [out:8 <-- in:8 ] AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] + │partitions [out:8 ] ArrowFlightReadExec: Stage 2 + │ + └────────────────────────────────────────────────── + ┌───── Stage 2 Task: partitions: 0..2,unassigned] + │partitions [out:3 <-- in:3 ] CoalesceBatchesExec: target_batch_size=8192 + │partitions [out:3 <-- in:3 ] RepartitionExec: partitioning=Hash([RainToday@0], 3), input_partitions=3 + │partitions [out:3 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(3), input_partitions=1 + │partitions [out:1 <-- in:1 ] AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] + │partitions [out:1 ] ArrowFlightReadExec: Stage 1 + │ + └────────────────────────────────────────────────── + ┌───── Stage 1 Task: partitions: 0,unassigned] + │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet + │ + └────────────────────────────────────────────────── ", ); diff --git a/tests/error_propagation.rs b/tests/error_propagation.rs index 6afb196..e8d2ca5 100644 --- a/tests/error_propagation.rs +++ b/tests/error_propagation.rs @@ -36,8 +36,7 @@ mod tests { Ok(with_user_codec(builder, ErrorExecCodec)) } } - let (ctx, _guard) = - start_localhost_context([50050, 50051, 50053], CustomSessionBuilder).await; + let (ctx, _guard) = start_localhost_context(3, CustomSessionBuilder).await; let mut plan: Arc = Arc::new(ErrorExec::new("something failed")); diff --git a/tests/highly_distributed_query.rs b/tests/highly_distributed_query.rs index 3994ee1..d6c2cfa 100644 --- a/tests/highly_distributed_query.rs +++ b/tests/highly_distributed_query.rs @@ -12,13 +12,7 @@ mod tests { #[tokio::test] #[ignore] async fn highly_distributed_query() -> Result<(), Box> { - let (ctx, _guard) = start_localhost_context( - [ - 50050, 50051, 50053, 50054, 50055, 50056, 50057, 50058, 50059, - ], - NoopSessionBuilder, - ) - .await; + let (ctx, _guard) = start_localhost_context(9, NoopSessionBuilder).await; register_parquet_tables(&ctx).await?; let df = ctx.sql(r#"SELECT * FROM flights_1m"#).await?; diff --git a/tests/non_distributed_consistency_test.rs b/tests/non_distributed_consistency_test.rs index debcde0..ea2f395 100644 --- a/tests/non_distributed_consistency_test.rs +++ b/tests/non_distributed_consistency_test.rs @@ -3,11 +3,15 @@ mod common; #[cfg(all(feature = "integration", test))] mod tests { use crate::common::{ensure_tpch_data, get_test_data_dir, get_test_tpch_query}; + use async_trait::async_trait; + use datafusion::error::DataFusionError; use datafusion::execution::SessionStateBuilder; - use datafusion::physical_plan::execute_stream; use datafusion::prelude::{SessionConfig, SessionContext}; + use datafusion_distributed::test_utils::localhost::start_localhost_context; + use datafusion_distributed::{DistributedPhysicalOptimizerRule, SessionBuilder}; use futures::TryStreamExt; use std::error::Error; + use std::sync::Arc; #[tokio::test] async fn test_tpch_1() -> Result<(), Box> { @@ -80,9 +84,6 @@ mod tests { } #[tokio::test] - #[ignore] - // TODO: Support query 15? - // Skip because it contains DDL statements not supported in single SQL execution async fn test_tpch_15() -> Result<(), Box> { test_tpch_query(15).await } @@ -118,19 +119,66 @@ mod tests { } #[tokio::test] + // TODO: Add support for NestedLoopJoinExec to support query 22. + #[ignore] async fn test_tpch_22() -> Result<(), Box> { test_tpch_query(22).await } + async fn test_tpch_query(query_id: u8) -> Result<(), Box> { + let (ctx, _guard) = start_localhost_context(2, TestSessionBuilder).await; + run_tpch_query(ctx, query_id).await + } + + #[derive(Clone)] + struct TestSessionBuilder; + + #[async_trait] + impl SessionBuilder for TestSessionBuilder { + fn session_state_builder( + &self, + builder: SessionStateBuilder, + ) -> Result { + let mut config = SessionConfig::new().with_target_partitions(3); + + // FIXME: these three options are critical for the correct function of the library + // but we are not enforcing that the user sets them. They are here at the moment + // but we should figure out a way to do this better. + config + .options_mut() + .optimizer + .hash_join_single_partition_threshold = 0; + config + .options_mut() + .optimizer + .hash_join_single_partition_threshold_rows = 0; + + config.options_mut().optimizer.prefer_hash_join = true; + // end critical options section + + let rule = DistributedPhysicalOptimizerRule::new().with_maximum_partitions_per_task(2); + Ok(builder + .with_config(config) + .with_physical_optimizer_rule(Arc::new(rule))) + } + + async fn session_context( + &self, + ctx: SessionContext, + ) -> std::result::Result { + Ok(ctx) + } + } + // test_non_distributed_consistency runs each TPC-H query twice - once in a distributed manner // and once in a non-distributed manner. For each query, it asserts that the results are identical. - async fn test_tpch_query(query_id: u8) -> Result<(), Box> { + async fn run_tpch_query(ctx2: SessionContext, query_id: u8) -> Result<(), Box> { ensure_tpch_data().await; let sql = get_test_tpch_query(query_id); // Context 1: Non-distributed execution. - let config1 = SessionConfig::new(); + let config1 = SessionConfig::new().with_target_partitions(3); let state1 = SessionStateBuilder::new() .with_default_features() .with_config(config1) @@ -148,31 +196,7 @@ mod tests { datafusion::prelude::ParquetReadOptions::default(), ) .await?; - } - - let df1 = ctx1.sql(&sql).await?; - let physical1 = df1.create_physical_plan().await?; - - let batches1 = execute_stream(physical1.clone(), ctx1.task_ctx())? - .try_collect::>() - .await?; - - // Context 2: Distributed execution. - // TODO: once distributed execution is working, we can enable distributed features here. - let config2 = SessionConfig::new(); - // .with_target_partitions(3); - let state2 = SessionStateBuilder::new() - .with_default_features() - .with_config(config2) - // .with_optimizer_rule(DistributedPhysicalOptimizerRule::default().with_maximum_partitions_per_task(4)) - .build(); - let ctx2 = SessionContext::new_with_state(state2); - // Register tables for second context - for table_name in [ - "lineitem", "orders", "part", "partsupp", "customer", "nation", "region", "supplier", - ] { - let query_path = get_test_data_dir().join(format!("{}.parquet", table_name)); ctx2.register_parquet( table_name, query_path.to_string_lossy().as_ref(), @@ -181,12 +205,33 @@ mod tests { .await?; } - let df2 = ctx2.sql(&sql).await?; - let physical2 = df2.create_physical_plan().await?; - - let batches2 = execute_stream(physical2.clone(), ctx2.task_ctx())? - .try_collect::>() - .await?; + let (stream1, stream2) = if query_id == 15 { + let queries: Vec<&str> = sql + .split(';') + .map(str::trim) + .filter(|s| !s.is_empty()) + .collect(); + + println!("queryies: {:?}", queries); + + ctx1.sql(queries[0]).await?.collect().await?; + ctx2.sql(queries[0]).await?.collect().await?; + let df1 = ctx1.sql(queries[1]).await?; + let df2 = ctx2.sql(queries[1]).await?; + let stream1 = df1.execute_stream().await?; + let stream2 = df2.execute_stream().await?; + + ctx1.sql(queries[2]).await?.collect().await?; + ctx2.sql(queries[2]).await?.collect().await?; + (stream1, stream2) + } else { + let stream1 = ctx1.sql(&sql).await?.execute_stream().await?; + let stream2 = ctx2.sql(&sql).await?.execute_stream().await?; + (stream1, stream2) + }; + + let batches1 = stream1.try_collect::>().await?; + let batches2 = stream2.try_collect::>().await?; let formatted1 = arrow::util::pretty::pretty_format_batches(&batches1)?; let formatted2 = arrow::util::pretty::pretty_format_batches(&batches2)?; diff --git a/tests/stage_planning.rs b/tests/stage_planning.rs deleted file mode 100644 index ffacf77..0000000 --- a/tests/stage_planning.rs +++ /dev/null @@ -1,84 +0,0 @@ -mod common; - -#[cfg(all(feature = "integration", test))] -mod tests { - use crate::common::get_test_queries_dir; - use datafusion::arrow::util::pretty::pretty_format_batches; - use datafusion::execution::SessionStateBuilder; - use datafusion::physical_plan::{displayable, execute_stream}; - use datafusion::prelude::{SessionConfig, SessionContext}; - use datafusion_distributed::assert_snapshot; - use datafusion_distributed::test_utils::tpch::tpch_query_from_dir; - use datafusion_distributed::DistributedPhysicalOptimizerRule; - use datafusion_distributed::{display_stage_graphviz, ExecutionStage}; - use futures::TryStreamExt; - use std::error::Error; - use std::sync::Arc; - - // FIXME: ignored out until we figure out how to integrate best with tpch - #[tokio::test] - #[ignore] - async fn stage_planning() -> Result<(), Box> { - let config = SessionConfig::new().with_target_partitions(3); - - let rule = DistributedPhysicalOptimizerRule::default().with_maximum_partitions_per_task(4); - - let state = SessionStateBuilder::new() - .with_default_features() - .with_config(config) - .with_physical_optimizer_rule(Arc::new(rule)) - .build(); - - let ctx = SessionContext::new_with_state(state); - - for table_name in [ - "lineitem", "orders", "part", "partsupp", "customer", "nation", "region", "supplier", - ] { - let query_path = format!("testdata/tpch/{}.parquet", table_name); - ctx.register_parquet( - table_name, - query_path, - datafusion::prelude::ParquetReadOptions::default(), - ) - .await?; - } - - let queries_dir = get_test_queries_dir(); - let sql = tpch_query_from_dir(&queries_dir, 2); - //let sql = "select 1;"; - println!("SQL Query:\n{}", sql); - - let df = ctx.sql(&sql).await?; - - let physical = df.create_physical_plan().await?; - - let physical_str = displayable(physical.as_ref()).tree_render(); - println!("\n\nPhysical Plan:\n{}", physical_str); - - let physical_str = displayable(physical.as_ref()).indent(false); - println!("\n\nPhysical Plan:\n{}", physical_str); - - let physical_str = displayable(physical.as_ref()).indent(true); - println!("\n\nPhysical Plan:\n{}", physical_str); - - let physical_str = - display_stage_graphviz(physical.as_any().downcast_ref::().unwrap())?; - println!("\n\nPhysical Plan:\n{}", physical_str); - - assert_snapshot!(physical_str, - @r" - ", - ); - - let batches = pretty_format_batches( - &execute_stream(physical.clone(), ctx.task_ctx())? - .try_collect::>() - .await?, - )?; - - assert_snapshot!(batches, @r" - "); - - Ok(()) - } -}