Skip to content

Commit 6ed053c

Browse files
committed
Split Distributed step in two:
1. ArrowFlightReadExec assignation 2. Stage creation
1 parent 8067b44 commit 6ed053c

File tree

12 files changed

+250
-290
lines changed

12 files changed

+250
-290
lines changed

src/flight_service/do_get.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,3 @@ impl ArrowFlightEndpoint {
7979
))))
8080
}
8181
}
82-
83-
fn invalid_argument<T>(msg: impl Into<String>) -> Result<T, Status> {
84-
Err(Status::invalid_argument(msg))
85-
}

src/physical_optimizer.rs

Lines changed: 84 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
use std::sync::Arc;
22

3+
use crate::{plan::PartitionIsolatorExec, ArrowFlightReadExec};
4+
use datafusion::common::tree_node::TreeNodeRecursion;
5+
use datafusion::error::DataFusionError;
36
use datafusion::{
47
common::{
58
internal_datafusion_err,
6-
tree_node::{Transformed, TreeNode, TreeNodeRewriter},
9+
tree_node::{Transformed, TreeNode},
710
},
811
config::ConfigOptions,
912
error::Result,
@@ -14,8 +17,6 @@ use datafusion::{
1417
};
1518
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
1619

17-
use crate::{plan::PartitionIsolatorExec, ArrowFlightReadExec};
18-
1920
use super::stage::ExecutionStage;
2021

2122
#[derive(Debug, Default)]
@@ -75,11 +76,9 @@ impl PhysicalOptimizerRule for DistributedPhysicalOptimizerRule {
7576
displayable(plan.as_ref()).indent(false)
7677
);
7778

78-
let mut planner = StagePlanner::new(self.codec.clone(), self.partitions_per_task);
79-
plan.rewrite(&mut planner)?;
80-
planner
81-
.finish()
82-
.map(|stage| stage as Arc<dyn ExecutionPlan>)
79+
let plan = self.apply_network_boundaries(plan)?;
80+
let plan = self.distribute_plan(plan)?;
81+
Ok(Arc::new(plan))
8382
}
8483

8584
fn name(&self) -> &str {
@@ -91,172 +90,93 @@ impl PhysicalOptimizerRule for DistributedPhysicalOptimizerRule {
9190
}
9291
}
9392

94-
/// StagePlanner is a TreeNodeRewriter that walks the plan tree and creates
95-
/// a tree of ExecutionStage nodes that represent discrete stages of execution
96-
/// can are separated by a data shuffle.
97-
///
98-
/// See https://howqueryengineswork.com/13-distributed-query.html for more information
99-
/// about distributed execution.
100-
struct StagePlanner {
101-
/// used to keep track of the current plan head
102-
plan_head: Option<Arc<dyn ExecutionPlan>>,
103-
/// Current depth in the plan tree, as we walk the tree
104-
depth: usize,
105-
/// Input stages collected so far. Each entry is a tuple of (plan tree depth, stage).
106-
/// This allows us to keep track of the depth in the plan tree
107-
/// where we created the stage. That way when we create a new
108-
/// stage, we can tell if it is a peer to the current input stages or
109-
/// should be a parent (if its depth is a smaller number)
110-
input_stages: Vec<(usize, ExecutionStage)>,
111-
/// current stage number
112-
stage_counter: usize,
113-
/// Optional codec to assist in serializing and deserializing any custom
114-
codec: Option<Arc<dyn PhysicalExtensionCodec>>,
115-
/// partitions_per_task is used to determine how many tasks to create for each stage
116-
partitions_per_task: Option<usize>,
117-
}
118-
119-
impl StagePlanner {
120-
fn new(
121-
codec: Option<Arc<dyn PhysicalExtensionCodec>>,
122-
partitions_per_task: Option<usize>,
123-
) -> Self {
124-
StagePlanner {
125-
plan_head: None,
126-
depth: 0,
127-
input_stages: vec![],
128-
stage_counter: 1,
129-
codec,
130-
partitions_per_task,
131-
}
132-
}
133-
134-
fn finish(mut self) -> Result<Arc<ExecutionStage>> {
135-
let stage = if self.input_stages.is_empty() {
136-
ExecutionStage::new(
137-
self.stage_counter,
138-
self.plan_head
139-
.take()
140-
.ok_or_else(|| internal_datafusion_err!("No plan head set"))?,
141-
vec![],
142-
)
143-
} else if self.depth < self.input_stages[0].0 {
144-
// There is more plan above the last stage we created, so we need to
145-
// create a new stage that includes the last plan head
146-
ExecutionStage::new(
147-
self.stage_counter,
148-
self.plan_head
149-
.take()
150-
.ok_or_else(|| internal_datafusion_err!("No plan head set"))?,
151-
self.input_stages
152-
.into_iter()
153-
.map(|(_, stage)| Arc::new(stage))
154-
.collect(),
155-
)
156-
} else {
157-
// We have a plan head, and we are at the same depth as the last stage we created,
158-
// so we can just return the last stage
159-
self.input_stages.last().unwrap().1.clone()
160-
};
161-
162-
// assign the proper tree depth to each stage in the tree
163-
fn assign_tree_depth(stage: &ExecutionStage, depth: usize) {
164-
stage
165-
.depth
166-
.store(depth as u64, std::sync::atomic::Ordering::Relaxed);
167-
for input in stage.child_stages_iter() {
168-
assign_tree_depth(input, depth + 1);
93+
impl DistributedPhysicalOptimizerRule {
94+
pub fn apply_network_boundaries(
95+
&self,
96+
plan: Arc<dyn ExecutionPlan>,
97+
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
98+
let result = plan.transform_up(|plan| {
99+
if plan.as_any().downcast_ref::<RepartitionExec>().is_some() {
100+
let child = Arc::clone(plan.children().first().cloned().ok_or(
101+
internal_datafusion_err!("Expected RepartitionExec to have a child"),
102+
)?);
103+
104+
let maybe_isolated_plan = if let Some(ppt) = self.partitions_per_task {
105+
let isolated = Arc::new(PartitionIsolatorExec::new(child, ppt));
106+
plan.with_new_children(vec![isolated])?
107+
} else {
108+
plan
109+
};
110+
111+
return Ok(Transformed::yes(Arc::new(
112+
ArrowFlightReadExec::new_single_node(
113+
Arc::clone(&maybe_isolated_plan),
114+
maybe_isolated_plan.output_partitioning().clone(),
115+
),
116+
)));
169117
}
170-
}
171-
assign_tree_depth(&stage, 0);
172118

173-
Ok(Arc::new(stage))
119+
Ok(Transformed::no(plan))
120+
})?;
121+
Ok(result.data)
174122
}
175-
}
176123

177-
impl TreeNodeRewriter for StagePlanner {
178-
type Node = Arc<dyn ExecutionPlan>;
179-
180-
fn f_down(&mut self, plan: Self::Node) -> Result<Transformed<Self::Node>> {
181-
self.depth += 1;
182-
Ok(Transformed::no(plan))
124+
pub fn distribute_plan(
125+
&self,
126+
plan: Arc<dyn ExecutionPlan>,
127+
) -> Result<ExecutionStage, DataFusionError> {
128+
self._distribute_plan(plan, &mut 1, 0)
183129
}
184130

185-
fn f_up(&mut self, plan: Self::Node) -> Result<Transformed<Self::Node>> {
186-
self.depth -= 1;
187-
188-
// keep track of where we are
189-
self.plan_head = Some(plan.clone());
190-
191-
// determine if we need to shuffle data, and thus create a new stage
192-
// at this shuffle boundary
193-
if let Some(repartition_exec) = plan.as_any().downcast_ref::<RepartitionExec>() {
194-
// time to create a stage here so include all previous seen stages deeper than us as
195-
// our input stages
196-
let child_stages = self
197-
.input_stages
198-
.iter()
199-
.rev()
200-
.take_while(|(depth, _)| *depth > self.depth)
201-
.map(|(_, stage)| stage.clone())
202-
.collect::<Vec<_>>();
203-
204-
self.input_stages.retain(|(depth, _)| *depth <= self.depth);
205-
206-
let maybe_isolated_plan = if let Some(partitions_per_task) = self.partitions_per_task {
207-
let child = repartition_exec
208-
.children()
209-
.first()
210-
.ok_or(internal_datafusion_err!(
211-
"RepartitionExec has no children, cannot create PartitionIsolatorExec"
212-
))?
213-
.clone()
214-
.clone(); // just clone the Arcs
215-
let isolated = Arc::new(PartitionIsolatorExec::new(child, partitions_per_task));
216-
plan.clone().with_new_children(vec![isolated])?
217-
} else {
218-
plan.clone()
131+
fn _distribute_plan(
132+
&self,
133+
plan: Arc<dyn ExecutionPlan>,
134+
num: &mut usize,
135+
depth: usize,
136+
) -> Result<ExecutionStage, DataFusionError> {
137+
let mut inputs = vec![];
138+
for reader in find_children::<ArrowFlightReadExec>(&plan) {
139+
let child = Arc::clone(reader.children().first().cloned().ok_or(
140+
internal_datafusion_err!("Expected ArrowFlightExecRead to have a child"),
141+
)?);
142+
inputs.push(self._distribute_plan(child, num, depth + 1)?);
143+
}
144+
let mut input_index = 0;
145+
let ready = plan.transform_down(|plan| {
146+
let Some(node) = plan.as_any().downcast_ref::<ArrowFlightReadExec>() else {
147+
return Ok(Transformed::no(plan));
219148
};
149+
let node = Arc::new(node.to_distributed(inputs[input_index].num)?);
150+
input_index += 1;
151+
Ok(Transformed::new(node, true, TreeNodeRecursion::Stop))
152+
})?;
153+
let inputs = inputs.into_iter().map(Arc::new).collect();
154+
let mut stage = ExecutionStage::new(*num, ready.data, inputs);
155+
*num += 1;
156+
157+
if let Some(partitions_per_task) = self.partitions_per_task {
158+
stage = stage.with_maximum_partitions_per_task(partitions_per_task);
159+
}
160+
if let Some(codec) = self.codec.as_ref() {
161+
stage = stage.with_codec(codec.clone());
162+
}
163+
stage.depth = depth;
220164

221-
let mut stage = ExecutionStage::new(
222-
self.stage_counter,
223-
maybe_isolated_plan,
224-
child_stages.into_iter().map(Arc::new).collect(),
225-
);
226-
227-
if let Some(partitions_per_task) = self.partitions_per_task {
228-
stage = stage.with_maximum_partitions_per_task(partitions_per_task);
229-
}
230-
if let Some(codec) = self.codec.as_ref() {
231-
stage = stage.with_codec(codec.clone());
232-
}
165+
Ok(stage)
166+
}
167+
}
233168

234-
self.input_stages.push((self.depth, stage));
235-
236-
// As we are walking up the plan tree, we've now put what we've encountered so far
237-
// into a stage. We want to replace this plan now with an ArrowFlightReadExec
238-
// which will be able to consume from this stage over the network.
239-
//
240-
// That way as we walk further up the tree and build the next stage, the leaf
241-
// node in that plan will be an ArrowFlightReadExec that can read from
242-
//
243-
// Note that we use the original plans partitioning and schema for ArrowFlightReadExec.
244-
// If we divide it up in to tasks, then that parittion will need to be gathered from
245-
// among them
246-
let name = format!("Stage {:<3}", self.stage_counter);
247-
let read = Arc::new(ArrowFlightReadExec::new(
248-
plan.output_partitioning().clone(),
249-
plan.schema(),
250-
self.stage_counter,
251-
));
252-
253-
self.stage_counter += 1;
254-
255-
Ok(Transformed::yes(read as Self::Node))
256-
} else {
257-
Ok(Transformed::no(plan))
258-
}
169+
fn find_children<T: ExecutionPlan + 'static>(
170+
plan: &Arc<dyn ExecutionPlan>,
171+
) -> Vec<&Arc<dyn ExecutionPlan>> {
172+
if plan.as_any().is::<T>() {
173+
return vec![plan];
174+
}
175+
let mut result = vec![];
176+
for child in plan.children() {
177+
result.extend(find_children::<T>(child));
259178
}
179+
result
260180
}
261181

262182
#[cfg(test)]

0 commit comments

Comments
 (0)