Skip to content

Commit 9f6fdfc

Browse files
committed
Move back to with_distributed_channel_resolver and let users inject either DistributedPhysicalOptimizerRule or their own rules
1 parent 2832861 commit 9f6fdfc

File tree

10 files changed

+62
-50
lines changed

10 files changed

+62
-50
lines changed

benchmarks/src/tpch/run.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ use datafusion_distributed::test_utils::localhost::{
4646
LocalHostChannelResolver, spawn_flight_service,
4747
};
4848
use datafusion_distributed::{
49-
DistributedExt, DistributedSessionBuilder, DistributedSessionBuilderContext, NetworkBoundaryExt,
49+
DistributedExt, DistributedPhysicalOptimizerRule, DistributedSessionBuilder,
50+
DistributedSessionBuilderContext, NetworkBoundaryExt,
5051
};
5152
use log::info;
5253
use std::fs;
@@ -138,7 +139,8 @@ impl DistributedSessionBuilder for RunOpt {
138139
.with_default_features()
139140
.with_config(config)
140141
.with_distributed_user_codec(InMemoryCacheExecCodec)
141-
.with_distributed_execution(LocalHostChannelResolver::new(self.workers.clone()))
142+
.with_distributed_channel_resolver(LocalHostChannelResolver::new(self.workers.clone()))
143+
.with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule))
142144
.with_distributed_option_extension_from_headers::<WarmingUpMarker>(&ctx.headers)?
143145
.with_distributed_files_per_task(
144146
self.files_per_task.unwrap_or(get_available_parallelism()),

examples/in_memory_cluster.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@ use datafusion::physical_plan::displayable;
77
use datafusion::prelude::{ParquetReadOptions, SessionContext};
88
use datafusion_distributed::{
99
ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt,
10-
DistributedSessionBuilderContext, create_flight_client,
10+
DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext, create_flight_client,
1111
};
1212
use futures::TryStreamExt;
1313
use hyper_util::rt::TokioIo;
1414
use std::error::Error;
15+
use std::sync::Arc;
1516
use structopt::StructOpt;
1617
use tonic::transport::{Endpoint, Server};
1718

@@ -34,7 +35,8 @@ async fn main() -> Result<(), Box<dyn Error>> {
3435

3536
let state = SessionStateBuilder::new()
3637
.with_default_features()
37-
.with_distributed_execution(InMemoryChannelResolver::new())
38+
.with_distributed_channel_resolver(InMemoryChannelResolver::new())
39+
.with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule))
3840
.build();
3941

4042
let ctx = SessionContext::from(state);
@@ -97,7 +99,7 @@ impl InMemoryChannelResolver {
9799
async move {
98100
let builder = SessionStateBuilder::new()
99101
.with_default_features()
100-
.with_distributed_execution(this)
102+
.with_distributed_channel_resolver(this)
101103
.with_runtime_env(ctx.runtime_env.clone());
102104
Ok(builder.build())
103105
}

examples/localhost_run.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@ use datafusion::common::DataFusionError;
66
use datafusion::execution::SessionStateBuilder;
77
use datafusion::physical_plan::displayable;
88
use datafusion::prelude::{ParquetReadOptions, SessionContext};
9-
use datafusion_distributed::{BoxCloneSyncChannel, ChannelResolver, DistributedExt};
9+
use datafusion_distributed::{
10+
BoxCloneSyncChannel, ChannelResolver, DistributedExt, DistributedPhysicalOptimizerRule,
11+
};
1012
use futures::TryStreamExt;
1113
use std::error::Error;
14+
use std::sync::Arc;
1215
use structopt::StructOpt;
1316
use tonic::transport::Channel;
1417
use url::Url;
@@ -38,7 +41,8 @@ async fn main() -> Result<(), Box<dyn Error>> {
3841

3942
let state = SessionStateBuilder::new()
4043
.with_default_features()
41-
.with_distributed_execution(localhost_resolver)
44+
.with_distributed_channel_resolver(localhost_resolver)
45+
.with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule))
4246
.build();
4347

4448
let ctx = SessionContext::from(state);

examples/localhost_worker.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
3838
async move {
3939
Ok(SessionStateBuilder::new()
4040
.with_runtime_env(ctx.runtime_env)
41-
.with_distributed_execution(local_host_resolver)
41+
.with_distributed_channel_resolver(local_host_resolver)
4242
.with_default_features()
4343
.build())
4444
}

src/distributed_ext.rs

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::config_extension_ext::{
44
};
55
use crate::distributed_planner::set_distributed_task_estimator;
66
use crate::protobuf::{set_distributed_user_codec, set_distributed_user_codec_arc};
7-
use crate::{ChannelResolver, DistributedConfig, DistributedPhysicalOptimizerRule, TaskEstimator};
7+
use crate::{ChannelResolver, DistributedConfig, TaskEstimator};
88
use datafusion::common::DataFusionError;
99
use datafusion::config::ConfigExtension;
1010
use datafusion::execution::SessionStateBuilder;
@@ -178,14 +178,8 @@ pub trait DistributedExt: Sized {
178178
/// Same as [DistributedExt::set_distributed_user_codec] but with a dynamic argument.
179179
fn set_distributed_user_codec_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>);
180180

181-
/// Enables distributed execution. For this, several things happen:
182-
///
183-
/// - Injects a [ChannelResolver] implementation for Distributed DataFusion to resolve worker
184-
/// nodes. When running in distributed mode, setting a [ChannelResolver] is required.
185-
/// - Injects a [DistributedPhysicalOptimizerRule] rule that will inject network boundaries
186-
/// in the plan and will break it down into stages.
187-
/// - Injects a [DistributedConfig] object with configuration about the amount of tasks that
188-
/// should be spawned while distributing the queries.
181+
/// Injects a [ChannelResolver] implementation for Distributed DataFusion to resolve worker
182+
/// nodes. When running in distributed mode, setting a [ChannelResolver] is required.
189183
///
190184
/// Example:
191185
///
@@ -196,7 +190,8 @@ pub trait DistributedExt: Sized {
196190
/// # use datafusion::execution::{SessionState, SessionStateBuilder};
197191
/// # use datafusion::prelude::SessionConfig;
198192
/// # use url::Url;
199-
/// # use datafusion_distributed::{BoxCloneSyncChannel, ChannelResolver, DistributedExt, DistributedSessionBuilderContext};
193+
/// # use std::sync::Arc;
194+
/// # use datafusion_distributed::{BoxCloneSyncChannel, ChannelResolver, DistributedExt, DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext};
200195
///
201196
/// struct CustomChannelResolver;
202197
///
@@ -211,25 +206,29 @@ pub trait DistributedExt: Sized {
211206
/// }
212207
/// }
213208
///
209+
/// // This tweaks the SessionState so that it can plan for distributed queries and execute them.
214210
/// let state = SessionStateBuilder::new()
215-
/// .with_distributed_execution(CustomChannelResolver)
211+
/// .with_distributed_channel_resolver(CustomChannelResolver)
212+
/// // the DistributedPhysicalOptimizerRule also needs to be passed so that query plans
213+
/// // get distributed.
214+
/// .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule))
216215
/// .build();
217216
///
217+
/// // This function can be provided to an ArrowFlightEndpoint so that, upon receiving a distributed
218+
/// // part of a plan, it knows how to resolve gRPC channels from URLs for making network calls to other nodes.
218219
/// async fn build_state(ctx: DistributedSessionBuilderContext) -> Result<SessionState, DataFusionError> {
219-
/// // This function can be provided to an ArrowFlightEndpoint so that it knows how to
220-
/// // resolve tonic channels from URLs upon making network calls to other nodes.
221220
/// Ok(SessionStateBuilder::new()
222-
/// .with_distributed_execution(CustomChannelResolver)
221+
/// .with_distributed_channel_resolver(CustomChannelResolver)
223222
/// .build())
224223
/// }
225224
/// ```
226-
fn with_distributed_execution<T: ChannelResolver + Send + Sync + 'static>(
225+
fn with_distributed_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(
227226
self,
228227
resolver: T,
229228
) -> Self;
230229

231-
/// Same as [DistributedExt::with_distributed_execution] but with an in-place mutation.
232-
fn set_distributed_execution<T: ChannelResolver + Send + Sync + 'static>(
230+
/// Same as [DistributedExt::with_distributed_channel_resolver] but with an in-place mutation.
231+
fn set_distributed_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(
233232
&mut self,
234233
resolver: T,
235234
);
@@ -317,14 +316,12 @@ impl DistributedExt for SessionStateBuilder {
317316
set_distributed_user_codec_arc(self.config().get_or_insert_default(), codec)
318317
}
319318

320-
fn set_distributed_execution<T: ChannelResolver + Send + Sync + 'static>(
319+
fn set_distributed_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(
321320
&mut self,
322321
resolver: T,
323322
) {
324323
let cfg = self.config().get_or_insert_default();
325324
set_distributed_channel_resolver(cfg, resolver);
326-
let rules = self.physical_optimizer_rules().get_or_insert_default();
327-
rules.push(Arc::new(DistributedPhysicalOptimizerRule));
328325
}
329326

330327
fn set_distributed_task_estimator<T: TaskEstimator + Send + Sync + 'static>(
@@ -372,9 +369,9 @@ impl DistributedExt for SessionStateBuilder {
372369
#[expr($;self)]
373370
fn with_distributed_user_codec_arc(mut self, codec: Arc<dyn PhysicalExtensionCodec>) -> Self;
374371

375-
#[call(set_distributed_execution)]
372+
#[call(set_distributed_channel_resolver)]
376373
#[expr($;self)]
377-
fn with_distributed_execution<T: ChannelResolver + Send + Sync + 'static>(mut self, resolver: T) -> Self;
374+
fn with_distributed_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(mut self, resolver: T) -> Self;
378375

379376
#[call(set_distributed_task_estimator)]
380377
#[expr($;self)]

src/distributed_planner/distributed_physical_optimizer_rule.rs

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -403,12 +403,13 @@ impl<T: ?Sized> Referenced<'_, T> {
403403

404404
#[cfg(test)]
405405
mod tests {
406-
use crate::DistributedExt;
407406
use crate::test_utils::in_memory_channel_resolver::InMemoryChannelResolver;
408407
use crate::test_utils::parquet::register_parquet_tables;
408+
use crate::{DistributedExt, DistributedPhysicalOptimizerRule};
409409
use crate::{assert_snapshot, display_plan_ascii};
410410
use datafusion::execution::SessionStateBuilder;
411411
use datafusion::prelude::{SessionConfig, SessionContext};
412+
use std::sync::Arc;
412413
/* shema for the "weather" table
413414
414415
MinTemp [type=DOUBLE] [repetitiontype=OPTIONAL]
@@ -441,7 +442,7 @@ mod tests {
441442
SELECT * FROM weather
442443
"#;
443444
let plan = sql_to_explain(query, |b| {
444-
b.with_distributed_execution(InMemoryChannelResolver::new(3))
445+
b.with_distributed_channel_resolver(InMemoryChannelResolver::new(3))
445446
})
446447
.await;
447448
assert_snapshot!(plan, @"DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, MaxTemp, Rainfall, Evaporation, Sunshine, WindGustDir, WindGustSpeed, WindDir9am, WindDir3pm, WindSpeed9am, WindSpeed3pm, Humidity9am, Humidity3pm, Pressure9am, Pressure3pm, Cloud9am, Cloud3pm, Temp9am, Temp3pm, RainToday, RISK_MM, RainTomorrow], file_type=parquet");
@@ -453,7 +454,7 @@ mod tests {
453454
SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)
454455
"#;
455456
let plan = sql_to_explain(query, |b| {
456-
b.with_distributed_execution(InMemoryChannelResolver::new(3))
457+
b.with_distributed_channel_resolver(InMemoryChannelResolver::new(3))
457458
})
458459
.await;
459460
assert_snapshot!(plan, @r"
@@ -485,7 +486,7 @@ mod tests {
485486
SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)
486487
"#;
487488
let plan = sql_to_explain(query, |b| {
488-
b.with_distributed_execution(InMemoryChannelResolver::new(2))
489+
b.with_distributed_channel_resolver(InMemoryChannelResolver::new(2))
489490
})
490491
.await;
491492
assert_snapshot!(plan, @r"
@@ -517,7 +518,7 @@ mod tests {
517518
SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)
518519
"#;
519520
let plan = sql_to_explain(query, |b| {
520-
b.with_distributed_execution(InMemoryChannelResolver::new(0))
521+
b.with_distributed_channel_resolver(InMemoryChannelResolver::new(0))
521522
})
522523
.await;
523524
assert_snapshot!(plan, @r"
@@ -540,7 +541,7 @@ mod tests {
540541
SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)
541542
"#;
542543
let plan = sql_to_explain(query, |b| {
543-
b.with_distributed_execution(InMemoryChannelResolver::new(3))
544+
b.with_distributed_channel_resolver(InMemoryChannelResolver::new(3))
544545
.with_distributed_cardinality_effect_task_scale_factor(3.0)
545546
.unwrap()
546547
})
@@ -571,7 +572,7 @@ mod tests {
571572
SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)
572573
"#;
573574
let plan = sql_to_explain(query, |b| {
574-
b.with_distributed_execution(InMemoryChannelResolver::new(3))
575+
b.with_distributed_channel_resolver(InMemoryChannelResolver::new(3))
575576
.with_distributed_files_per_task(3)
576577
.unwrap()
577578
})
@@ -596,7 +597,7 @@ mod tests {
596597
SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)
597598
"#;
598599
let plan = sql_to_explain(query, |b| {
599-
b.with_distributed_execution(InMemoryChannelResolver::new(3))
600+
b.with_distributed_channel_resolver(InMemoryChannelResolver::new(3))
600601
})
601602
.await;
602603
assert_snapshot!(plan, @r"
@@ -628,7 +629,7 @@ mod tests {
628629
SELECT a."MinTemp", b."MaxTemp" FROM weather a LEFT JOIN weather b ON a."RainToday" = b."RainToday"
629630
"#;
630631
let plan = sql_to_explain(query, |b| {
631-
b.with_distributed_execution(InMemoryChannelResolver::new(3))
632+
b.with_distributed_channel_resolver(InMemoryChannelResolver::new(3))
632633
})
633634
.await;
634635
assert_snapshot!(plan, @r"
@@ -666,7 +667,7 @@ mod tests {
666667
ON a."RainTomorrow" = b."RainTomorrow"
667668
"#;
668669
let plan = sql_to_explain(query, |b| {
669-
b.with_distributed_execution(InMemoryChannelResolver::new(3))
670+
b.with_distributed_channel_resolver(InMemoryChannelResolver::new(3))
670671
})
671672
.await;
672673
assert_snapshot!(plan, @r"
@@ -714,7 +715,7 @@ mod tests {
714715
SELECT * FROM weather ORDER BY "MinTemp" DESC
715716
"#;
716717
let plan = sql_to_explain(query, |b| {
717-
b.with_distributed_execution(InMemoryChannelResolver::new(3))
718+
b.with_distributed_channel_resolver(InMemoryChannelResolver::new(3))
718719
})
719720
.await;
720721
assert_snapshot!(plan, @r"
@@ -736,7 +737,7 @@ mod tests {
736737
SELECT DISTINCT "RainToday", "WindGustDir" FROM weather
737738
"#;
738739
let plan = sql_to_explain(query, |b| {
739-
b.with_distributed_execution(InMemoryChannelResolver::new(3))
740+
b.with_distributed_channel_resolver(InMemoryChannelResolver::new(3))
740741
})
741742
.await;
742743
assert_snapshot!(plan, @r"
@@ -765,7 +766,7 @@ mod tests {
765766
SHOW COLUMNS from weather
766767
"#;
767768
let plan = sql_to_explain(query, |b| {
768-
b.with_distributed_execution(InMemoryChannelResolver::new(3))
769+
b.with_distributed_channel_resolver(InMemoryChannelResolver::new(3))
769770
})
770771
.await;
771772
assert_snapshot!(plan, @r"
@@ -788,6 +789,7 @@ mod tests {
788789

789790
let builder = SessionStateBuilder::new()
790791
.with_default_features()
792+
.with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule))
791793
.with_config(config);
792794

793795
let state = f(builder).build();

src/metrics/task_metrics_collector.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,12 @@ mod tests {
127127
use datafusion::arrow::record_batch::RecordBatch;
128128
use futures::StreamExt;
129129

130-
use crate::DistributedExt;
131130
use crate::execution_plans::DistributedExec;
132131
use crate::metrics::proto::metrics_set_proto_to_df;
133132
use crate::test_utils::in_memory_channel_resolver::InMemoryChannelResolver;
134133
use crate::test_utils::plans::{count_plan_nodes, get_stages_and_stage_keys};
135134
use crate::test_utils::session_context::register_temp_parquet_table;
135+
use crate::{DistributedExt, DistributedPhysicalOptimizerRule};
136136
use datafusion::execution::{SessionStateBuilder, context::SessionContext};
137137
use datafusion::prelude::SessionConfig;
138138
use datafusion::{
@@ -151,7 +151,8 @@ mod tests {
151151
let state = SessionStateBuilder::new()
152152
.with_default_features()
153153
.with_config(config)
154-
.with_distributed_execution(InMemoryChannelResolver::new(10))
154+
.with_distributed_channel_resolver(InMemoryChannelResolver::new(10))
155+
.with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule))
155156
.with_distributed_task_estimator(2)
156157
.build();
157158

src/metrics/task_metrics_rewriter.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,6 @@ pub fn stage_metrics_rewriter(
193193

194194
#[cfg(test)]
195195
mod tests {
196-
use crate::DistributedExec;
197196
use crate::PartitionIsolatorExec;
198197
use crate::metrics::proto::{
199198
MetricsSetProto, df_metrics_set_to_proto, metrics_set_proto_to_df,
@@ -205,6 +204,7 @@ mod tests {
205204
use crate::test_utils::metrics::make_test_metrics_set_proto_from_seed;
206205
use crate::test_utils::plans::count_plan_nodes;
207206
use crate::test_utils::session_context::register_temp_parquet_table;
207+
use crate::{DistributedExec, DistributedPhysicalOptimizerRule};
208208
use crate::{NetworkBoundaryExt, Stage};
209209
use bytes::Bytes;
210210
use datafusion::arrow::array::{Int32Array, StringArray};
@@ -244,7 +244,8 @@ mod tests {
244244

245245
if distributed {
246246
builder = builder
247-
.with_distributed_execution(InMemoryChannelResolver::new(10))
247+
.with_distributed_channel_resolver(InMemoryChannelResolver::new(10))
248+
.with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule))
248249
.with_distributed_task_estimator(2)
249250
}
250251

src/test_utils/in_memory_channel_resolver.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ impl InMemoryChannelResolver {
4545
async move {
4646
let builder = SessionStateBuilder::new()
4747
.with_default_features()
48-
.with_distributed_execution(this)
48+
.with_distributed_channel_resolver(this)
4949
.with_runtime_env(ctx.runtime_env.clone());
5050
Ok(builder.build())
5151
}

src/test_utils/localhost.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::{
22
ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt,
3-
DistributedSessionBuilder, DistributedSessionBuilderContext,
3+
DistributedPhysicalOptimizerRule, DistributedSessionBuilder, DistributedSessionBuilderContext,
44
MappedDistributedSessionBuilderExt, create_flight_client,
55
};
66
use arrow_flight::flight_service_client::FlightServiceClient;
@@ -53,7 +53,10 @@ where
5353
let channel_resolver = LocalHostChannelResolver::new(ports.clone());
5454
let session_builder = session_builder.map(move |builder: SessionStateBuilder| {
5555
let channel_resolver = channel_resolver.clone();
56-
Ok(builder.with_distributed_execution(channel_resolver).build())
56+
Ok(builder
57+
.with_distributed_channel_resolver(channel_resolver)
58+
.with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule))
59+
.build())
5760
});
5861
let mut join_set = JoinSet::new();
5962
for listener in listeners {

0 commit comments

Comments
 (0)