Skip to content

Commit 72dc970

Browse files
committed
Uncomment tests in favor of just #[ignore]-ing them
1 parent ccf36a1 commit 72dc970

File tree

7 files changed

+134
-25
lines changed

7 files changed

+134
-25
lines changed

tests/common/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
pub mod insta;
22
pub mod localhost;
33
pub mod parquet;
4+
pub mod plan;

tests/common/plan.rs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
use datafusion::common::plan_err;
2+
use datafusion::common::tree_node::{Transformed, TreeNode};
3+
use datafusion::error::DataFusionError;
4+
use datafusion::physical_expr::Partitioning;
5+
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode};
6+
use datafusion::physical_plan::ExecutionPlan;
7+
use datafusion_distributed::ArrowFlightReadExec;
8+
use std::sync::Arc;
9+
10+
pub fn distribute_aggregate(
11+
plan: Arc<dyn ExecutionPlan>,
12+
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
13+
let mut aggregate_partial_found = false;
14+
Ok(plan
15+
.transform_up(|node| {
16+
let Some(agg) = node.as_any().downcast_ref::<AggregateExec>() else {
17+
return Ok(Transformed::no(node));
18+
};
19+
20+
match agg.mode() {
21+
AggregateMode::Partial => {
22+
if aggregate_partial_found {
23+
return plan_err!("Two consecutive partial aggregations found");
24+
}
25+
aggregate_partial_found = true;
26+
let expr = agg
27+
.group_expr()
28+
.expr()
29+
.iter()
30+
.map(|(v, _)| Arc::clone(v))
31+
.collect::<Vec<_>>();
32+
33+
if node.children().len() != 1 {
34+
return plan_err!("Aggregate must have exactly one child");
35+
}
36+
let child = node.children()[0].clone();
37+
38+
let node = node.with_new_children(vec![Arc::new(ArrowFlightReadExec::new(
39+
Partitioning::Hash(expr, 1),
40+
child.schema(),
41+
0,
42+
))])?;
43+
Ok(Transformed::yes(node))
44+
}
45+
AggregateMode::Final
46+
| AggregateMode::FinalPartitioned
47+
| AggregateMode::Single
48+
| AggregateMode::SinglePartitioned => {
49+
if !aggregate_partial_found {
50+
return plan_err!("No partial aggregate found before the final one");
51+
}
52+
53+
if node.children().len() != 1 {
54+
return plan_err!("Aggregate must have exactly one child");
55+
}
56+
let child = node.children()[0].clone();
57+
58+
let node = node.with_new_children(vec![Arc::new(ArrowFlightReadExec::new(
59+
Partitioning::RoundRobinBatch(8),
60+
child.schema(),
61+
1,
62+
))])?;
63+
Ok(Transformed::yes(node))
64+
}
65+
}
66+
})?
67+
.data)
68+
}

tests/custom_extension_codec.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#[allow(dead_code)]
22
mod common;
3-
/*
3+
44
#[cfg(test)]
55
mod tests {
66
use crate::assert_snapshot;
@@ -27,7 +27,7 @@ mod tests {
2727
use datafusion::physical_plan::{
2828
displayable, execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
2929
};
30-
use datafusion_distributed::{assign_stages, ArrowFlightReadExec, SessionBuilder};
30+
use datafusion_distributed::{ArrowFlightReadExec, SessionBuilder};
3131
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
3232
use datafusion_proto::protobuf::proto_error;
3333
use futures::{stream, TryStreamExt};
@@ -37,6 +37,7 @@ mod tests {
3737
use std::sync::Arc;
3838

3939
#[tokio::test]
40+
#[ignore]
4041
async fn custom_extension_codec() -> Result<(), Box<dyn std::error::Error>> {
4142
#[derive(Clone)]
4243
struct CustomSessionBuilder;
@@ -66,7 +67,6 @@ mod tests {
6667
");
6768

6869
let distributed_plan = build_plan(true)?;
69-
let distributed_plan = assign_stages(distributed_plan, &ctx)?;
7070

7171
assert_snapshot!(displayable(distributed_plan.as_ref()).indent(true).to_string(), @r"
7272
SortExec: expr=[numbers@0 DESC NULLS LAST], preserve_partitioning=[false]
@@ -124,8 +124,9 @@ mod tests {
124124

125125
if distributed {
126126
plan = Arc::new(ArrowFlightReadExec::new(
127-
plan.clone(),
128127
Partitioning::Hash(vec![col("numbers", &plan.schema())?], 1),
128+
plan.clone().schema(),
129+
0, // TODO: stage num should be assigned by someone else
129130
));
130131
}
131132

@@ -139,8 +140,9 @@ mod tests {
139140

140141
if distributed {
141142
plan = Arc::new(ArrowFlightReadExec::new(
142-
plan.clone(),
143143
Partitioning::RoundRobinBatch(10),
144+
plan.clone().schema(),
145+
1, // TODO: stage num should be assigned by someone else
144146
));
145147

146148
plan = Arc::new(RepartitionExec::try_new(
@@ -266,4 +268,4 @@ mod tests {
266268
.map_err(|err| proto_error(format!("{err}")))
267269
}
268270
}
269-
}*/
271+
}

tests/distributed_aggregation.rs

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ mod tests {
66
use crate::assert_snapshot;
77
use crate::common::localhost::{start_localhost_context, NoopSessionBuilder};
88
use crate::common::parquet::register_parquet_tables;
9+
use crate::common::plan::distribute_aggregate;
910
use datafusion::arrow::util::pretty::pretty_format_batches;
1011
use datafusion::physical_plan::{displayable, execute_stream};
1112
use futures::TryStreamExt;
1213
use std::error::Error;
1314

1415
#[tokio::test]
16+
#[ignore]
1517
async fn distributed_aggregation() -> Result<(), Box<dyn Error>> {
1618
// FIXME these ports are in use on my machine, we should find unused ports
1719
// Changed them for now
@@ -26,9 +28,13 @@ mod tests {
2628

2729
let physical_str = displayable(physical.as_ref()).indent(true).to_string();
2830

29-
println!("\n\nPhysical Plan:\n{}", physical_str);
31+
let physical_distributed = distribute_aggregate(physical.clone())?;
3032

31-
/*assert_snapshot!(physical_str,
33+
let physical_distributed_str = displayable(physical_distributed.as_ref())
34+
.indent(true)
35+
.to_string();
36+
37+
assert_snapshot!(physical_str,
3238
@r"
3339
ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday]
3440
SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST]
@@ -41,7 +47,24 @@ mod tests {
4147
AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
4248
DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet
4349
",
44-
);*/
50+
);
51+
52+
assert_snapshot!(physical_distributed_str,
53+
@r"
54+
ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday]
55+
SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST]
56+
SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true]
57+
ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))]
58+
AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
59+
ArrowFlightReadExec: input_tasks=8 hash_expr=[] stage_id=UUID input_stage_id=UUID input_hosts=[http://localhost:50050/, http://localhost:50051/, http://localhost:50052/, http://localhost:50050/, http://localhost:50051/, http://localhost:50052/, http://localhost:50050/, http://localhost:50051/]
60+
CoalesceBatchesExec: target_batch_size=8192
61+
RepartitionExec: partitioning=Hash([RainToday@0], CPUs), input_partitions=CPUs
62+
RepartitionExec: partitioning=RoundRobinBatch(CPUs), input_partitions=1
63+
AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
64+
ArrowFlightReadExec: input_tasks=1 hash_expr=[RainToday@0] stage_id=UUID input_stage_id=UUID input_hosts=[http://localhost:50052/]
65+
DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet
66+
",
67+
);
4568

4669
let batches = pretty_format_batches(
4770
&execute_stream(physical, ctx.task_ctx())?
@@ -58,6 +81,20 @@ mod tests {
5881
+----------+-----------+
5982
");
6083

84+
let batches_distributed = pretty_format_batches(
85+
&execute_stream(physical_distributed, ctx.task_ctx())?
86+
.try_collect::<Vec<_>>()
87+
.await?,
88+
)?;
89+
assert_snapshot!(batches_distributed, @r"
90+
+----------+-----------+
91+
| count(*) | RainToday |
92+
+----------+-----------+
93+
| 66 | Yes |
94+
| 300 | No |
95+
+----------+-----------+
96+
");
97+
6198
Ok(())
6299
}
63100
}

tests/error_propagation.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#[allow(dead_code)]
22
mod common;
3-
/*
3+
44
#[cfg(test)]
55
mod tests {
66
use crate::common::localhost::start_localhost_context;
@@ -26,6 +26,7 @@ mod tests {
2626
use std::sync::Arc;
2727

2828
#[tokio::test]
29+
#[ignore]
2930
async fn test_error_propagation() -> Result<(), Box<dyn Error>> {
3031
#[derive(Clone)]
3132
struct CustomSessionBuilder;
@@ -48,14 +49,13 @@ mod tests {
4849

4950
let mut plan: Arc<dyn ExecutionPlan> = Arc::new(ErrorExec::new("something failed"));
5051

51-
for size in [1, 2, 3] {
52+
for (i, size) in [1, 2, 3].iter().enumerate() {
5253
plan = Arc::new(ArrowFlightReadExec::new(
53-
Partitioning::RoundRobinBatch(size),
54+
Partitioning::RoundRobinBatch(*size as usize),
5455
plan.schema(),
55-
0,
56+
i,
5657
));
5758
}
58-
let plan = assign_stages(plan, &ctx)?;
5959
let stream = execute_stream(plan, ctx.task_ctx())?;
6060

6161
let Err(err) = stream.try_collect::<Vec<_>>().await else {
@@ -170,4 +170,4 @@ mod tests {
170170
.map_err(|err| proto_error(format!("{err}")))
171171
}
172172
}
173-
}*/
173+
}

tests/highly_distributed_query.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
#[allow(dead_code)]
22
mod common;
3-
/*
3+
44
#[cfg(test)]
55
mod tests {
66
use crate::assert_snapshot;
77
use crate::common::localhost::{start_localhost_context, NoopSessionBuilder};
88
use crate::common::parquet::register_parquet_tables;
99
use datafusion::physical_expr::Partitioning;
1010
use datafusion::physical_plan::{displayable, execute_stream};
11-
use datafusion_distributed::{assign_stages, ArrowFlightReadExec};
11+
use datafusion_distributed::ArrowFlightReadExec;
1212
use futures::TryStreamExt;
1313
use std::error::Error;
1414
use std::sync::Arc;
1515

1616
#[tokio::test]
17+
#[ignore]
1718
async fn highly_distributed_query() -> Result<(), Box<dyn Error>> {
1819
let (ctx, _guard) = start_localhost_context(
1920
[
@@ -29,13 +30,13 @@ mod tests {
2930
let physical_str = displayable(physical.as_ref()).indent(true).to_string();
3031

3132
let mut physical_distributed = physical.clone();
32-
for size in [1, 10, 5] {
33+
for (i, size) in [1, 10, 5].iter().enumerate() {
3334
physical_distributed = Arc::new(ArrowFlightReadExec::new(
34-
physical_distributed.clone(),
35-
Partitioning::RoundRobinBatch(size),
35+
Partitioning::RoundRobinBatch(*size as usize),
36+
physical_distributed.schema(),
37+
i,
3638
));
3739
}
38-
let physical_distributed = assign_stages(physical_distributed, &ctx)?;
3940
let physical_distributed_str = displayable(physical_distributed.as_ref())
4041
.indent(true)
4142
.to_string();
@@ -75,4 +76,4 @@ mod tests {
7576

7677
Ok(())
7778
}
78-
}*/
79+
}

tests/stage_planning.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
mod common;
22
mod tpch;
33

4-
// FIXME: commented out until we figure out how to integrate best with tpch
5-
/*
64
#[cfg(test)]
75
mod tests {
86
use crate::tpch::tpch_query;
@@ -17,7 +15,9 @@ mod tests {
1715
use std::error::Error;
1816
use std::sync::Arc;
1917

18+
// FIXME: ignored out until we figure out how to integrate best with tpch
2019
#[tokio::test]
20+
#[ignore]
2121
async fn stage_planning() -> Result<(), Box<dyn Error>> {
2222
let config = SessionConfig::new().with_target_partitions(3);
2323

@@ -86,4 +86,4 @@ mod tests {
8686

8787
Ok(())
8888
}
89-
}*/
89+
}

0 commit comments

Comments
 (0)