Skip to content

Commit b1c72c3

Browse files
committed
Simplify physical_optimizer
1 parent 6ed053c commit b1c72c3

File tree

1 file changed

+12
-26
lines changed

1 file changed

+12
-26
lines changed

src/physical_optimizer.rs

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -125,33 +125,32 @@ impl DistributedPhysicalOptimizerRule {
125125
&self,
126126
plan: Arc<dyn ExecutionPlan>,
127127
) -> Result<ExecutionStage, DataFusionError> {
128-
self._distribute_plan(plan, &mut 1, 0)
128+
self._distribute_plan_inner(plan, &mut 1, 0)
129129
}
130130

131-
fn _distribute_plan(
131+
fn _distribute_plan_inner(
132132
&self,
133133
plan: Arc<dyn ExecutionPlan>,
134134
num: &mut usize,
135135
depth: usize,
136136
) -> Result<ExecutionStage, DataFusionError> {
137137
let mut inputs = vec![];
138-
for reader in find_children::<ArrowFlightReadExec>(&plan) {
139-
let child = Arc::clone(reader.children().first().cloned().ok_or(
140-
internal_datafusion_err!("Expected ArrowFlightExecRead to have a child"),
141-
)?);
142-
inputs.push(self._distribute_plan(child, num, depth + 1)?);
143-
}
144-
let mut input_index = 0;
145-
let ready = plan.transform_down(|plan| {
138+
139+
let distributed = plan.transform_down(|plan| {
146140
let Some(node) = plan.as_any().downcast_ref::<ArrowFlightReadExec>() else {
147141
return Ok(Transformed::no(plan));
148142
};
149-
let node = Arc::new(node.to_distributed(inputs[input_index].num)?);
150-
input_index += 1;
143+
let child = Arc::clone(node.children().first().cloned().ok_or(
144+
internal_datafusion_err!("Expected ArrowFlightExecRead to have a child"),
145+
)?);
146+
let stage = self._distribute_plan_inner(child, num, depth + 1)?;
147+
let node = Arc::new(node.to_distributed(stage.num)?);
148+
inputs.push(stage);
151149
Ok(Transformed::new(node, true, TreeNodeRecursion::Stop))
152150
})?;
151+
153152
let inputs = inputs.into_iter().map(Arc::new).collect();
154-
let mut stage = ExecutionStage::new(*num, ready.data, inputs);
153+
let mut stage = ExecutionStage::new(*num, distributed.data, inputs);
155154
*num += 1;
156155

157156
if let Some(partitions_per_task) = self.partitions_per_task {
@@ -166,19 +165,6 @@ impl DistributedPhysicalOptimizerRule {
166165
}
167166
}
168167

169-
fn find_children<T: ExecutionPlan + 'static>(
170-
plan: &Arc<dyn ExecutionPlan>,
171-
) -> Vec<&Arc<dyn ExecutionPlan>> {
172-
if plan.as_any().is::<T>() {
173-
return vec![plan];
174-
}
175-
let mut result = vec![];
176-
for child in plan.children() {
177-
result.extend(find_children::<T>(child));
178-
}
179-
result
180-
}
181-
182168
#[cfg(test)]
183169
mod tests {
184170
use crate::assert_snapshot;

0 commit comments

Comments
 (0)