Skip to content

Commit 84243d3

Browse files
authored
Apply TaskEstimator on all nodes (#271)
* Split channel resolver in two * Simplify WorkerResolverExtension and ChannelResolverExtension * Add default builder to ArrowFlightEndpoint * Add some docs * Listen to clippy * Split get_flight_client_for_url in two * Fix conflicts * Remove unnecessary channel resolver * Improve WorkerResolver docs * Use one ChannelResolver per runtime * Improve error reporting on client connection failure * Add a from_session_builder method for constructing an InMemoryChannelResolver * Add ChannelResolver and WorkerResolver default implementations for Arcs * Make TPC-DS tests use DataFusion test dataset * Remove non-working in-memory option from benchmarks * Remove unnecessary utils folder * Refactor benchmark folder * Rename to prepare_tpch.rs * Adapt benchmarks for TPC-DS * Update benchmarks README.md * Fix conflicts * Use default session state builder * Add ChildrenIsolatorUnionExec * Add proto serde for ChildrenIsolatorUnionExec * Wire up ChildrenIsolatorUnionExec to planner * Add integration tests for distributed UNIONs * Skip query 72 in TPC-DS benchmarks * Allow setting children isolator unions * Allow passing multiple queries in benchmarks * Extend TaskEstimator API to also be applicable to intermediate nodes
1 parent a4b0b5a commit 84243d3

File tree

2 files changed

+178
-57
lines changed

2 files changed

+178
-57
lines changed

src/distributed_planner/plan_annotator.rs

Lines changed: 110 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ fn _annotate_plan(
138138
) -> Result<AnnotatedPlan, DataFusionError> {
139139
use TaskCountAnnotation::*;
140140
let d_cfg = DistributedConfig::from_config_options(cfg)?;
141+
let estimator = &d_cfg.__private_task_estimator;
141142
let n_workers = d_cfg.__private_worker_resolver.0.get_urls()?.len().max(1);
142143

143144
let annotated_children = plan
@@ -150,8 +151,7 @@ fn _annotate_plan(
150151
// This is a leaf node, maybe a DataSourceExec, or maybe something else custom from the
151152
// user. We need to estimate how many tasks are needed for this leaf node, and we'll take
152153
// this decision into account when deciding how many tasks will be actually used.
153-
let estimator = &d_cfg.__private_task_estimator;
154-
if let Some(estimate) = estimator.tasks_for_leaf_node(&plan, cfg) {
154+
if let Some(estimate) = estimator.task_estimation(&plan, cfg) {
155155
return Ok(AnnotatedPlan {
156156
plan,
157157
children: Vec::new(),
@@ -170,7 +170,9 @@ fn _annotate_plan(
170170
}
171171
}
172172

173-
let mut task_count = Desired(1);
173+
let mut task_count = estimator
174+
.task_estimation(&plan, cfg)
175+
.map_or(Desired(1), |v| v.task_count);
174176
if d_cfg.children_isolator_unions && plan.as_any().is::<UnionExec>() {
175177
// Unions have the chance to decide how many tasks they should run on. If there's a union
176178
// with a bunch of children, the user might want to increase parallelism and increase the
@@ -341,11 +343,11 @@ mod tests {
341343
use super::*;
342344
use crate::test_utils::in_memory_channel_resolver::InMemoryWorkerResolver;
343345
use crate::test_utils::parquet::register_parquet_tables;
344-
use crate::{DistributedExt, assert_snapshot};
346+
use crate::{DistributedExt, TaskEstimation, assert_snapshot};
345347
use datafusion::execution::SessionStateBuilder;
348+
use datafusion::physical_plan::filter::FilterExec;
346349
use datafusion::prelude::{SessionConfig, SessionContext};
347350
use itertools::Itertools;
348-
349351
/* schema for the "weather" table
350352
351353
MinTemp [type=DOUBLE] [repetitiontype=OPTIONAL]
@@ -584,16 +586,116 @@ mod tests {
584586
")
585587
}
586588

589+
#[tokio::test]
590+
async fn test_intermediate_task_estimator() {
591+
let query = r#"
592+
SELECT DISTINCT "RainToday" FROM weather
593+
"#;
594+
let annotated = sql_to_annotated_with_estimator(query, |_: &RepartitionExec| {
595+
Some(TaskEstimation::maximum(1))
596+
})
597+
.await;
598+
assert_snapshot!(annotated, @r"
599+
AggregateExec: task_count=Desired(1)
600+
CoalesceBatchesExec: task_count=Desired(1), required_network_boundary=Shuffle
601+
RepartitionExec: task_count=Maximum(1)
602+
RepartitionExec: task_count=Maximum(1)
603+
AggregateExec: task_count=Maximum(1)
604+
DataSourceExec: task_count=Maximum(1)
605+
")
606+
}
607+
608+
#[tokio::test]
609+
async fn test_union_all_limited_by_intermediate_estimator() {
610+
let query = r#"
611+
SELECT "MinTemp" FROM weather WHERE "RainToday" = 'yes'
612+
UNION ALL
613+
SELECT "MaxTemp" FROM weather WHERE "RainToday" = 'no'
614+
"#;
615+
let annotated = sql_to_annotated_with_estimator(query, |_: &FilterExec| {
616+
Some(TaskEstimation::maximum(1))
617+
})
618+
.await;
619+
assert_snapshot!(annotated, @r"
620+
ChildrenIsolatorUnionExec: task_count=Desired(2)
621+
CoalesceBatchesExec: task_count=Maximum(1)
622+
FilterExec: task_count=Maximum(1)
623+
RepartitionExec: task_count=Maximum(1)
624+
DataSourceExec: task_count=Maximum(1)
625+
ProjectionExec: task_count=Maximum(1)
626+
CoalesceBatchesExec: task_count=Maximum(1)
627+
FilterExec: task_count=Maximum(1)
628+
RepartitionExec: task_count=Maximum(1)
629+
DataSourceExec: task_count=Maximum(1)
630+
")
631+
}
632+
633+
#[allow(clippy::type_complexity)]
634+
struct CallbackEstimator {
635+
f: Arc<dyn Fn(&(dyn ExecutionPlan)) -> Option<TaskEstimation> + Send + Sync>,
636+
}
637+
638+
impl CallbackEstimator {
639+
fn new<T: ExecutionPlan + 'static>(
640+
f: impl Fn(&T) -> Option<TaskEstimation> + Send + Sync + 'static,
641+
) -> Self {
642+
let f = Arc::new(move |plan: &dyn ExecutionPlan| -> Option<TaskEstimation> {
643+
if let Some(plan) = plan.as_any().downcast_ref::<T>() {
644+
f(plan)
645+
} else {
646+
None
647+
}
648+
});
649+
Self { f }
650+
}
651+
}
652+
653+
impl TaskEstimator for CallbackEstimator {
654+
fn task_estimation(
655+
&self,
656+
plan: &Arc<dyn ExecutionPlan>,
657+
_: &ConfigOptions,
658+
) -> Option<TaskEstimation> {
659+
(self.f)(plan.as_ref())
660+
}
661+
662+
fn scale_up_leaf_node(
663+
&self,
664+
_: &Arc<dyn ExecutionPlan>,
665+
_: usize,
666+
_: &ConfigOptions,
667+
) -> Option<Arc<dyn ExecutionPlan>> {
668+
None
669+
}
670+
}
671+
587672
async fn sql_to_annotated(query: &str) -> String {
673+
sql_to_annotated_with_options(query, move |b| b).await
674+
}
675+
676+
async fn sql_to_annotated_with_estimator<T: ExecutionPlan + Send + Sync + 'static>(
677+
query: &str,
678+
estimator: impl Fn(&T) -> Option<TaskEstimation> + Send + Sync + 'static,
679+
) -> String {
680+
sql_to_annotated_with_options(query, move |b| {
681+
b.with_distributed_task_estimator(CallbackEstimator::new(estimator))
682+
})
683+
.await
684+
}
685+
686+
async fn sql_to_annotated_with_options(
687+
query: &str,
688+
f: impl FnOnce(SessionStateBuilder) -> SessionStateBuilder,
689+
) -> String {
588690
let config = SessionConfig::new()
589691
.with_target_partitions(4)
590692
.with_information_schema(true);
591693

592-
let state = SessionStateBuilder::new()
694+
let state = f(SessionStateBuilder::new()
593695
.with_default_features()
594696
.with_config(config)
595-
.with_distributed_worker_resolver(InMemoryWorkerResolver::new(4))
596-
.build();
697+
.with_distributed_worker_resolver(InMemoryWorkerResolver::new(4)))
698+
.build();
597699

598700
let ctx = SessionContext::new_with_state(state);
599701
let mut queries = query.split(";").collect_vec();

src/distributed_planner/task_estimator.rs

Lines changed: 68 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use datafusion::config::ConfigOptions;
55
use datafusion::datasource::physical_plan::FileScanConfig;
66
use datafusion::physical_plan::ExecutionPlan;
77
use datafusion::prelude::SessionConfig;
8+
use delegate::delegate;
89
use std::collections::HashSet;
910
use std::fmt::Debug;
1011
use std::sync::Arc;
@@ -58,6 +59,34 @@ pub struct TaskEstimation {
5859
pub task_count: TaskCountAnnotation,
5960
}
6061

62+
impl TaskEstimation {
63+
/// Tells the distributed planner that the evaluated stage can have **at maximum** the provided
64+
/// number of tasks, setting a hard upper limit.
65+
///
66+
/// Returning `TaskEstimation::maximum(1)` tells the distributed planner that the evaluated
67+
/// stage cannot be distributed.
68+
///
69+
/// Even if a `TaskEstimation::maximum(N)` is provided, any other node in the same stage
70+
/// providing a value of `TaskEstimation::maximum(M)` where `M` < `N` will have preference.
71+
pub fn maximum(value: usize) -> Self {
72+
TaskEstimation {
73+
task_count: TaskCountAnnotation::Maximum(value),
74+
}
75+
}
76+
77+
/// Tells the distributed planner that the evaluated can **optimally** have the provided
78+
/// number of tasks, setting a soft task count hint that can be overridden by others.
79+
///
80+
/// The provided `TaskEstimation::desired(N)` can be overridden by:
81+
/// - Other nodes providing a `TaskEstimation::desired(M)` where `M` > `N`.
82+
/// - Any other node providing a `TaskEstimation::maximum(M)` where `M` can be anything.
83+
pub fn desired(value: usize) -> Self {
84+
TaskEstimation {
85+
task_count: TaskCountAnnotation::Desired(value),
86+
}
87+
}
88+
}
89+
6190
/// Given a leaf node, provides an estimation about how many tasks should be used in the
6291
/// stage containing it, and if the leaf node should be replaced by some other.
6392
///
@@ -66,14 +95,19 @@ pub struct TaskEstimation {
6695
/// count calculated based on whether lower stages are reducing the cardinality of the data
6796
/// or increasing it.
6897
pub trait TaskEstimator {
69-
/// Function applied to leaf nodes that returns a [TaskEstimation] hinting how many
70-
/// tasks should be used in the [Stage] containing that leaf node.
98+
/// Function applied to each node that returns a [TaskEstimation] hinting how many
99+
/// tasks should be used in the [Stage] containing that node.
100+
///
101+
/// All the [TaskEstimator] registered in the session will be applied to the node
102+
/// until one returns an estimation.
103+
///
71104
///
72-
/// All the [TaskEstimator] registered in the session will be applied to the leaf node
73-
/// until one returns an estimation. If no estimation is return from any of the
74-
/// [TaskEstimator]s, then `Maximum(1)` is returned, hinting the distributed planner to not
75-
/// distribute the stage containing that node.
76-
fn tasks_for_leaf_node(
105+
/// If no estimation is returned from any of the registered [TaskEstimator]s, then:
106+
/// - If the node is a leaf node,`Maximum(1)` is assumed, hinting the distributed planner
107+
/// that the leaf node cannot be distributed across tasks.
108+
/// - If the node is a normal node in the plan, then the maximum task count from its children
109+
/// is inherited.
110+
fn task_estimation(
77111
&self,
78112
plan: &Arc<dyn ExecutionPlan>,
79113
cfg: &ConfigOptions,
@@ -91,14 +125,18 @@ pub trait TaskEstimator {
91125
}
92126

93127
impl TaskEstimator for usize {
94-
fn tasks_for_leaf_node(
128+
fn task_estimation(
95129
&self,
96-
_: &Arc<dyn ExecutionPlan>,
130+
inputs: &Arc<dyn ExecutionPlan>,
97131
_: &ConfigOptions,
98132
) -> Option<TaskEstimation> {
99-
Some(TaskEstimation {
100-
task_count: TaskCountAnnotation::Desired(*self),
101-
})
133+
if inputs.children().is_empty() {
134+
Some(TaskEstimation {
135+
task_count: TaskCountAnnotation::Desired(*self),
136+
})
137+
} else {
138+
None
139+
}
102140
}
103141

104142
fn scale_up_leaf_node(
@@ -112,40 +150,20 @@ impl TaskEstimator for usize {
112150
}
113151

114152
impl TaskEstimator for Arc<dyn TaskEstimator> {
115-
fn tasks_for_leaf_node(
116-
&self,
117-
plan: &Arc<dyn ExecutionPlan>,
118-
cfg: &ConfigOptions,
119-
) -> Option<TaskEstimation> {
120-
self.as_ref().tasks_for_leaf_node(plan, cfg)
121-
}
122-
123-
fn scale_up_leaf_node(
124-
&self,
125-
plan: &Arc<dyn ExecutionPlan>,
126-
task_count: usize,
127-
cfg: &ConfigOptions,
128-
) -> Option<Arc<dyn ExecutionPlan>> {
129-
self.as_ref().scale_up_leaf_node(plan, task_count, cfg)
153+
delegate! {
154+
to self.as_ref() {
155+
fn task_estimation(&self, plan: &Arc<dyn ExecutionPlan>, cfg: &ConfigOptions) -> Option<TaskEstimation>;
156+
fn scale_up_leaf_node(&self, plan: &Arc<dyn ExecutionPlan>, task_count: usize, cfg: &ConfigOptions) -> Option<Arc<dyn ExecutionPlan>>;
157+
}
130158
}
131159
}
132160

133161
impl TaskEstimator for Arc<dyn TaskEstimator + Send + Sync> {
134-
fn tasks_for_leaf_node(
135-
&self,
136-
plan: &Arc<dyn ExecutionPlan>,
137-
cfg: &ConfigOptions,
138-
) -> Option<TaskEstimation> {
139-
self.as_ref().tasks_for_leaf_node(plan, cfg)
140-
}
141-
142-
fn scale_up_leaf_node(
143-
&self,
144-
plan: &Arc<dyn ExecutionPlan>,
145-
task_count: usize,
146-
cfg: &ConfigOptions,
147-
) -> Option<Arc<dyn ExecutionPlan>> {
148-
self.as_ref().scale_up_leaf_node(plan, task_count, cfg)
162+
delegate! {
163+
to self.as_ref() {
164+
fn task_estimation(&self, plan: &Arc<dyn ExecutionPlan>, cfg: &ConfigOptions) -> Option<TaskEstimation>;
165+
fn scale_up_leaf_node(&self, plan: &Arc<dyn ExecutionPlan>, task_count: usize, cfg: &ConfigOptions) -> Option<Arc<dyn ExecutionPlan>>;
166+
}
149167
}
150168
}
151169

@@ -177,15 +195,16 @@ pub(crate) fn set_distributed_task_estimator(
177195
struct FileScanConfigTaskEstimator;
178196

179197
impl TaskEstimator for FileScanConfigTaskEstimator {
180-
fn tasks_for_leaf_node(
198+
fn task_estimation(
181199
&self,
182200
plan: &Arc<dyn ExecutionPlan>,
183201
cfg: &ConfigOptions,
184202
) -> Option<TaskEstimation> {
185-
let d_cfg = cfg.extensions.get::<DistributedConfig>()?;
186203
let dse: &DataSourceExec = plan.as_any().downcast_ref()?;
187204
let file_scan: &FileScanConfig = dse.data_source().as_any().downcast_ref()?;
188205

206+
let d_cfg = cfg.extensions.get::<DistributedConfig>()?;
207+
189208
// Count how many distinct files we have in the FileScanConfig. Each file in each
190209
// file group is a PartitionedFile rather than a full file, so it's possible that
191210
// many entries refer to different chunks of the same physical file. By keeping a
@@ -244,21 +263,21 @@ pub(crate) struct CombinedTaskEstimator {
244263
}
245264

246265
impl TaskEstimator for CombinedTaskEstimator {
247-
fn tasks_for_leaf_node(
266+
fn task_estimation(
248267
&self,
249268
plan: &Arc<dyn ExecutionPlan>,
250269
cfg: &ConfigOptions,
251270
) -> Option<TaskEstimation> {
252271
for estimator in &self.user_provided {
253-
if let Some(result) = estimator.tasks_for_leaf_node(plan, cfg) {
272+
if let Some(result) = estimator.task_estimation(plan, cfg) {
254273
return Some(result);
255274
}
256275
}
257276
// We want to execute the default estimators last so that the user-provided ones have
258277
// a chance of providing an estimation.
259278
// If none of the user-provided returned an estimation, the default ones are used.
260279
for default_estimator in [&FileScanConfigTaskEstimator as &dyn TaskEstimator] {
261-
if let Some(result) = default_estimator.tasks_for_leaf_node(plan, cfg) {
280+
if let Some(result) = default_estimator.task_estimation(plan, cfg) {
262281
return Some(result);
263282
}
264283
}
@@ -348,7 +367,7 @@ mod tests {
348367
..Default::default()
349368
};
350369
cfg.extensions.insert(f(d_cfg));
351-
self.tasks_for_leaf_node(&node, &cfg)
370+
self.task_estimation(&node, &cfg)
352371
.unwrap()
353372
.task_count
354373
.as_usize()
@@ -370,7 +389,7 @@ mod tests {
370389
}
371390

372391
impl<F: Fn(&Arc<dyn ExecutionPlan>, &ConfigOptions) -> Option<TaskEstimation>> TaskEstimator for F {
373-
fn tasks_for_leaf_node(
392+
fn task_estimation(
374393
&self,
375394
plan: &Arc<dyn ExecutionPlan>,
376395
cfg: &ConfigOptions,

0 commit comments

Comments
 (0)