Skip to content

Commit 2832861

Browse files
committed
Fix when worker urls is zero or one and add tests
1 parent e932754 commit 2832861

File tree

4 files changed

+166
-28
lines changed

4 files changed

+166
-28
lines changed

src/distributed_planner/distributed_physical_optimizer_rule.rs

Lines changed: 160 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,11 @@ pub fn apply_network_boundaries(
121121
}
122122
let distributed_cfg = DistributedConfig::from_config_options(cfg)?;
123123
let urls = distributed_cfg.__private_channel_resolver.0.get_urls()?;
124+
// If there are 1 or 0 available workers, it does not make sense to distribute the query,
125+
// so don't.
126+
if urls.len() <= 1 {
127+
return Ok(plan);
128+
}
124129
let ctx = _apply_network_boundaries(plan, cfg, urls.len())?;
125130
Ok(ctx.plan)
126131
}
@@ -402,7 +407,6 @@ mod tests {
402407
use crate::test_utils::in_memory_channel_resolver::InMemoryChannelResolver;
403408
use crate::test_utils::parquet::register_parquet_tables;
404409
use crate::{assert_snapshot, display_plan_ascii};
405-
use datafusion::error::DataFusionError;
406410
use datafusion::execution::SessionStateBuilder;
407411
use datafusion::prelude::{SessionConfig, SessionContext};
408412
/* shema for the "weather" table
@@ -436,7 +440,10 @@ mod tests {
436440
let query = r#"
437441
SELECT * FROM weather
438442
"#;
439-
let plan = sql_to_explain(query).await.unwrap();
443+
let plan = sql_to_explain(query, |b| {
444+
b.with_distributed_execution(InMemoryChannelResolver::new(3))
445+
})
446+
.await;
440447
assert_snapshot!(plan, @"DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, MaxTemp, Rainfall, Evaporation, Sunshine, WindGustDir, WindGustSpeed, WindDir9am, WindDir3pm, WindSpeed9am, WindSpeed3pm, Humidity9am, Humidity3pm, Pressure9am, Pressure3pm, Cloud9am, Cloud3pm, Temp9am, Temp3pm, RainToday, RISK_MM, RainTomorrow], file_type=parquet");
441448
}
442449

@@ -445,7 +452,10 @@ mod tests {
445452
let query = r#"
446453
SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)
447454
"#;
448-
let plan = sql_to_explain(query).await.unwrap();
455+
let plan = sql_to_explain(query, |b| {
456+
b.with_distributed_execution(InMemoryChannelResolver::new(3))
457+
})
458+
.await;
449459
assert_snapshot!(plan, @r"
450460
┌───── DistributedExec ── Tasks: t0:[p0]
451461
│ ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday]
@@ -469,12 +479,126 @@ mod tests {
469479
");
470480
}
471481

482+
#[tokio::test]
483+
async fn test_aggregation_with_fewer_workers_than_files() {
484+
let query = r#"
485+
SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)
486+
"#;
487+
let plan = sql_to_explain(query, |b| {
488+
b.with_distributed_execution(InMemoryChannelResolver::new(2))
489+
})
490+
.await;
491+
assert_snapshot!(plan, @r"
492+
┌───── DistributedExec ── Tasks: t0:[p0]
493+
│ ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday]
494+
│ SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST]
495+
│ [Stage 2] => NetworkCoalesceExec: output_partitions=8, input_tasks=2
496+
└──────────────────────────────────────────────────
497+
┌───── Stage 2 ── Tasks: t0:[p0..p3] t1:[p0..p3]
498+
│ SortExec: expr=[count(*)@0 ASC NULLS LAST], preserve_partitioning=[true]
499+
│ ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))]
500+
│ AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
501+
│ CoalesceBatchesExec: target_batch_size=8192
502+
│ [Stage 1] => NetworkShuffleExec: output_partitions=4, input_tasks=2
503+
└──────────────────────────────────────────────────
504+
┌───── Stage 1 ── Tasks: t0:[p0..p7] t1:[p0..p7]
505+
│ RepartitionExec: partitioning=Hash([RainToday@0], 8), input_partitions=4
506+
│ RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2
507+
│ AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
508+
│ PartitionIsolatorExec: t0:[p0,p1,__] t1:[__,__,p0]
509+
│ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[RainToday], file_type=parquet
510+
└──────────────────────────────────────────────────
511+
");
512+
}
513+
514+
#[tokio::test]
515+
async fn test_aggregation_with_0_workers() {
516+
let query = r#"
517+
SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)
518+
"#;
519+
let plan = sql_to_explain(query, |b| {
520+
b.with_distributed_execution(InMemoryChannelResolver::new(0))
521+
})
522+
.await;
523+
assert_snapshot!(plan, @r"
524+
ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday]
525+
SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST]
526+
SortExec: expr=[count(*)@0 ASC NULLS LAST], preserve_partitioning=[true]
527+
ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))]
528+
AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
529+
CoalesceBatchesExec: target_batch_size=8192
530+
RepartitionExec: partitioning=Hash([RainToday@0], 4), input_partitions=4
531+
RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=3
532+
AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
533+
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[RainToday], file_type=parquet
534+
");
535+
}
536+
537+
#[tokio::test]
538+
async fn test_aggregation_with_high_cardinality_factor() {
539+
let query = r#"
540+
SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)
541+
"#;
542+
let plan = sql_to_explain(query, |b| {
543+
b.with_distributed_execution(InMemoryChannelResolver::new(3))
544+
.with_distributed_cardinality_effect_task_scale_factor(3.0)
545+
.unwrap()
546+
})
547+
.await;
548+
assert_snapshot!(plan, @r"
549+
┌───── DistributedExec ── Tasks: t0:[p0]
550+
│ ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday]
551+
│ SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST]
552+
│ SortExec: expr=[count(*)@0 ASC NULLS LAST], preserve_partitioning=[true]
553+
│ ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))]
554+
│ AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
555+
│ CoalesceBatchesExec: target_batch_size=8192
556+
│ [Stage 1] => NetworkShuffleExec: output_partitions=4, input_tasks=3
557+
└──────────────────────────────────────────────────
558+
┌───── Stage 1 ── Tasks: t0:[p0..p3] t1:[p0..p3] t2:[p0..p3]
559+
│ RepartitionExec: partitioning=Hash([RainToday@0], 4), input_partitions=4
560+
│ RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
561+
│ AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
562+
│ PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0]
563+
│ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[RainToday], file_type=parquet
564+
└──────────────────────────────────────────────────
565+
");
566+
}
567+
568+
#[tokio::test]
569+
async fn test_aggregation_with_a_lot_of_files_per_task() {
570+
let query = r#"
571+
SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)
572+
"#;
573+
let plan = sql_to_explain(query, |b| {
574+
b.with_distributed_execution(InMemoryChannelResolver::new(3))
575+
.with_distributed_files_per_task(3)
576+
.unwrap()
577+
})
578+
.await;
579+
assert_snapshot!(plan, @r"
580+
ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday]
581+
SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST]
582+
SortExec: expr=[count(*)@0 ASC NULLS LAST], preserve_partitioning=[true]
583+
ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))]
584+
AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
585+
CoalesceBatchesExec: target_batch_size=8192
586+
RepartitionExec: partitioning=Hash([RainToday@0], 4), input_partitions=4
587+
RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=3
588+
AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
589+
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[RainToday], file_type=parquet
590+
");
591+
}
592+
472593
#[tokio::test]
473594
async fn test_aggregation_with_partitions_per_task() {
474595
let query = r#"
475596
SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)
476597
"#;
477-
let plan = sql_to_explain(query).await.unwrap();
598+
let plan = sql_to_explain(query, |b| {
599+
b.with_distributed_execution(InMemoryChannelResolver::new(3))
600+
})
601+
.await;
478602
assert_snapshot!(plan, @r"
479603
┌───── DistributedExec ── Tasks: t0:[p0]
480604
│ ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday]
@@ -503,7 +627,10 @@ mod tests {
503627
let query = r#"
504628
SELECT a."MinTemp", b."MaxTemp" FROM weather a LEFT JOIN weather b ON a."RainToday" = b."RainToday"
505629
"#;
506-
let plan = sql_to_explain(query).await.unwrap();
630+
let plan = sql_to_explain(query, |b| {
631+
b.with_distributed_execution(InMemoryChannelResolver::new(3))
632+
})
633+
.await;
507634
assert_snapshot!(plan, @r"
508635
CoalesceBatchesExec: target_batch_size=8192
509636
HashJoinExec: mode=CollectLeft, join_type=Left, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
@@ -538,7 +665,10 @@ mod tests {
538665
LEFT JOIN b
539666
ON a."RainTomorrow" = b."RainTomorrow"
540667
"#;
541-
let plan = sql_to_explain(query).await.unwrap();
668+
let plan = sql_to_explain(query, |b| {
669+
b.with_distributed_execution(InMemoryChannelResolver::new(3))
670+
})
671+
.await;
542672
assert_snapshot!(plan, @r"
543673
┌───── DistributedExec ── Tasks: t0:[p0]
544674
│ CoalescePartitionsExec
@@ -583,7 +713,10 @@ mod tests {
583713
let query = r#"
584714
SELECT * FROM weather ORDER BY "MinTemp" DESC
585715
"#;
586-
let plan = sql_to_explain(query).await.unwrap();
716+
let plan = sql_to_explain(query, |b| {
717+
b.with_distributed_execution(InMemoryChannelResolver::new(3))
718+
})
719+
.await;
587720
assert_snapshot!(plan, @r"
588721
┌───── DistributedExec ── Tasks: t0:[p0]
589722
│ SortPreservingMergeExec: [MinTemp@0 DESC]
@@ -602,7 +735,10 @@ mod tests {
602735
let query = r#"
603736
SELECT DISTINCT "RainToday", "WindGustDir" FROM weather
604737
"#;
605-
let plan = sql_to_explain(query).await.unwrap();
738+
let plan = sql_to_explain(query, |b| {
739+
b.with_distributed_execution(InMemoryChannelResolver::new(3))
740+
})
741+
.await;
606742
assert_snapshot!(plan, @r"
607743
┌───── DistributedExec ── Tasks: t0:[p0]
608744
│ CoalescePartitionsExec
@@ -628,7 +764,10 @@ mod tests {
628764
let query = r#"
629765
SHOW COLUMNS from weather
630766
"#;
631-
let plan = sql_to_explain(query).await.unwrap();
767+
let plan = sql_to_explain(query, |b| {
768+
b.with_distributed_execution(InMemoryChannelResolver::new(3))
769+
})
770+
.await;
632771
assert_snapshot!(plan, @r"
633772
CoalescePartitionsExec
634773
ProjectionExec: expr=[table_catalog@0 as table_catalog, table_schema@1 as table_schema, table_name@2 as table_name, column_name@3 as column_name, data_type@5 as data_type, is_nullable@4 as is_nullable]
@@ -639,26 +778,29 @@ mod tests {
639778
");
640779
}
641780

642-
async fn sql_to_explain(query: &str) -> Result<String, DataFusionError> {
781+
async fn sql_to_explain(
782+
query: &str,
783+
f: impl FnOnce(SessionStateBuilder) -> SessionStateBuilder,
784+
) -> String {
643785
let config = SessionConfig::new()
644786
.with_target_partitions(4)
645787
.with_information_schema(true);
646788

647-
let state = SessionStateBuilder::new()
789+
let builder = SessionStateBuilder::new()
648790
.with_default_features()
649-
.with_config(config)
650-
.with_distributed_execution(InMemoryChannelResolver::new())
651-
.build();
791+
.with_config(config);
792+
793+
let state = f(builder).build();
652794

653795
let ctx = SessionContext::new_with_state(state);
654-
register_parquet_tables(&ctx).await?;
796+
register_parquet_tables(&ctx).await.unwrap();
655797

656798
let mut df = None;
657799
for query in query.split(";") {
658-
df = Some(ctx.sql(query).await?);
800+
df = Some(ctx.sql(query).await.unwrap());
659801
}
660802

661-
let physical_plan = df.unwrap().create_physical_plan().await?;
662-
Ok(display_plan_ascii(physical_plan.as_ref(), false))
803+
let physical_plan = df.unwrap().create_physical_plan().await.unwrap();
804+
display_plan_ascii(physical_plan.as_ref(), false)
663805
}
664806
}

src/metrics/task_metrics_collector.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ mod tests {
151151
let state = SessionStateBuilder::new()
152152
.with_default_features()
153153
.with_config(config)
154-
.with_distributed_execution(InMemoryChannelResolver::new())
154+
.with_distributed_execution(InMemoryChannelResolver::new(10))
155155
.with_distributed_task_estimator(2)
156156
.build();
157157

src/metrics/task_metrics_rewriter.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ mod tests {
244244

245245
if distributed {
246246
builder = builder
247-
.with_distributed_execution(InMemoryChannelResolver::new())
247+
.with_distributed_execution(InMemoryChannelResolver::new(10))
248248
.with_distributed_task_estimator(2)
249249
}
250250

src/test_utils/in_memory_channel_resolver.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,11 @@ const DUMMY_URL: &str = "http://localhost:50051";
1616
#[derive(Clone)]
1717
pub struct InMemoryChannelResolver {
1818
channel: FlightServiceClient<BoxCloneSyncChannel>,
19-
}
20-
21-
impl Default for InMemoryChannelResolver {
22-
fn default() -> Self {
23-
Self::new()
24-
}
19+
n_workers: usize,
2520
}
2621

2722
impl InMemoryChannelResolver {
28-
pub fn new() -> Self {
23+
pub fn new(n_workers: usize) -> Self {
2924
let (client, server) = tokio::io::duplex(1024 * 1024);
3025

3126
let mut client = Some(client);
@@ -40,6 +35,7 @@ impl InMemoryChannelResolver {
4035

4136
let this = Self {
4237
channel: create_flight_client(BoxCloneSyncChannel::new(channel)),
38+
n_workers,
4339
};
4440
let this_clone = this.clone();
4541

@@ -72,7 +68,7 @@ impl ChannelResolver for InMemoryChannelResolver {
7268
fn get_urls(&self) -> Result<Vec<url::Url>, DataFusionError> {
7369
// Set to a high number so that the distributed planner does not limit the maximum
7470
// spawned tasks to just 1.
75-
Ok(vec![url::Url::parse(DUMMY_URL).unwrap(); 100])
71+
Ok(vec![url::Url::parse(DUMMY_URL).unwrap(); self.n_workers])
7672
}
7773

7874
async fn get_flight_client_for_url(

0 commit comments

Comments
 (0)