From 2359329dacb3fe1926930f679e4934b1dd0f7ae3 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Wed, 3 Sep 2025 16:58:13 +0200 Subject: [PATCH] ArrowFlightReadExec should always have a RepartitionExec as a child + uncomment working tests --- src/physical_optimizer.rs | 7 +--- src/plan/arrow_flight_read.rs | 8 ++-- src/test_utils/mod.rs | 1 - src/test_utils/plan.rs | 65 ------------------------------- tests/custom_config_extension.rs | 8 ++-- tests/custom_extension_codec.rs | 33 ++++++++-------- tests/distributed_aggregation.rs | 38 ++++++++++-------- tests/error_propagation.rs | 8 ++-- tests/highly_distributed_query.rs | 40 ++++++++++++++----- 9 files changed, 83 insertions(+), 125 deletions(-) delete mode 100644 src/test_utils/plan.rs diff --git a/src/physical_optimizer.rs b/src/physical_optimizer.rs index 2b3489b..eddc646 100644 --- a/src/physical_optimizer.rs +++ b/src/physical_optimizer.rs @@ -13,7 +13,7 @@ use datafusion::{ config::ConfigOptions, error::Result, physical_optimizer::PhysicalOptimizerRule, - physical_plan::{repartition::RepartitionExec, ExecutionPlan, ExecutionPlanProperties}, + physical_plan::{repartition::RepartitionExec, ExecutionPlan}, }; use uuid::Uuid; @@ -94,10 +94,7 @@ impl DistributedPhysicalOptimizerRule { }; return Ok(Transformed::yes(Arc::new( - ArrowFlightReadExec::new_pending( - Arc::clone(&maybe_isolated_plan), - maybe_isolated_plan.output_partitioning().clone(), - ), + ArrowFlightReadExec::new_pending(Arc::clone(&maybe_isolated_plan)), ))); } diff --git a/src/plan/arrow_flight_read.rs b/src/plan/arrow_flight_read.rs index 44d5f9b..d3b9ac5 100644 --- a/src/plan/arrow_flight_read.rs +++ b/src/plan/arrow_flight_read.rs @@ -16,7 +16,9 @@ use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; -use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, +}; use futures::{StreamExt, TryFutureExt, TryStreamExt}; use http::Extensions; use prost::Message; @@ -57,11 +59,11 @@ pub struct ArrowFlightReadReadyExec { } impl ArrowFlightReadExec { - pub fn new_pending(child: Arc, partitioning: Partitioning) -> Self { + pub fn new_pending(child: Arc) -> Self { Self::Pending(ArrowFlightReadPendingExec { properties: PlanProperties::new( EquivalenceProperties::new(child.schema()), - partitioning, + child.output_partitioning().clone(), EmissionType::Incremental, Boundedness::Bounded, ), diff --git a/src/test_utils/mod.rs b/src/test_utils/mod.rs index 9b8e408..2ef71e3 100644 --- a/src/test_utils/mod.rs +++ b/src/test_utils/mod.rs @@ -2,5 +2,4 @@ pub mod insta; pub mod localhost; pub mod mock_exec; pub mod parquet; -pub mod plan; pub mod tpch; diff --git a/src/test_utils/plan.rs b/src/test_utils/plan.rs deleted file mode 100644 index b545451..0000000 --- a/src/test_utils/plan.rs +++ /dev/null @@ -1,65 +0,0 @@ -use crate::{ArrowFlightReadExec, DistributedPhysicalOptimizerRule}; -use datafusion::common::plan_err; -use datafusion::common::tree_node::{Transformed, TreeNode}; -use datafusion::error::DataFusionError; -use datafusion::physical_expr::Partitioning; -use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode}; -use datafusion::physical_plan::ExecutionPlan; -use std::sync::Arc; - -pub fn distribute_aggregate( - plan: Arc, -) -> Result, DataFusionError> { - let mut aggregate_partial_found = false; - let transformed = plan.transform_up(|node| { - let Some(agg) = node.as_any().downcast_ref::() else { - return Ok(Transformed::no(node)); - }; - - match agg.mode() { - AggregateMode::Partial => { - if aggregate_partial_found { - return plan_err!("Two consecutive partial aggregations found"); - } - aggregate_partial_found = true; - let expr = agg - .group_expr() - .expr() - .iter() - .map(|(v, _)| Arc::clone(v)) - .collect::>(); - - if node.children().len() != 1 { - return plan_err!("Aggregate must have exactly one child"); - } - let child = node.children()[0].clone(); - - let node = node.with_new_children(vec![Arc::new( - ArrowFlightReadExec::new_pending(child, Partitioning::Hash(expr, 1)), - )])?; - Ok(Transformed::yes(node)) - } - AggregateMode::Final - | AggregateMode::FinalPartitioned - | AggregateMode::Single - | AggregateMode::SinglePartitioned => { - if !aggregate_partial_found { - return plan_err!("No partial aggregate found before the final one"); - } - - if node.children().len() != 1 { - return plan_err!("Aggregate must have exactly one child"); - } - let child = node.children()[0].clone(); - - let node = node.with_new_children(vec![Arc::new( - ArrowFlightReadExec::new_pending(child, Partitioning::RoundRobinBatch(8)), - )])?; - Ok(Transformed::yes(node)) - } - } - })?; - Ok(Arc::new( - DistributedPhysicalOptimizerRule::default().distribute_plan(transformed.data)?, - )) -} diff --git a/tests/custom_config_extension.rs b/tests/custom_config_extension.rs index a08b615..5560dfe 100644 --- a/tests/custom_config_extension.rs +++ b/tests/custom_config_extension.rs @@ -9,6 +9,7 @@ mod tests { }; use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; + use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, @@ -46,10 +47,9 @@ mod tests { let mut plan: Arc = Arc::new(CustomConfigExtensionRequiredExec::new()); for size in [1, 2, 3] { - plan = Arc::new(ArrowFlightReadExec::new_pending( - plan, - Partitioning::RoundRobinBatch(size), - )); + plan = Arc::new(ArrowFlightReadExec::new_pending(Arc::new( + RepartitionExec::try_new(plan, Partitioning::RoundRobinBatch(size))?, + ))); } let plan = DistributedPhysicalOptimizerRule::default().distribute_plan(plan)?; diff --git a/tests/custom_extension_codec.rs b/tests/custom_extension_codec.rs index fff1c55..4b9fb43 100644 --- a/tests/custom_extension_codec.rs +++ b/tests/custom_extension_codec.rs @@ -36,7 +36,6 @@ mod tests { use std::sync::Arc; #[tokio::test] - #[ignore] async fn custom_extension_codec() -> Result<(), Box> { async fn build_state( ctx: DistributedSessionBuilderContext, @@ -66,17 +65,16 @@ mod tests { │partitions [out:1 <-- in:1 ] SortExec: expr=[numbers@0 DESC NULLS LAST], preserve_partitioning=[false] │partitions [out:1 <-- in:10 ] RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=10 │partitions [out:10 ] ArrowFlightReadExec: Stage 2 - │ └────────────────────────────────────────────────── - ┌───── Stage 2 Task: partitions: 0,unassigned] - │partitions [out:1 <-- in:1 ] SortExec: expr=[numbers@0 DESC NULLS LAST], preserve_partitioning=[false] - │partitions [out:1 ] ArrowFlightReadExec: Stage 1 - │ + ┌───── Stage 2 Task: partitions: 0..9,unassigned] + │partitions [out:10 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + │partitions [out:1 <-- in:1 ] SortExec: expr=[numbers@0 DESC NULLS LAST], preserve_partitioning=[false] + │partitions [out:1 ] ArrowFlightReadExec: Stage 1 └────────────────────────────────────────────────── ┌───── Stage 1 Task: partitions: 0,unassigned] - │partitions [out:1 <-- in:1 ] FilterExec: numbers@0 > 1 - │partitions [out:1 ] Int64ListExec: length=6 - │ + │partitions [out:1 <-- in:1 ] RepartitionExec: partitioning=Hash([numbers@0], 1), input_partitions=1 + │partitions [out:1 <-- in:1 ] FilterExec: numbers@0 > 1 + │partitions [out:1 ] Int64ListExec: length=6 └────────────────────────────────────────────────── "); @@ -125,10 +123,12 @@ mod tests { )?); if distributed { - plan = Arc::new(ArrowFlightReadExec::new_pending( - plan.clone(), - Partitioning::Hash(vec![col("numbers", &plan.schema())?], 1), - )); + plan = Arc::new(ArrowFlightReadExec::new_pending(Arc::new( + RepartitionExec::try_new( + plan.clone(), + Partitioning::Hash(vec![col("numbers", &plan.schema())?], 1), + )?, + ))); } plan = Arc::new(SortExec::new( @@ -141,10 +141,9 @@ mod tests { )); if distributed { - plan = Arc::new(ArrowFlightReadExec::new_pending( - plan.clone(), - Partitioning::RoundRobinBatch(10), - )); + plan = Arc::new(ArrowFlightReadExec::new_pending(Arc::new( + RepartitionExec::try_new(plan.clone(), Partitioning::RoundRobinBatch(10))?, + ))); plan = Arc::new(RepartitionExec::try_new( plan, diff --git a/tests/distributed_aggregation.rs b/tests/distributed_aggregation.rs index 273b332..4794a5f 100644 --- a/tests/distributed_aggregation.rs +++ b/tests/distributed_aggregation.rs @@ -1,11 +1,13 @@ #[cfg(all(feature = "integration", test))] mod tests { use datafusion::arrow::util::pretty::pretty_format_batches; + use datafusion::physical_optimizer::PhysicalOptimizerRule; use datafusion::physical_plan::{displayable, execute_stream}; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::test_utils::parquet::register_parquet_tables; - use datafusion_distributed::test_utils::plan::distribute_aggregate; - use datafusion_distributed::{assert_snapshot, DefaultSessionBuilder}; + use datafusion_distributed::{ + assert_snapshot, DefaultSessionBuilder, DistributedPhysicalOptimizerRule, + }; use futures::TryStreamExt; use std::error::Error; @@ -21,7 +23,9 @@ mod tests { let physical_str = displayable(physical.as_ref()).indent(true).to_string(); - let physical_distributed = distribute_aggregate(physical.clone())?; + let physical_distributed = DistributedPhysicalOptimizerRule::default() + .with_maximum_partitions_per_task(1) + .optimize(physical.clone(), &Default::default())?; let physical_distributed_str = displayable(physical_distributed.as_ref()) .indent(true) @@ -48,21 +52,23 @@ mod tests { @r" ┌───── Stage 3 Task: partitions: 0,unassigned] │partitions [out:1 <-- in:1 ] ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday] - │partitions [out:1 <-- in:8 ] SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST] - │partitions [out:8 <-- in:8 ] SortExec: expr=[count(*)@0 ASC NULLS LAST], preserve_partitioning=[true] - │partitions [out:8 <-- in:8 ] ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))] - │partitions [out:8 <-- in:8 ] AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] - │partitions [out:8 ] ArrowFlightReadExec: Stage 2 + │partitions [out:1 <-- in:3 ] SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST] + │partitions [out:3 <-- in:3 ] SortExec: expr=[count(*)@0 ASC NULLS LAST], preserve_partitioning=[true] + │partitions [out:3 <-- in:3 ] ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))] + │partitions [out:3 <-- in:3 ] AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] + │partitions [out:3 <-- in:3 ] CoalesceBatchesExec: target_batch_size=8192 + │partitions [out:3 ] ArrowFlightReadExec: Stage 2 └────────────────────────────────────────────────── - ┌───── Stage 2 Task: partitions: 0..2,unassigned] - │partitions [out:3 <-- in:3 ] CoalesceBatchesExec: target_batch_size=8192 - │partitions [out:3 <-- in:3 ] RepartitionExec: partitioning=Hash([RainToday@0], 3), input_partitions=3 - │partitions [out:3 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(3), input_partitions=1 - │partitions [out:1 <-- in:1 ] AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] - │partitions [out:1 ] ArrowFlightReadExec: Stage 1 + ┌───── Stage 2 Task: partitions: 0,unassigned],Task: partitions: 1,unassigned],Task: partitions: 2,unassigned] + │partitions [out:3 <-- in:1 ] RepartitionExec: partitioning=Hash([RainToday@0], 3), input_partitions=1 + │partitions [out:1 <-- in:3 ] PartitionIsolatorExec [providing upto 1 partitions] + │partitions [out:3 ] ArrowFlightReadExec: Stage 1 └────────────────────────────────────────────────── - ┌───── Stage 1 Task: partitions: 0,unassigned] - │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet + ┌───── Stage 1 Task: partitions: 0,unassigned],Task: partitions: 1,unassigned],Task: partitions: 2,unassigned] + │partitions [out:3 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(3), input_partitions=1 + │partitions [out:1 <-- in:1 ] PartitionIsolatorExec [providing upto 1 partitions] + │partitions [out:1 <-- in:1 ] AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] + │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet └────────────────────────────────────────────────── ", ); diff --git a/tests/error_propagation.rs b/tests/error_propagation.rs index e4424fc..f1da05b 100644 --- a/tests/error_propagation.rs +++ b/tests/error_propagation.rs @@ -7,6 +7,7 @@ mod tests { }; use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; + use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, @@ -42,10 +43,9 @@ mod tests { let mut plan: Arc = Arc::new(ErrorExec::new("something failed")); for size in [1, 2, 3] { - plan = Arc::new(ArrowFlightReadExec::new_pending( - plan, - Partitioning::RoundRobinBatch(size), - )); + plan = Arc::new(ArrowFlightReadExec::new_pending(Arc::new( + RepartitionExec::try_new(plan, Partitioning::RoundRobinBatch(size))?, + ))); } let plan = DistributedPhysicalOptimizerRule::default().distribute_plan(plan)?; let stream = execute_stream(Arc::new(plan), ctx.task_ctx())?; diff --git a/tests/highly_distributed_query.rs b/tests/highly_distributed_query.rs index c2b8160..9c13365 100644 --- a/tests/highly_distributed_query.rs +++ b/tests/highly_distributed_query.rs @@ -1,16 +1,19 @@ #[cfg(all(feature = "integration", test))] mod tests { use datafusion::physical_expr::Partitioning; + use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::{displayable, execute_stream}; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::test_utils::parquet::register_parquet_tables; - use datafusion_distributed::{assert_snapshot, ArrowFlightReadExec, DefaultSessionBuilder}; + use datafusion_distributed::{ + assert_snapshot, ArrowFlightReadExec, DefaultSessionBuilder, + DistributedPhysicalOptimizerRule, + }; use futures::TryStreamExt; use std::error::Error; use std::sync::Arc; #[tokio::test] - #[ignore] async fn highly_distributed_query() -> Result<(), Box> { let (ctx, _guard) = start_localhost_context(9, DefaultSessionBuilder).await; register_parquet_tables(&ctx).await?; @@ -21,11 +24,17 @@ mod tests { let mut physical_distributed = physical.clone(); for size in [1, 10, 5] { - physical_distributed = Arc::new(ArrowFlightReadExec::new_pending( - physical_distributed, - Partitioning::RoundRobinBatch(size), - )); + physical_distributed = Arc::new(ArrowFlightReadExec::new_pending(Arc::new( + RepartitionExec::try_new( + physical_distributed, + Partitioning::RoundRobinBatch(size), + )?, + ))); } + + let physical_distributed = + DistributedPhysicalOptimizerRule::default().distribute_plan(physical_distributed)?; + let physical_distributed = Arc::new(physical_distributed); let physical_distributed_str = displayable(physical_distributed.as_ref()) .indent(true) .to_string(); @@ -36,10 +45,21 @@ mod tests { assert_snapshot!(physical_distributed_str, @r" - ArrowFlightReadExec: input_tasks=5 hash_expr=[] stage_id=UUID input_stage_id=UUID input_hosts=[http://localhost:50050/, http://localhost:50051/, http://localhost:50053/, http://localhost:50054/, http://localhost:50055/] - ArrowFlightReadExec: input_tasks=10 hash_expr=[] stage_id=UUID input_stage_id=UUID input_hosts=[http://localhost:50056/, http://localhost:50057/, http://localhost:50058/, http://localhost:50059/, http://localhost:50050/, http://localhost:50051/, http://localhost:50053/, http://localhost:50054/, http://localhost:50055/, http://localhost:50056/] - ArrowFlightReadExec: input_tasks=1 hash_expr=[] stage_id=UUID input_stage_id=UUID input_hosts=[http://localhost:50057/] - DataSourceExec: file_groups={1 group: [[/testdata/flights-1m.parquet]]}, projection=[FL_DATE, DEP_DELAY, ARR_DELAY, AIR_TIME, DISTANCE, DEP_TIME, ARR_TIME], file_type=parquet + ┌───── Stage 4 Task: partitions: 0..4,unassigned] + │partitions [out:5 ] ArrowFlightReadExec: Stage 3 + └────────────────────────────────────────────────── + ┌───── Stage 3 Task: partitions: 0..4,unassigned] + │partitions [out:5 <-- in:10 ] RepartitionExec: partitioning=RoundRobinBatch(5), input_partitions=10 + │partitions [out:10 ] ArrowFlightReadExec: Stage 2 + └────────────────────────────────────────────────── + ┌───── Stage 2 Task: partitions: 0..9,unassigned] + │partitions [out:10 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + │partitions [out:1 ] ArrowFlightReadExec: Stage 1 + └────────────────────────────────────────────────── + ┌───── Stage 1 Task: partitions: 0,unassigned] + │partitions [out:1 <-- in:1 ] RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=1 + │partitions [out:1 ] DataSourceExec: file_groups={1 group: [[/testdata/flights-1m.parquet]]}, projection=[FL_DATE, DEP_DELAY, ARR_DELAY, AIR_TIME, DISTANCE, DEP_TIME, ARR_TIME], file_type=parquet + └────────────────────────────────────────────────── ", );