@@ -10,22 +10,97 @@ use datafusion::{
1010 config:: ConfigOptions ,
1111 datasource:: physical_plan:: { FileScanConfig , FileSource } ,
1212 error:: Result ,
13+ physical_optimizer:: PhysicalOptimizerRule ,
1314 physical_plan:: {
1415 displayable, execution_plan:: need_data_exchange, ExecutionPlan , ExecutionPlanProperties ,
1516 } ,
1617} ;
18+ use datafusion_proto:: physical_plan:: PhysicalExtensionCodec ;
1719
1820use crate :: ArrowFlightReadExec ;
1921
2022use super :: stage:: ExecutionStage ;
2123
24+ #[ derive( Debug , Default ) ]
25+ pub struct DistributedPhysicalOptimizerRule {
26+ /// Optional codec to assist in serializing and deserializing any custom
27+ /// ExecutionPlan nodes
28+ codec : Option < Arc < dyn PhysicalExtensionCodec > > ,
29+ /// maximum number of partitions per task. This is used to determine how many
30+ /// tasks to create for each stage
31+ partitions_per_task : Option < usize > ,
32+ }
33+
34+ impl DistributedPhysicalOptimizerRule {
35+ pub fn new ( ) -> Self {
36+ DistributedPhysicalOptimizerRule {
37+ codec : None ,
38+ partitions_per_task : None ,
39+ }
40+ }
41+
42+ /// Set a codec to use to assist in serializing and deserializing
43+ /// custom ExecutionPlan nodes.
44+ pub fn with_codec ( mut self , codec : Arc < dyn PhysicalExtensionCodec > ) -> Self {
45+ self . codec = Some ( codec) ;
46+ self
47+ }
48+
49+ /// Set the maximum number of partitions per task. This is used to determine how many
50+ /// tasks to create for each stage.
51+ ///
52+ /// If a stage holds a plan with 10 partitions, and this is set to 3,
53+ /// then the stage will be split into 4 tasks:
54+ /// - Task 1: partitions 0, 1, 2
55+ /// - Task 2: partitions 3, 4, 5
56+ /// - Task 3: partitions 6, 7, 8
57+ /// - Task 4: partitions 9
58+ ///
59+ /// Each task will be executed on a separate host
60+ pub fn with_maximum_partitions_per_task ( mut self , partitions_per_task : usize ) -> Self {
61+ self . partitions_per_task = Some ( partitions_per_task) ;
62+ self
63+ }
64+ }
65+
66+ impl PhysicalOptimizerRule for DistributedPhysicalOptimizerRule {
67+ fn optimize (
68+ & self ,
69+ plan : Arc < dyn ExecutionPlan > ,
70+ _config : & ConfigOptions ,
71+ ) -> Result < Arc < dyn ExecutionPlan > > {
72+ // We can only optimize plans that are not already distributed
73+ if plan. as_any ( ) . is :: < ExecutionStage > ( ) {
74+ return Ok ( plan) ;
75+ }
76+ println ! (
77+ "DistributedPhysicalOptimizerRule: optimizing plan: {}" ,
78+ displayable( plan. as_ref( ) ) . indent( false )
79+ ) ;
80+
81+ let mut planner = StagePlanner :: new ( self . codec . clone ( ) , self . partitions_per_task ) ;
82+ plan. rewrite ( & mut planner) ?;
83+ planner
84+ . finish ( )
85+ . map ( |stage| stage as Arc < dyn ExecutionPlan > )
86+ }
87+
88+ fn name ( & self ) -> & str {
89+ "DistributedPhysicalOptimizer"
90+ }
91+
92+ fn schema_check ( & self ) -> bool {
93+ true
94+ }
95+ }
96+
2297/// StagePlanner is a TreeNodeRewriter that walks the plan tree and creates
2398/// a tree of ExecutionStage nodes that represent discrete stages of execution
2499/// can are separated by a data shuffle.
25100///
26101/// See https://howqueryengineswork.com/13-distributed-query.html for more information
27102/// about distributed execution.
28- pub struct StagePlanner {
103+ struct StagePlanner {
29104 /// used to keep track of the current plan head
30105 plan_head : Option < Arc < dyn ExecutionPlan > > ,
31106 /// Current depth in the plan tree, as we walk the tree
@@ -38,48 +113,39 @@ pub struct StagePlanner {
38113 input_stages : Vec < ( usize , Arc < ExecutionStage > ) > ,
39114 /// current stage number
40115 stage_counter : usize ,
41- }
42-
43- /// Create an ExecutionStage from a plan.
44- ///
45- /// The resulting ExecutionStage cannot be executed directly, but is used as input,
46- /// together with a vec of Worker Addresses, to further break up the stages into [`ExecutionTask`]s
47- /// Which are ultimately what facilitates distributed execution.
48- impl TryFrom < Arc < dyn ExecutionPlan > > for ExecutionStage {
49- type Error = DataFusionError ;
50- fn try_from ( plan : Arc < dyn ExecutionPlan > ) -> Result < Self > {
51- let mut planner = StagePlanner {
52- plan_head : Some ( plan. clone ( ) ) ,
53- depth : 0 ,
54- input_stages : vec ! [ ] ,
55- stage_counter : 1 ,
56- } ;
57-
58- plan. rewrite ( & mut planner) ?;
59- planner. finish ( ) . and_then ( |arc_stage| {
60- Arc :: into_inner ( arc_stage) . ok_or ( internal_datafusion_err ! (
61- "Failed to convert Arc<ExecutionStage> to ExecutionStage"
62- ) )
63- } )
64- }
116+ /// Optional codec to assist in serializing and deserializing any custom
117+ codec : Option < Arc < dyn PhysicalExtensionCodec > > ,
118+ /// partitions_per_task is used to determine how many tasks to create for each stage
119+ partitions_per_task : Option < usize > ,
65120}
66121
67122impl StagePlanner {
68- pub fn new ( ) -> Self {
123+ fn new (
124+ codec : Option < Arc < dyn PhysicalExtensionCodec > > ,
125+ partitions_per_task : Option < usize > ,
126+ ) -> Self {
69127 StagePlanner {
70128 plan_head : None ,
71129 depth : 0 ,
72130 input_stages : vec ! [ ] ,
73- stage_counter : 0 ,
131+ stage_counter : 1 ,
132+ codec,
133+ partitions_per_task,
74134 }
75135 }
76136
77- pub fn finish ( mut self ) -> Result < Arc < ExecutionStage > > {
137+ fn finish ( mut self ) -> Result < Arc < ExecutionStage > > {
78138 if self . input_stages . is_empty ( ) {
79- return internal_err ! ( "No input stages found, did you forget to call rewrite()?" ) ;
80- }
81-
82- if self . depth < self . input_stages [ 0 ] . 0 {
139+ Ok ( Arc :: new ( ExecutionStage :: new (
140+ self . stage_counter ,
141+ self . plan_head
142+ . take ( )
143+ . ok_or_else ( || internal_datafusion_err ! ( "No plan head set" ) ) ?,
144+ vec ! [ ] ,
145+ ) ) )
146+ } else if self . depth < self . input_stages [ 0 ] . 0 {
147+ // There is more plan above the last stage we created, so we need to
148+ // create a new stage that includes the last plan head
83149 Ok ( Arc :: new ( ExecutionStage :: new (
84150 self . stage_counter ,
85151 self . plan_head
@@ -88,10 +154,12 @@ impl StagePlanner {
88154 self . input_stages
89155 . iter ( )
90156 . map ( |( _, stage) | stage. clone ( ) )
91- . collect :: < Vec < _ > > ( ) ,
157+ . collect ( ) ,
92158 ) ) )
93159 } else {
94- Ok ( self . input_stages . remove ( 0 ) . 1 )
160+ // We have a plan head, and we are at the same depth as the last stage we created,
161+ // so we can just return the last stage
162+ Ok ( self . input_stages . last ( ) . unwrap ( ) . 1 . clone ( ) )
95163 }
96164 }
97165}
@@ -142,13 +210,16 @@ impl TreeNodeRewriter for StagePlanner {
142210
143211 self . input_stages . retain ( |( depth, _) | * depth <= self . depth ) ;
144212
145- let stage = Arc :: new ( ExecutionStage :: new (
146- self . stage_counter ,
147- plan. clone ( ) ,
148- child_stages,
149- ) ) ;
213+ let mut stage = ExecutionStage :: new ( self . stage_counter , plan. clone ( ) , child_stages) ;
214+
215+ if let Some ( partitions_per_task) = self . partitions_per_task {
216+ stage = stage. with_maximum_partitions_per_task ( partitions_per_task) ;
217+ }
218+ if let Some ( codec) = self . codec . as_ref ( ) {
219+ stage = stage. with_codec ( codec. clone ( ) ) ;
220+ }
150221
151- self . input_stages . push ( ( self . depth , stage) ) ;
222+ self . input_stages . push ( ( self . depth , Arc :: new ( stage) ) ) ;
152223
153224 // As we are walking up the plan tree, we've now put what we've encountered so far
154225 // into a stage. We want to replace this plan now with an ArrowFlightReadExec
0 commit comments