Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
pub mod result;
pub mod util;
1 change: 0 additions & 1 deletion src/common/result.rs

This file was deleted.

1 change: 0 additions & 1 deletion src/flight_service/do_get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use arrow_flight::flight_service_server::FlightService;
use arrow_flight::Ticket;
use datafusion::execution::SessionStateBuilder;
use datafusion::optimizer::OptimizerConfig;
use datafusion::physical_plan::ExecutionPlan;
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use futures::TryStreamExt;
use prost::Message;
Expand Down
1 change: 0 additions & 1 deletion src/physical_optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use datafusion::{
tree_node::{Transformed, TreeNode, TreeNodeRewriter},
},
config::ConfigOptions,
datasource::physical_plan::FileSource,
error::Result,
physical_optimizer::PhysicalOptimizerRule,
physical_plan::{
Expand Down
2 changes: 1 addition & 1 deletion src/plan/arrow_flight_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use arrow_flight::{FlightClient, Ticket};
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::common::{internal_datafusion_err, plan_err};
use datafusion::error::Result;
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_expr::{EquivalenceProperties, Partitioning};
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
Expand Down
2 changes: 1 addition & 1 deletion src/stage/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use std::fmt::Write;

use datafusion::{
error::Result,
physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan},
physical_plan::{DisplayAs, DisplayFormatType},
};

use crate::{
Expand Down
18 changes: 5 additions & 13 deletions src/stage/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use datafusion_proto::{
physical_plan::{AsExecutionPlan, PhysicalExtensionCodec},
protobuf::PhysicalPlanNode,
};
use prost::Message;

use crate::{plan::DistributedCodec, task::ExecutionTask};

Expand Down Expand Up @@ -100,9 +99,7 @@ pub fn stage_from_proto(
}

// add tests for round trip to and from a proto message for ExecutionStage
/* TODO: broken for now
#[cfg(test)]

mod tests {
use std::sync::Arc;

Expand All @@ -111,19 +108,13 @@ mod tests {
array::{RecordBatch, StringArray, UInt8Array},
datatypes::{DataType, Field, Schema},
},
catalog::memory::DataSourceExec,
common::{internal_datafusion_err, internal_err},
common::internal_datafusion_err,
datasource::MemTable,
error::{DataFusionError, Result},
error::Result,
execution::context::SessionContext,
prelude::SessionConfig,
};
use datafusion_proto::{
physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec},
protobuf::PhysicalPlanNode,
};
use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec;
use prost::Message;
use uuid::Uuid;

use crate::stage::{proto::stage_from_proto, ExecutionStage, ExecutionStageProto};

Expand All @@ -148,6 +139,7 @@ mod tests {
}

#[tokio::test]
#[ignore]
async fn test_execution_stage_proto_round_trip() -> Result<()> {
let ctx = SessionContext::new();
let mem_table = create_mem_table();
Expand Down Expand Up @@ -196,4 +188,4 @@ mod tests {
assert_eq!(stage.name, round_trip_stage.name);
Ok(())
}
}*/
}
3 changes: 1 addition & 2 deletions src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@ use std::fmt::Display;
use std::fmt::Formatter;

use datafusion::common::internal_datafusion_err;
use prost::Message;

use datafusion::error::Result;

use url::Url;

#[derive(Clone, PartialEq, ::prost::Message)]
Expand Down
2 changes: 2 additions & 0 deletions tests/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pub mod insta;
pub mod localhost;
pub mod parquet;
pub mod plan;
pub mod tpch;
68 changes: 68 additions & 0 deletions tests/common/plan.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use datafusion::common::plan_err;
use datafusion::common::tree_node::{Transformed, TreeNode};
use datafusion::error::DataFusionError;
use datafusion::physical_expr::Partitioning;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode};
use datafusion::physical_plan::ExecutionPlan;
use datafusion_distributed::ArrowFlightReadExec;
use std::sync::Arc;

pub fn distribute_aggregate(
plan: Arc<dyn ExecutionPlan>,
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
let mut aggregate_partial_found = false;
Ok(plan
.transform_up(|node| {
let Some(agg) = node.as_any().downcast_ref::<AggregateExec>() else {
return Ok(Transformed::no(node));
};

match agg.mode() {
AggregateMode::Partial => {
if aggregate_partial_found {
return plan_err!("Two consecutive partial aggregations found");
}
aggregate_partial_found = true;
let expr = agg
.group_expr()
.expr()
.iter()
.map(|(v, _)| Arc::clone(v))
.collect::<Vec<_>>();

if node.children().len() != 1 {
return plan_err!("Aggregate must have exactly one child");
}
let child = node.children()[0].clone();

let node = node.with_new_children(vec![Arc::new(ArrowFlightReadExec::new(
Partitioning::Hash(expr, 1),
child.schema(),
0,
))])?;
Ok(Transformed::yes(node))
}
AggregateMode::Final
| AggregateMode::FinalPartitioned
| AggregateMode::Single
| AggregateMode::SinglePartitioned => {
if !aggregate_partial_found {
return plan_err!("No partial aggregate found before the final one");
}

if node.children().len() != 1 {
return plan_err!("Aggregate must have exactly one child");
}
let child = node.children()[0].clone();

let node = node.with_new_children(vec![Arc::new(ArrowFlightReadExec::new(
Partitioning::RoundRobinBatch(8),
child.schema(),
1,
))])?;
Ok(Transformed::yes(node))
}
}
})?
.data)
}
File renamed without changes.
14 changes: 8 additions & 6 deletions tests/custom_extension_codec.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#[allow(dead_code)]
mod common;
/*

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having big chunks of commented code, a more Rusty way of handling this is to just #[ignore] the tests

#[cfg(test)]
mod tests {
use crate::assert_snapshot;
Expand All @@ -27,7 +27,7 @@ mod tests {
use datafusion::physical_plan::{
displayable, execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
};
use datafusion_distributed::{assign_stages, ArrowFlightReadExec, SessionBuilder};
use datafusion_distributed::{ArrowFlightReadExec, SessionBuilder};
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use datafusion_proto::protobuf::proto_error;
use futures::{stream, TryStreamExt};
Expand All @@ -37,6 +37,7 @@ mod tests {
use std::sync::Arc;

#[tokio::test]
#[ignore]
async fn custom_extension_codec() -> Result<(), Box<dyn std::error::Error>> {
#[derive(Clone)]
struct CustomSessionBuilder;
Expand Down Expand Up @@ -66,7 +67,6 @@ mod tests {
");

let distributed_plan = build_plan(true)?;
let distributed_plan = assign_stages(distributed_plan, &ctx)?;

assert_snapshot!(displayable(distributed_plan.as_ref()).indent(true).to_string(), @r"
SortExec: expr=[numbers@0 DESC NULLS LAST], preserve_partitioning=[false]
Expand Down Expand Up @@ -124,8 +124,9 @@ mod tests {

if distributed {
plan = Arc::new(ArrowFlightReadExec::new(
plan.clone(),
Partitioning::Hash(vec![col("numbers", &plan.schema())?], 1),
plan.clone().schema(),
0, // TODO: stage num should be assigned by someone else
));
}

Expand All @@ -139,8 +140,9 @@ mod tests {

if distributed {
plan = Arc::new(ArrowFlightReadExec::new(
plan.clone(),
Partitioning::RoundRobinBatch(10),
plan.clone().schema(),
1, // TODO: stage num should be assigned by someone else
));

plan = Arc::new(RepartitionExec::try_new(
Expand Down Expand Up @@ -266,4 +268,4 @@ mod tests {
.map_err(|err| proto_error(format!("{err}")))
}
}
}*/
}
43 changes: 40 additions & 3 deletions tests/distributed_aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ mod tests {
use crate::assert_snapshot;
use crate::common::localhost::{start_localhost_context, NoopSessionBuilder};
use crate::common::parquet::register_parquet_tables;
use crate::common::plan::distribute_aggregate;
use datafusion::arrow::util::pretty::pretty_format_batches;
use datafusion::physical_plan::{displayable, execute_stream};
use futures::TryStreamExt;
use std::error::Error;

#[tokio::test]
#[ignore]
async fn distributed_aggregation() -> Result<(), Box<dyn Error>> {
// FIXME these ports are in use on my machine, we should find unused ports
// Changed them for now
Expand All @@ -26,9 +28,13 @@ mod tests {

let physical_str = displayable(physical.as_ref()).indent(true).to_string();

println!("\n\nPhysical Plan:\n{}", physical_str);
let physical_distributed = distribute_aggregate(physical.clone())?;

/*assert_snapshot!(physical_str,
let physical_distributed_str = displayable(physical_distributed.as_ref())
.indent(true)
.to_string();

assert_snapshot!(physical_str,
@r"
ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday]
SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST]
Expand All @@ -41,7 +47,24 @@ mod tests {
AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet
",
);*/
);

assert_snapshot!(physical_distributed_str,
@r"
ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday]
SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST]
SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true]
ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))]
AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
ArrowFlightReadExec: input_tasks=8 hash_expr=[] stage_id=UUID input_stage_id=UUID input_hosts=[http://localhost:50050/, http://localhost:50051/, http://localhost:50052/, http://localhost:50050/, http://localhost:50051/, http://localhost:50052/, http://localhost:50050/, http://localhost:50051/]
CoalesceBatchesExec: target_batch_size=8192
RepartitionExec: partitioning=Hash([RainToday@0], CPUs), input_partitions=CPUs
RepartitionExec: partitioning=RoundRobinBatch(CPUs), input_partitions=1
AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
ArrowFlightReadExec: input_tasks=1 hash_expr=[RainToday@0] stage_id=UUID input_stage_id=UUID input_hosts=[http://localhost:50052/]
DataSourceExec: file_groups={1 group: [[/testdata/weather.parquet]]}, projection=[RainToday], file_type=parquet
",
);

let batches = pretty_format_batches(
&execute_stream(physical, ctx.task_ctx())?
Expand All @@ -58,6 +81,20 @@ mod tests {
+----------+-----------+
");

let batches_distributed = pretty_format_batches(
&execute_stream(physical_distributed, ctx.task_ctx())?
.try_collect::<Vec<_>>()
.await?,
)?;
assert_snapshot!(batches_distributed, @r"
+----------+-----------+
| count(*) | RainToday |
+----------+-----------+
| 66 | Yes |
| 300 | No |
+----------+-----------+
");

Ok(())
}
}
12 changes: 6 additions & 6 deletions tests/error_propagation.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#[allow(dead_code)]
mod common;
/*

#[cfg(test)]
mod tests {
use crate::common::localhost::start_localhost_context;
Expand All @@ -26,6 +26,7 @@ mod tests {
use std::sync::Arc;

#[tokio::test]
#[ignore]
async fn test_error_propagation() -> Result<(), Box<dyn Error>> {
#[derive(Clone)]
struct CustomSessionBuilder;
Expand All @@ -48,14 +49,13 @@ mod tests {

let mut plan: Arc<dyn ExecutionPlan> = Arc::new(ErrorExec::new("something failed"));

for size in [1, 2, 3] {
for (i, size) in [1, 2, 3].iter().enumerate() {
plan = Arc::new(ArrowFlightReadExec::new(
Partitioning::RoundRobinBatch(size),
Partitioning::RoundRobinBatch(*size as usize),
plan.schema(),
0,
i,
));
}
let plan = assign_stages(plan, &ctx)?;
let stream = execute_stream(plan, ctx.task_ctx())?;

let Err(err) = stream.try_collect::<Vec<_>>().await else {
Expand Down Expand Up @@ -170,4 +170,4 @@ mod tests {
.map_err(|err| proto_error(format!("{err}")))
}
}
}*/
}
Loading