From 72dc9708f1ac0081982ae60fc4b1f3e80cb7fb2a Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Tue, 5 Aug 2025 12:08:19 +0200 Subject: [PATCH 1/7] Uncomment tests in favor of just #[ignore]-ing them --- tests/common/mod.rs | 1 + tests/common/plan.rs | 68 +++++++++++++++++++++++++++++++ tests/custom_extension_codec.rs | 14 ++++--- tests/distributed_aggregation.rs | 43 +++++++++++++++++-- tests/error_propagation.rs | 12 +++--- tests/highly_distributed_query.rs | 15 +++---- tests/stage_planning.rs | 6 +-- 7 files changed, 134 insertions(+), 25 deletions(-) create mode 100644 tests/common/plan.rs diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 4a55075..a100491 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,3 +1,4 @@ pub mod insta; pub mod localhost; pub mod parquet; +pub mod plan; diff --git a/tests/common/plan.rs b/tests/common/plan.rs new file mode 100644 index 0000000..892bba5 --- /dev/null +++ b/tests/common/plan.rs @@ -0,0 +1,68 @@ +use datafusion::common::plan_err; +use datafusion::common::tree_node::{Transformed, TreeNode}; +use datafusion::error::DataFusionError; +use datafusion::physical_expr::Partitioning; +use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode}; +use datafusion::physical_plan::ExecutionPlan; +use datafusion_distributed::ArrowFlightReadExec; +use std::sync::Arc; + +pub fn distribute_aggregate( + plan: Arc, +) -> Result, DataFusionError> { + let mut aggregate_partial_found = false; + Ok(plan + .transform_up(|node| { + let Some(agg) = node.as_any().downcast_ref::() else { + return Ok(Transformed::no(node)); + }; + + match agg.mode() { + AggregateMode::Partial => { + if aggregate_partial_found { + return plan_err!("Two consecutive partial aggregations found"); + } + aggregate_partial_found = true; + let expr = agg + .group_expr() + .expr() + .iter() + .map(|(v, _)| Arc::clone(v)) + .collect::>(); + + if node.children().len() != 1 { + return plan_err!("Aggregate must have exactly one child"); + } + let child = node.children()[0].clone(); + + let node = node.with_new_children(vec![Arc::new(ArrowFlightReadExec::new( + Partitioning::Hash(expr, 1), + child.schema(), + 0, + ))])?; + Ok(Transformed::yes(node)) + } + AggregateMode::Final + | AggregateMode::FinalPartitioned + | AggregateMode::Single + | AggregateMode::SinglePartitioned => { + if !aggregate_partial_found { + return plan_err!("No partial aggregate found before the final one"); + } + + if node.children().len() != 1 { + return plan_err!("Aggregate must have exactly one child"); + } + let child = node.children()[0].clone(); + + let node = node.with_new_children(vec![Arc::new(ArrowFlightReadExec::new( + Partitioning::RoundRobinBatch(8), + child.schema(), + 1, + ))])?; + Ok(Transformed::yes(node)) + } + } + })? + .data) +} diff --git a/tests/custom_extension_codec.rs b/tests/custom_extension_codec.rs index 7a64a71..7367616 100644 --- a/tests/custom_extension_codec.rs +++ b/tests/custom_extension_codec.rs @@ -1,6 +1,6 @@ #[allow(dead_code)] mod common; -/* + #[cfg(test)] mod tests { use crate::assert_snapshot; @@ -27,7 +27,7 @@ mod tests { use datafusion::physical_plan::{ displayable, execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; - use datafusion_distributed::{assign_stages, ArrowFlightReadExec, SessionBuilder}; + use datafusion_distributed::{ArrowFlightReadExec, SessionBuilder}; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_proto::protobuf::proto_error; use futures::{stream, TryStreamExt}; @@ -37,6 +37,7 @@ mod tests { use std::sync::Arc; #[tokio::test] + #[ignore] async fn custom_extension_codec() -> Result<(), Box> { #[derive(Clone)] struct CustomSessionBuilder; @@ -66,7 +67,6 @@ mod tests { "); let distributed_plan = build_plan(true)?; - let distributed_plan = assign_stages(distributed_plan, &ctx)?; assert_snapshot!(displayable(distributed_plan.as_ref()).indent(true).to_string(), @r" SortExec: expr=[numbers@0 DESC NULLS LAST], preserve_partitioning=[false] @@ -124,8 +124,9 @@ mod tests { if distributed { plan = Arc::new(ArrowFlightReadExec::new( - plan.clone(), Partitioning::Hash(vec![col("numbers", &plan.schema())?], 1), + plan.clone().schema(), + 0, // TODO: stage num should be assigned by someone else )); } @@ -139,8 +140,9 @@ mod tests { if distributed { plan = Arc::new(ArrowFlightReadExec::new( - plan.clone(), Partitioning::RoundRobinBatch(10), + plan.clone().schema(), + 1, // TODO: stage num should be assigned by someone else )); plan = Arc::new(RepartitionExec::try_new( @@ -266,4 +268,4 @@ mod tests { .map_err(|err| proto_error(format!("{err}"))) } } -}*/ +} diff --git a/tests/distributed_aggregation.rs b/tests/distributed_aggregation.rs index 68e6def..3776606 100644 --- a/tests/distributed_aggregation.rs +++ b/tests/distributed_aggregation.rs @@ -6,12 +6,14 @@ mod tests { use crate::assert_snapshot; use crate::common::localhost::{start_localhost_context, NoopSessionBuilder}; use crate::common::parquet::register_parquet_tables; + use crate::common::plan::distribute_aggregate; use datafusion::arrow::util::pretty::pretty_format_batches; use datafusion::physical_plan::{displayable, execute_stream}; use futures::TryStreamExt; use std::error::Error; #[tokio::test] + #[ignore] async fn distributed_aggregation() -> Result<(), Box> { // FIXME these ports are in use on my machine, we should find unused ports // Changed them for now @@ -26,9 +28,13 @@ mod tests { let physical_str = displayable(physical.as_ref()).indent(true).to_string(); - println!("\n\nPhysical Plan:\n{}", physical_str); + let physical_distributed = distribute_aggregate(physical.clone())?; - /*assert_snapshot!(physical_str, + let physical_distributed_str = displayable(physical_distributed.as_ref()) + .indent(true) + .to_string(); + + assert_snapshot!(physical_str, @r" ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday] SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST] @@ -41,7 +47,24 @@ mod tests { AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet ", - );*/ + ); + + assert_snapshot!(physical_distributed_str, + @r" + ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday] + SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST] + SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true] + ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))] + AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] + ArrowFlightReadExec: input_tasks=8 hash_expr=[] stage_id=UUID input_stage_id=UUID input_hosts=[http://localhost:50050/, http://localhost:50051/, http://localhost:50052/, http://localhost:50050/, http://localhost:50051/, http://localhost:50052/, http://localhost:50050/, http://localhost:50051/] + CoalesceBatchesExec: target_batch_size=8192 + RepartitionExec: partitioning=Hash([RainToday@0], CPUs), input_partitions=CPUs + RepartitionExec: partitioning=RoundRobinBatch(CPUs), input_partitions=1 + AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] + ArrowFlightReadExec: input_tasks=1 hash_expr=[RainToday@0] stage_id=UUID input_stage_id=UUID input_hosts=[http://localhost:50052/] + DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet + ", + ); let batches = pretty_format_batches( &execute_stream(physical, ctx.task_ctx())? @@ -58,6 +81,20 @@ mod tests { +----------+-----------+ "); + let batches_distributed = pretty_format_batches( + &execute_stream(physical_distributed, ctx.task_ctx())? + .try_collect::>() + .await?, + )?; + assert_snapshot!(batches_distributed, @r" + +----------+-----------+ + | count(*) | RainToday | + +----------+-----------+ + | 66 | Yes | + | 300 | No | + +----------+-----------+ + "); + Ok(()) } } diff --git a/tests/error_propagation.rs b/tests/error_propagation.rs index 1fce8ad..05e3e95 100644 --- a/tests/error_propagation.rs +++ b/tests/error_propagation.rs @@ -1,6 +1,6 @@ #[allow(dead_code)] mod common; -/* + #[cfg(test)] mod tests { use crate::common::localhost::start_localhost_context; @@ -26,6 +26,7 @@ mod tests { use std::sync::Arc; #[tokio::test] + #[ignore] async fn test_error_propagation() -> Result<(), Box> { #[derive(Clone)] struct CustomSessionBuilder; @@ -48,14 +49,13 @@ mod tests { let mut plan: Arc = Arc::new(ErrorExec::new("something failed")); - for size in [1, 2, 3] { + for (i, size) in [1, 2, 3].iter().enumerate() { plan = Arc::new(ArrowFlightReadExec::new( - Partitioning::RoundRobinBatch(size), + Partitioning::RoundRobinBatch(*size as usize), plan.schema(), - 0, + i, )); } - let plan = assign_stages(plan, &ctx)?; let stream = execute_stream(plan, ctx.task_ctx())?; let Err(err) = stream.try_collect::>().await else { @@ -170,4 +170,4 @@ mod tests { .map_err(|err| proto_error(format!("{err}"))) } } -}*/ +} diff --git a/tests/highly_distributed_query.rs b/tests/highly_distributed_query.rs index 3523694..385c430 100644 --- a/tests/highly_distributed_query.rs +++ b/tests/highly_distributed_query.rs @@ -1,6 +1,6 @@ #[allow(dead_code)] mod common; -/* + #[cfg(test)] mod tests { use crate::assert_snapshot; @@ -8,12 +8,13 @@ mod tests { use crate::common::parquet::register_parquet_tables; use datafusion::physical_expr::Partitioning; use datafusion::physical_plan::{displayable, execute_stream}; - use datafusion_distributed::{assign_stages, ArrowFlightReadExec}; + use datafusion_distributed::ArrowFlightReadExec; use futures::TryStreamExt; use std::error::Error; use std::sync::Arc; #[tokio::test] + #[ignore] async fn highly_distributed_query() -> Result<(), Box> { let (ctx, _guard) = start_localhost_context( [ @@ -29,13 +30,13 @@ mod tests { let physical_str = displayable(physical.as_ref()).indent(true).to_string(); let mut physical_distributed = physical.clone(); - for size in [1, 10, 5] { + for (i, size) in [1, 10, 5].iter().enumerate() { physical_distributed = Arc::new(ArrowFlightReadExec::new( - physical_distributed.clone(), - Partitioning::RoundRobinBatch(size), + Partitioning::RoundRobinBatch(*size as usize), + physical_distributed.schema(), + i, )); } - let physical_distributed = assign_stages(physical_distributed, &ctx)?; let physical_distributed_str = displayable(physical_distributed.as_ref()) .indent(true) .to_string(); @@ -75,4 +76,4 @@ mod tests { Ok(()) } -}*/ +} diff --git a/tests/stage_planning.rs b/tests/stage_planning.rs index 1ed6558..c567459 100644 --- a/tests/stage_planning.rs +++ b/tests/stage_planning.rs @@ -1,8 +1,6 @@ mod common; mod tpch; -// FIXME: commented out until we figure out how to integrate best with tpch -/* #[cfg(test)] mod tests { use crate::tpch::tpch_query; @@ -17,7 +15,9 @@ mod tests { use std::error::Error; use std::sync::Arc; + // FIXME: ignored out until we figure out how to integrate best with tpch #[tokio::test] + #[ignore] async fn stage_planning() -> Result<(), Box> { let config = SessionConfig::new().with_target_partitions(3); @@ -86,4 +86,4 @@ mod tests { Ok(()) } -}*/ +} From 15e4edf4aac92ed7fa0811dee772f00f6f3bd329 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Tue, 5 Aug 2025 12:16:50 +0200 Subject: [PATCH 2/7] Remove unused import statements --- src/common/mod.rs | 1 - src/common/result.rs | 1 - src/flight_service/do_get.rs | 1 - src/physical_optimizer.rs | 1 - src/plan/arrow_flight_read.rs | 2 +- src/stage/display.rs | 2 +- src/stage/proto.rs | 1 - src/task.rs | 3 +-- 8 files changed, 3 insertions(+), 9 deletions(-) delete mode 100644 src/common/result.rs diff --git a/src/common/mod.rs b/src/common/mod.rs index 5668f16..812d1ed 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,2 +1 @@ -pub mod result; pub mod util; diff --git a/src/common/result.rs b/src/common/result.rs deleted file mode 100644 index 8b13789..0000000 --- a/src/common/result.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index afbf3c5..ed9646b 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -8,7 +8,6 @@ use arrow_flight::flight_service_server::FlightService; use arrow_flight::Ticket; use datafusion::execution::SessionStateBuilder; use datafusion::optimizer::OptimizerConfig; -use datafusion::physical_plan::ExecutionPlan; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use futures::TryStreamExt; use prost::Message; diff --git a/src/physical_optimizer.rs b/src/physical_optimizer.rs index f5ebbe1..3f7a057 100644 --- a/src/physical_optimizer.rs +++ b/src/physical_optimizer.rs @@ -6,7 +6,6 @@ use datafusion::{ tree_node::{Transformed, TreeNode, TreeNodeRewriter}, }, config::ConfigOptions, - datasource::physical_plan::FileSource, error::Result, physical_optimizer::PhysicalOptimizerRule, physical_plan::{ diff --git a/src/plan/arrow_flight_read.rs b/src/plan/arrow_flight_read.rs index 36be92f..b0c37d1 100644 --- a/src/plan/arrow_flight_read.rs +++ b/src/plan/arrow_flight_read.rs @@ -5,7 +5,7 @@ use arrow_flight::{FlightClient, Ticket}; use datafusion::arrow::datatypes::SchemaRef; use datafusion::common::{internal_datafusion_err, plan_err}; use datafusion::error::Result; -use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; diff --git a/src/stage/display.rs b/src/stage/display.rs index 51e980d..6d5e3d6 100644 --- a/src/stage/display.rs +++ b/src/stage/display.rs @@ -14,7 +14,7 @@ use std::fmt::Write; use datafusion::{ error::Result, - physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}, + physical_plan::{DisplayAs, DisplayFormatType}, }; use crate::{ diff --git a/src/stage/proto.rs b/src/stage/proto.rs index 1d61dc1..f50f1fa 100644 --- a/src/stage/proto.rs +++ b/src/stage/proto.rs @@ -10,7 +10,6 @@ use datafusion_proto::{ physical_plan::{AsExecutionPlan, PhysicalExtensionCodec}, protobuf::PhysicalPlanNode, }; -use prost::Message; use crate::{plan::DistributedCodec, task::ExecutionTask}; diff --git a/src/task.rs b/src/task.rs index 7913848..b912834 100644 --- a/src/task.rs +++ b/src/task.rs @@ -3,9 +3,8 @@ use std::fmt::Display; use std::fmt::Formatter; use datafusion::common::internal_datafusion_err; -use prost::Message; - use datafusion::error::Result; + use url::Url; #[derive(Clone, PartialEq, ::prost::Message)] From e4e2dbb222bce5143cfa767ac36277838531e602 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Tue, 5 Aug 2025 12:17:23 +0200 Subject: [PATCH 3/7] Move common tpch module to common --- tests/common/mod.rs | 1 + tests/{tpch/mod.rs => common/tpch.rs} | 0 tests/stage_planning.rs | 20 +++++++------------- 3 files changed, 8 insertions(+), 13 deletions(-) rename tests/{tpch/mod.rs => common/tpch.rs} (100%) diff --git a/tests/common/mod.rs b/tests/common/mod.rs index a100491..6a11300 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -2,3 +2,4 @@ pub mod insta; pub mod localhost; pub mod parquet; pub mod plan; +pub mod tpch; diff --git a/tests/tpch/mod.rs b/tests/common/tpch.rs similarity index 100% rename from tests/tpch/mod.rs rename to tests/common/tpch.rs diff --git a/tests/stage_planning.rs b/tests/stage_planning.rs index c567459..9632d5b 100644 --- a/tests/stage_planning.rs +++ b/tests/stage_planning.rs @@ -1,12 +1,12 @@ +#[allow(dead_code)] mod common; -mod tpch; #[cfg(test)] mod tests { - use crate::tpch::tpch_query; - use crate::{assert_snapshot, tpch}; + use crate::assert_snapshot; + use crate::common::tpch::tpch_query; use datafusion::arrow::util::pretty::pretty_format_batches; - use datafusion::execution::{SessionState, SessionStateBuilder}; + use datafusion::execution::SessionStateBuilder; use datafusion::physical_plan::{displayable, execute_stream}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_distributed::physical_optimizer::DistributedPhysicalOptimizerRule; @@ -69,20 +69,14 @@ mod tests { ", ); - /*let batches = pretty_format_batches( - &execute_stream(physical, ctx.task_ctx())? + let batches = pretty_format_batches( + &execute_stream(physical.clone(), ctx.task_ctx())? .try_collect::>() .await?, )?; assert_snapshot!(batches, @r" - +----------+-----------+ - | count(*) | RainToday | - +----------+-----------+ - | 66 | Yes | - | 300 | No | - +----------+-----------+ - ");*/ + "); Ok(()) } From 3f9100328dc066f21c91641db3f1d5b23565c7d9 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Tue, 5 Aug 2025 14:53:48 +0200 Subject: [PATCH 4/7] Unignore one more test --- src/stage/proto.rs | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/src/stage/proto.rs b/src/stage/proto.rs index f50f1fa..7d0d0d8 100644 --- a/src/stage/proto.rs +++ b/src/stage/proto.rs @@ -99,9 +99,7 @@ pub fn stage_from_proto( } // add tests for round trip to and from a proto message for ExecutionStage -/* TODO: broken for now #[cfg(test)] - mod tests { use std::sync::Arc; @@ -110,19 +108,13 @@ mod tests { array::{RecordBatch, StringArray, UInt8Array}, datatypes::{DataType, Field, Schema}, }, - catalog::memory::DataSourceExec, - common::{internal_datafusion_err, internal_err}, + common::internal_datafusion_err, datasource::MemTable, - error::{DataFusionError, Result}, + error::Result, execution::context::SessionContext, - prelude::SessionConfig, - }; - use datafusion_proto::{ - physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}, - protobuf::PhysicalPlanNode, }; + use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; use prost::Message; - use uuid::Uuid; use crate::stage::{proto::stage_from_proto, ExecutionStage, ExecutionStageProto}; @@ -147,6 +139,7 @@ mod tests { } #[tokio::test] + #[ignore] async fn test_execution_stage_proto_round_trip() -> Result<()> { let ctx = SessionContext::new(); let mem_table = create_mem_table(); @@ -195,4 +188,4 @@ mod tests { assert_eq!(stage.name, round_trip_stage.name); Ok(()) } -}*/ +} From 17be2c5ba1bbb5ee91834fc8fd74e3f93bae56cf Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Tue, 5 Aug 2025 17:19:26 +0200 Subject: [PATCH 5/7] Fix ArrowFlightReadExec --- src/plan/arrow_flight_read.rs | 58 +++++++++++++++++------------------ 1 file changed, 28 insertions(+), 30 deletions(-) diff --git a/src/plan/arrow_flight_read.rs b/src/plan/arrow_flight_read.rs index b0c37d1..ae15fc8 100644 --- a/src/plan/arrow_flight_read.rs +++ b/src/plan/arrow_flight_read.rs @@ -1,10 +1,15 @@ +use super::combined::CombinedRecordBatchStream; use crate::channel_manager::ChannelManager; +use crate::errors::tonic_status_to_datafusion_error; use crate::flight_service::DoGet; use crate::stage::{ExecutionStage, ExecutionStageProto}; -use arrow_flight::{FlightClient, Ticket}; +use arrow_flight::decode::FlightRecordBatchStream; +use arrow_flight::error::FlightError; +use arrow_flight::flight_service_client::FlightServiceClient; +use arrow_flight::Ticket; use datafusion::arrow::datatypes::SchemaRef; use datafusion::common::{internal_datafusion_err, plan_err}; -use datafusion::error::Result; +use datafusion::error::{DataFusionError, Result}; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; @@ -15,11 +20,8 @@ use prost::Message; use std::any::Any; use std::fmt::Formatter; use std::sync::Arc; -use tonic::transport::Channel; use url::Url; -use super::combined::CombinedRecordBatchStream; - #[derive(Debug, Clone)] pub struct ArrowFlightReadExec { /// the number of the stage we are reading from @@ -125,8 +127,6 @@ impl ExecutionPlan for ArrowFlightReadExec { let schema = child_stage.plan.schema(); let stream = async move { - // concurrenly build streams for each stage - // TODO: tokio spawn instead here? let futs = child_stage_tasks.iter().map(|task| async { let url = task.url()?.ok_or(internal_datafusion_err!( "ArrowFlightReadExec: task is unassigned, cannot proceed" @@ -153,31 +153,29 @@ async fn stream_from_stage_task( ticket: Ticket, url: &Url, schema: SchemaRef, - _channel_manager: &ChannelManager, -) -> Result { - // FIXME: I cannot figure how how to use the arrow_flight::client::FlightClient (a mid level - // client) with the ChannelManager, so we willc create a new Channel directly for now + channel_manager: &ChannelManager, +) -> Result { + let channel = channel_manager.get_channel_for_url(&url).await?; - //let channel = channel_manager.get_channel_for_url(&url).await?; - - let channel = Channel::from_shared(url.to_string()) - .map_err(|e| internal_datafusion_err!("Failed to create channel from URL: {e:#?}"))? - .connect() - .await - .map_err(|e| internal_datafusion_err!("Failed to connect to channel: {e:#?}"))?; - - let mut client = FlightClient::new(channel); - - let flight_stream = client + let mut client = FlightServiceClient::new(channel); + let stream = client .do_get(ticket) .await - .map_err(|e| internal_datafusion_err!("Failed to execute do_get for ticket: {e:#?}"))?; - - let record_batch_stream = RecordBatchStreamAdapter::new( + .map_err(|err| { + tonic_status_to_datafusion_error(&err) + .unwrap_or_else(|| DataFusionError::External(Box::new(err))) + })? + .into_inner() + .map_err(|err| FlightError::Tonic(Box::new(err))); + + let stream = FlightRecordBatchStream::new_from_flight_data(stream).map_err(|err| match err { + FlightError::Tonic(status) => tonic_status_to_datafusion_error(&status) + .unwrap_or_else(|| DataFusionError::External(Box::new(status))), + err => DataFusionError::External(Box::new(err)), + }); + + Ok(Box::pin(RecordBatchStreamAdapter::new( schema.clone(), - flight_stream - .map_err(|e| internal_datafusion_err!("Failed to decode flight stream: {e:#?}")), - ); - - Ok(Box::pin(record_batch_stream) as SendableRecordBatchStream) + stream, + ))) } From 8067b446271d3fdcc2b3694b0bb27a24a0a25f70 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Wed, 6 Aug 2025 07:33:16 +0200 Subject: [PATCH 6/7] Add stage planner tests --- src/physical_optimizer.rs | 204 ++++++++++++++++++++++++++++++++++++++ src/test_utils/insta.rs | 24 +++++ src/test_utils/mod.rs | 5 +- src/test_utils/parquet.rs | 20 ++++ 4 files changed, 251 insertions(+), 2 deletions(-) create mode 100644 src/test_utils/insta.rs create mode 100644 src/test_utils/parquet.rs diff --git a/src/physical_optimizer.rs b/src/physical_optimizer.rs index 3f7a057..9a54486 100644 --- a/src/physical_optimizer.rs +++ b/src/physical_optimizer.rs @@ -258,3 +258,207 @@ impl TreeNodeRewriter for StagePlanner { } } } + +#[cfg(test)] +mod tests { + use crate::assert_snapshot; + use crate::physical_optimizer::DistributedPhysicalOptimizerRule; + use crate::test_utils::register_parquet_tables; + use datafusion::error::DataFusionError; + use datafusion::execution::SessionStateBuilder; + use datafusion::physical_plan::displayable; + use datafusion::prelude::{SessionConfig, SessionContext}; + use std::sync::Arc; + + /* shema for the "weather" table + + MinTemp [type=DOUBLE] [repetitiontype=OPTIONAL] + MaxTemp [type=DOUBLE] [repetitiontype=OPTIONAL] + Rainfall [type=DOUBLE] [repetitiontype=OPTIONAL] + Evaporation [type=DOUBLE] [repetitiontype=OPTIONAL] + Sunshine [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL] + WindGustDir [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL] + WindGustSpeed [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL] + WindDir9am [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL] + WindDir3pm [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL] + WindSpeed9am [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL] + WindSpeed3pm [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL] + Humidity9am [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL] + Humidity3pm [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL] + Pressure9am [type=DOUBLE] [repetitiontype=OPTIONAL] + Pressure3pm [type=DOUBLE] [repetitiontype=OPTIONAL] + Cloud9am [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL] + Cloud3pm [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL] + Temp9am [type=DOUBLE] [repetitiontype=OPTIONAL] + Temp3pm [type=DOUBLE] [repetitiontype=OPTIONAL] + RainToday [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL] + RISK_MM [type=DOUBLE] [repetitiontype=OPTIONAL] + RainTomorrow [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL] + */ + + #[tokio::test] + async fn test_select_all() { + let query = r#"SELECT * FROM weather"#; + let plan = sql_to_explain(query).await.unwrap(); + assert_snapshot!(plan, @r" + ┌───── Stage 1 Task: partitions: 0,unassigned] + │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MinTemp, MaxTemp, Rainfall, Evaporation, Sunshine, WindGustDir, WindGustSpeed, WindDir9am, WindDir3pm, WindSpeed9am, WindSpeed3pm, Humidity9am, Humidity3pm, Pressure9am, Pressure3pm, Cloud9am, Cloud3pm, Temp9am, Temp3pm, RainToday, RISK_MM, RainTomorrow], file_type=parquet + │ + └────────────────────────────────────────────────── + "); + } + + #[tokio::test] + async fn test_aggregation() { + let query = + r#"SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)"#; + let plan = sql_to_explain(query).await.unwrap(); + assert_snapshot!(plan, @r" + ┌───── Stage 3 Task: partitions: 0,unassigned] + │partitions [out:1 <-- in:1 ] ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday] + │partitions [out:1 <-- in:4 ] SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST] + │partitions [out:4 <-- in:4 ] SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true] + │partitions [out:4 <-- in:4 ] ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))] + │partitions [out:4 <-- in:4 ] AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] + │partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192 + │partitions [out:4 ] ArrowFlightReadExec: Stage 2 + │ + └────────────────────────────────────────────────── + ┌───── Stage 2 Task: partitions: 0..3,unassigned] + │partitions [out:4 <-- in:4 ] RepartitionExec: partitioning=Hash([RainToday@0], 4), input_partitions=4 + │partitions [out:4 ] ArrowFlightReadExec: Stage 1 + │ + └────────────────────────────────────────────────── + ┌───── Stage 1 Task: partitions: 0..3,unassigned] + │partitions [out:4 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 + │partitions [out:1 <-- in:1 ] AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] + │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet + │ + └────────────────────────────────────────────────── + "); + } + + #[tokio::test] + async fn test_aggregation_with_partitions_per_task() { + let query = + r#"SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)"#; + let plan = sql_to_explain_partitions_per_task(query, 2).await.unwrap(); + assert_snapshot!(plan, @r" + ┌───── Stage 3 Task: partitions: 0,unassigned] + │partitions [out:1 <-- in:1 ] ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday] + │partitions [out:1 <-- in:4 ] SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST] + │partitions [out:4 <-- in:4 ] SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true] + │partitions [out:4 <-- in:4 ] ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))] + │partitions [out:4 <-- in:4 ] AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] + │partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192 + │partitions [out:4 ] ArrowFlightReadExec: Stage 2 + │ + └────────────────────────────────────────────────── + ┌───── Stage 2 Task: partitions: 0,1,unassigned],Task: partitions: 2,3,unassigned] + │partitions [out:4 <-- in:2 ] RepartitionExec: partitioning=Hash([RainToday@0], 4), input_partitions=2 + │partitions [out:2 <-- in:4 ] PartitionIsolatorExec [providing upto 2 partitions] + │partitions [out:4 ] ArrowFlightReadExec: Stage 1 + │ + └────────────────────────────────────────────────── + ┌───── Stage 1 Task: partitions: 0,1,unassigned],Task: partitions: 2,3,unassigned] + │partitions [out:4 <-- in:2 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2 + │partitions [out:2 <-- in:1 ] PartitionIsolatorExec [providing upto 2 partitions] + │partitions [out:1 <-- in:1 ] AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] + │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet + │ + └────────────────────────────────────────────────── + "); + } + + #[tokio::test] + async fn test_left_join() { + let query = r#"SELECT a."MinTemp", b."MaxTemp" FROM weather a LEFT JOIN weather b ON a."RainToday" = b."RainToday" "#; + let plan = sql_to_explain(query).await.unwrap(); + assert_snapshot!(plan, @r" + ┌───── Stage 1 Task: partitions: 0,unassigned] + │partitions [out:1 <-- in:1 ] CoalesceBatchesExec: target_batch_size=8192 + │partitions [out:1 <-- in:1 ] HashJoinExec: mode=Partitioned, join_type=Left, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2] + │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet + │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet + │ + └────────────────────────────────────────────────── + "); + } + + #[tokio::test] + async fn test_sort() { + let query = r#"SELECT * FROM weather ORDER BY "MinTemp" DESC "#; + let plan = sql_to_explain(query).await.unwrap(); + assert_snapshot!(plan, @r" + ┌───── Stage 1 Task: partitions: 0,unassigned] + │partitions [out:1 <-- in:1 ] SortExec: expr=[MinTemp@0 DESC], preserve_partitioning=[false] + │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MinTemp, MaxTemp, Rainfall, Evaporation, Sunshine, WindGustDir, WindGustSpeed, WindDir9am, WindDir3pm, WindSpeed9am, WindSpeed3pm, Humidity9am, Humidity3pm, Pressure9am, Pressure3pm, Cloud9am, Cloud3pm, Temp9am, Temp3pm, RainToday, RISK_MM, RainTomorrow], file_type=parquet + │ + └────────────────────────────────────────────────── + "); + } + + #[tokio::test] + async fn test_distinct() { + let query = r#"SELECT DISTINCT "RainToday", "WindGustDir" FROM weather"#; + let plan = sql_to_explain(query).await.unwrap(); + assert_snapshot!(plan, @r" + ┌───── Stage 3 Task: partitions: 0..3,unassigned] + │partitions [out:4 <-- in:4 ] AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday, WindGustDir@1 as WindGustDir], aggr=[] + │partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192 + │partitions [out:4 ] ArrowFlightReadExec: Stage 2 + │ + └────────────────────────────────────────────────── + ┌───── Stage 2 Task: partitions: 0..3,unassigned] + │partitions [out:4 <-- in:4 ] RepartitionExec: partitioning=Hash([RainToday@0, WindGustDir@1], 4), input_partitions=4 + │partitions [out:4 ] ArrowFlightReadExec: Stage 1 + │ + └────────────────────────────────────────────────── + ┌───── Stage 1 Task: partitions: 0..3,unassigned] + │partitions [out:4 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 + │partitions [out:1 <-- in:1 ] AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday, WindGustDir@1 as WindGustDir], aggr=[] + │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday, WindGustDir], file_type=parquet + │ + └────────────────────────────────────────────────── + "); + } + + async fn sql_to_explain(query: &str) -> Result { + sql_to_explain_with_rule(query, DistributedPhysicalOptimizerRule::new()).await + } + + async fn sql_to_explain_partitions_per_task( + query: &str, + partitions_per_task: usize, + ) -> Result { + sql_to_explain_with_rule( + query, + DistributedPhysicalOptimizerRule::new() + .with_maximum_partitions_per_task(partitions_per_task), + ) + .await + } + + async fn sql_to_explain_with_rule( + query: &str, + rule: DistributedPhysicalOptimizerRule, + ) -> Result { + let config = SessionConfig::new().with_target_partitions(4); + + let state = SessionStateBuilder::new() + .with_default_features() + .with_physical_optimizer_rule(Arc::new(rule)) + .with_config(config) + .build(); + + let ctx = SessionContext::new_with_state(state); + register_parquet_tables(&ctx).await?; + + let df = ctx.sql(query).await?; + + let physical_plan = df.create_physical_plan().await?; + let display = displayable(physical_plan.as_ref()).indent(true).to_string(); + + Ok(display) + } +} diff --git a/src/test_utils/insta.rs b/src/test_utils/insta.rs new file mode 100644 index 0000000..00e9c1d --- /dev/null +++ b/src/test_utils/insta.rs @@ -0,0 +1,24 @@ +use std::env; + +#[macro_export] +macro_rules! assert_snapshot { + ($($arg:tt)*) => { + crate::test_utils::insta::settings().bind(|| { + insta::assert_snapshot!($($arg)*); + }) + }; +} + +pub fn settings() -> insta::Settings { + env::set_var("INSTA_WORKSPACE_ROOT", env!("CARGO_MANIFEST_DIR")); + let mut settings = insta::Settings::clone_current(); + let cwd = env::current_dir().unwrap(); + let cwd = cwd.to_str().unwrap(); + settings.add_filter(cwd.trim_start_matches("/"), ""); + settings.add_filter( + r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}", + "UUID", + ); + + settings +} diff --git a/src/test_utils/mod.rs b/src/test_utils/mod.rs index 83d345e..c012f4b 100644 --- a/src/test_utils/mod.rs +++ b/src/test_utils/mod.rs @@ -1,5 +1,6 @@ -#[cfg(test)] +pub mod insta; mod mock_exec; +mod parquet; -#[cfg(test)] pub use mock_exec::MockExec; +pub use parquet::register_parquet_tables; diff --git a/src/test_utils/parquet.rs b/src/test_utils/parquet.rs new file mode 100644 index 0000000..6efadd4 --- /dev/null +++ b/src/test_utils/parquet.rs @@ -0,0 +1,20 @@ +use datafusion::error::DataFusionError; +use datafusion::prelude::{ParquetReadOptions, SessionContext}; + +pub async fn register_parquet_tables(ctx: &SessionContext) -> Result<(), DataFusionError> { + ctx.register_parquet( + "flights_1m", + "testdata/flights-1m.parquet", + ParquetReadOptions::default(), + ) + .await?; + + ctx.register_parquet( + "weather", + "testdata/weather.parquet", + ParquetReadOptions::default(), + ) + .await?; + + Ok(()) +} From 2acc65f524ddbbd32409565bb7e744a2fa6e3375 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Wed, 6 Aug 2025 12:37:32 +0200 Subject: [PATCH 7/7] Add test_left_join_distributed --- src/physical_optimizer.rs | 73 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/src/physical_optimizer.rs b/src/physical_optimizer.rs index 9a54486..56f1a2f 100644 --- a/src/physical_optimizer.rs +++ b/src/physical_optimizer.rs @@ -385,6 +385,79 @@ mod tests { "); } + #[tokio::test] + async fn test_left_join_distributed() { + let query = r#" + WITH a AS ( + SELECT + AVG("MinTemp") as "MinTemp", + "RainTomorrow" + FROM weather + WHERE "RainToday" = 'yes' + GROUP BY "RainTomorrow" + ), b AS ( + SELECT + AVG("MaxTemp") as "MaxTemp", + "RainTomorrow" + FROM weather + WHERE "RainToday" = 'no' + GROUP BY "RainTomorrow" + ) + SELECT + a."MinTemp", + b."MaxTemp" + FROM a + LEFT JOIN b + ON a."RainTomorrow" = b."RainTomorrow" + + "#; + let plan = sql_to_explain(query).await.unwrap(); + assert_snapshot!(plan, @r" + ┌───── Stage 5 Task: partitions: 0..3,unassigned] + │partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192 + │partitions [out:4 <-- in:1 ] HashJoinExec: mode=CollectLeft, join_type=Left, on=[(RainTomorrow@1, RainTomorrow@1)], projection=[MinTemp@0, MaxTemp@2] + │partitions [out:1 <-- in:4 ] CoalescePartitionsExec + │partitions [out:4 <-- in:4 ] ProjectionExec: expr=[avg(weather.MinTemp)@1 as MinTemp, RainTomorrow@0 as RainTomorrow] + │partitions [out:4 <-- in:4 ] AggregateExec: mode=FinalPartitioned, gby=[RainTomorrow@0 as RainTomorrow], aggr=[avg(weather.MinTemp)] + │partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192 + │partitions [out:4 ] ArrowFlightReadExec: Stage 2 + │partitions [out:4 <-- in:4 ] ProjectionExec: expr=[avg(weather.MaxTemp)@1 as MaxTemp, RainTomorrow@0 as RainTomorrow] + │partitions [out:4 <-- in:4 ] AggregateExec: mode=FinalPartitioned, gby=[RainTomorrow@0 as RainTomorrow], aggr=[avg(weather.MaxTemp)] + │partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192 + │partitions [out:4 ] ArrowFlightReadExec: Stage 4 + │ + └────────────────────────────────────────────────── + ┌───── Stage 4 Task: partitions: 0..3,unassigned] + │partitions [out:4 <-- in:4 ] RepartitionExec: partitioning=Hash([RainTomorrow@0], 4), input_partitions=4 + │partitions [out:4 <-- in:4 ] AggregateExec: mode=Partial, gby=[RainTomorrow@1 as RainTomorrow], aggr=[avg(weather.MaxTemp)] + │partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192 + │partitions [out:4 <-- in:4 ] FilterExec: RainToday@1 = no, projection=[MaxTemp@0, RainTomorrow@2] + │partitions [out:4 ] ArrowFlightReadExec: Stage 3 + │ + └────────────────────────────────────────────────── + ┌───── Stage 3 Task: partitions: 0..3,unassigned] + │partitions [out:4 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 + │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MaxTemp, RainToday, RainTomorrow], file_type=parquet, predicate=RainToday@1 = no, pruning_predicate=RainToday_null_count@2 != row_count@3 AND RainToday_min@0 <= no AND no <= RainToday_max@1, required_guarantees=[RainToday in (no)] + │ + │ + └────────────────────────────────────────────────── + ┌───── Stage 2 Task: partitions: 0..3,unassigned] + │partitions [out:4 <-- in:4 ] RepartitionExec: partitioning=Hash([RainTomorrow@0], 4), input_partitions=4 + │partitions [out:4 <-- in:4 ] AggregateExec: mode=Partial, gby=[RainTomorrow@1 as RainTomorrow], aggr=[avg(weather.MinTemp)] + │partitions [out:4 <-- in:4 ] CoalesceBatchesExec: target_batch_size=8192 + │partitions [out:4 <-- in:4 ] FilterExec: RainToday@1 = yes, projection=[MinTemp@0, RainTomorrow@2] + │partitions [out:4 ] ArrowFlightReadExec: Stage 1 + │ + └────────────────────────────────────────────────── + ┌───── Stage 1 Task: partitions: 0..3,unassigned] + │partitions [out:4 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 + │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[MinTemp, RainToday, RainTomorrow], file_type=parquet, predicate=RainToday@1 = yes, pruning_predicate=RainToday_null_count@2 != row_count@3 AND RainToday_min@0 <= yes AND yes <= RainToday_max@1, required_guarantees=[RainToday in (yes)] + │ + │ + └────────────────────────────────────────────────── + "); + } + #[tokio::test] async fn test_sort() { let query = r#"SELECT * FROM weather ORDER BY "MinTemp" DESC "#;