11use std:: sync:: Arc ;
22
33use super :: stage:: ExecutionStage ;
4+ use crate :: common:: util:: can_be_divided;
45use crate :: { plan:: PartitionIsolatorExec , ArrowFlightReadExec } ;
56use datafusion:: common:: tree_node:: TreeNodeRecursion ;
67use datafusion:: error:: DataFusionError ;
@@ -83,12 +84,14 @@ impl DistributedPhysicalOptimizerRule {
8384 internal_datafusion_err ! ( "Expected RepartitionExec to have a child" ) ,
8485 ) ?) ;
8586
86- let maybe_isolated_plan = if let Some ( ppt) = self . partitions_per_task {
87- let isolated = Arc :: new ( PartitionIsolatorExec :: new ( child, ppt) ) ;
88- plan. with_new_children ( vec ! [ isolated] ) ?
89- } else {
90- plan
91- } ;
87+ let maybe_isolated_plan =
88+ if can_be_divided ( & plan) ? && self . partitions_per_task . is_some ( ) {
89+ let ppt = self . partitions_per_task . unwrap ( ) ;
90+ let isolated = Arc :: new ( PartitionIsolatorExec :: new ( child, ppt) ) ;
91+ plan. with_new_children ( vec ! [ isolated] ) ?
92+ } else {
93+ plan
94+ } ;
9295
9396 return Ok ( Transformed :: yes ( Arc :: new (
9497 ArrowFlightReadExec :: new_pending (
@@ -120,7 +123,7 @@ impl DistributedPhysicalOptimizerRule {
120123 ) -> Result < ExecutionStage , DataFusionError > {
121124 let mut inputs = vec ! [ ] ;
122125
123- let distributed = plan. transform_down ( |plan| {
126+ let distributed = plan. clone ( ) . transform_down ( |plan| {
124127 let Some ( node) = plan. as_any ( ) . downcast_ref :: < ArrowFlightReadExec > ( ) else {
125128 return Ok ( Transformed :: no ( plan) ) ;
126129 } ;
@@ -137,9 +140,13 @@ impl DistributedPhysicalOptimizerRule {
137140 let mut stage = ExecutionStage :: new ( query_id, * num, distributed. data , inputs) ;
138141 * num += 1 ;
139142
140- if let Some ( partitions_per_task) = self . partitions_per_task {
141- stage = stage. with_maximum_partitions_per_task ( partitions_per_task) ;
142- }
143+ stage = match ( self . partitions_per_task , can_be_divided ( & plan) ?) {
144+ ( Some ( partitions_per_task) , true ) => {
145+ stage. with_maximum_partitions_per_task ( partitions_per_task)
146+ }
147+ ( _, _) => stage,
148+ } ;
149+
143150 stage. depth = depth;
144151
145152 Ok ( stage)
0 commit comments