Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
pub mod result;
pub mod util;
1 change: 0 additions & 1 deletion src/common/result.rs

This file was deleted.

1 change: 0 additions & 1 deletion src/flight_service/do_get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use arrow_flight::flight_service_server::FlightService;
use arrow_flight::Ticket;
use datafusion::execution::SessionStateBuilder;
use datafusion::optimizer::OptimizerConfig;
use datafusion::physical_plan::ExecutionPlan;
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use futures::TryStreamExt;
use prost::Message;
Expand Down
278 changes: 277 additions & 1 deletion src/physical_optimizer.rs

Large diffs are not rendered by default.

60 changes: 29 additions & 31 deletions src/plan/arrow_flight_read.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
use super::combined::CombinedRecordBatchStream;
use crate::channel_manager::ChannelManager;
use crate::errors::tonic_status_to_datafusion_error;
use crate::flight_service::DoGet;
use crate::stage::{ExecutionStage, ExecutionStageProto};
use arrow_flight::{FlightClient, Ticket};
use arrow_flight::decode::FlightRecordBatchStream;
use arrow_flight::error::FlightError;
use arrow_flight::flight_service_client::FlightServiceClient;
use arrow_flight::Ticket;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::common::{internal_datafusion_err, plan_err};
use datafusion::error::Result;
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
use datafusion::error::{DataFusionError, Result};
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_expr::{EquivalenceProperties, Partitioning};
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
Expand All @@ -15,11 +20,8 @@ use prost::Message;
use std::any::Any;
use std::fmt::Formatter;
use std::sync::Arc;
use tonic::transport::Channel;
use url::Url;

use super::combined::CombinedRecordBatchStream;

#[derive(Debug, Clone)]
pub struct ArrowFlightReadExec {
/// the number of the stage we are reading from
Expand Down Expand Up @@ -125,8 +127,6 @@ impl ExecutionPlan for ArrowFlightReadExec {
let schema = child_stage.plan.schema();

let stream = async move {
// concurrenly build streams for each stage
// TODO: tokio spawn instead here?
let futs = child_stage_tasks.iter().map(|task| async {
let url = task.url()?.ok_or(internal_datafusion_err!(
"ArrowFlightReadExec: task is unassigned, cannot proceed"
Expand All @@ -153,31 +153,29 @@ async fn stream_from_stage_task(
ticket: Ticket,
url: &Url,
schema: SchemaRef,
_channel_manager: &ChannelManager,
) -> Result<SendableRecordBatchStream> {
// FIXME: I cannot figure how how to use the arrow_flight::client::FlightClient (a mid level
// client) with the ChannelManager, so we willc create a new Channel directly for now
channel_manager: &ChannelManager,
) -> Result<SendableRecordBatchStream, DataFusionError> {
let channel = channel_manager.get_channel_for_url(&url).await?;

//let channel = channel_manager.get_channel_for_url(&url).await?;

let channel = Channel::from_shared(url.to_string())
.map_err(|e| internal_datafusion_err!("Failed to create channel from URL: {e:#?}"))?
.connect()
.await
.map_err(|e| internal_datafusion_err!("Failed to connect to channel: {e:#?}"))?;

let mut client = FlightClient::new(channel);

let flight_stream = client
let mut client = FlightServiceClient::new(channel);
let stream = client
.do_get(ticket)
.await
.map_err(|e| internal_datafusion_err!("Failed to execute do_get for ticket: {e:#?}"))?;

let record_batch_stream = RecordBatchStreamAdapter::new(
.map_err(|err| {
tonic_status_to_datafusion_error(&err)
.unwrap_or_else(|| DataFusionError::External(Box::new(err)))
})?
.into_inner()
.map_err(|err| FlightError::Tonic(Box::new(err)));

let stream = FlightRecordBatchStream::new_from_flight_data(stream).map_err(|err| match err {
FlightError::Tonic(status) => tonic_status_to_datafusion_error(&status)
.unwrap_or_else(|| DataFusionError::External(Box::new(status))),
err => DataFusionError::External(Box::new(err)),
});

Ok(Box::pin(RecordBatchStreamAdapter::new(
schema.clone(),
flight_stream
.map_err(|e| internal_datafusion_err!("Failed to decode flight stream: {e:#?}")),
);

Ok(Box::pin(record_batch_stream) as SendableRecordBatchStream)
stream,
)))
}
2 changes: 1 addition & 1 deletion src/stage/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use std::fmt::Write;

use datafusion::{
error::Result,
physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan},
physical_plan::{DisplayAs, DisplayFormatType},
};

use crate::{
Expand Down
18 changes: 5 additions & 13 deletions src/stage/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use datafusion_proto::{
physical_plan::{AsExecutionPlan, PhysicalExtensionCodec},
protobuf::PhysicalPlanNode,
};
use prost::Message;

use crate::{plan::DistributedCodec, task::ExecutionTask};

Expand Down Expand Up @@ -100,9 +99,7 @@ pub fn stage_from_proto(
}

// add tests for round trip to and from a proto message for ExecutionStage
/* TODO: broken for now
#[cfg(test)]

mod tests {
use std::sync::Arc;

Expand All @@ -111,19 +108,13 @@ mod tests {
array::{RecordBatch, StringArray, UInt8Array},
datatypes::{DataType, Field, Schema},
},
catalog::memory::DataSourceExec,
common::{internal_datafusion_err, internal_err},
common::internal_datafusion_err,
datasource::MemTable,
error::{DataFusionError, Result},
error::Result,
execution::context::SessionContext,
prelude::SessionConfig,
};
use datafusion_proto::{
physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec},
protobuf::PhysicalPlanNode,
};
use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec;
use prost::Message;
use uuid::Uuid;

use crate::stage::{proto::stage_from_proto, ExecutionStage, ExecutionStageProto};

Expand All @@ -148,6 +139,7 @@ mod tests {
}

#[tokio::test]
#[ignore]
async fn test_execution_stage_proto_round_trip() -> Result<()> {
let ctx = SessionContext::new();
let mem_table = create_mem_table();
Expand Down Expand Up @@ -196,4 +188,4 @@ mod tests {
assert_eq!(stage.name, round_trip_stage.name);
Ok(())
}
}*/
}
3 changes: 1 addition & 2 deletions src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@ use std::fmt::Display;
use std::fmt::Formatter;

use datafusion::common::internal_datafusion_err;
use prost::Message;

use datafusion::error::Result;

use url::Url;

#[derive(Clone, PartialEq, ::prost::Message)]
Expand Down
24 changes: 24 additions & 0 deletions src/test_utils/insta.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use std::env;

#[macro_export]
macro_rules! assert_snapshot {
($($arg:tt)*) => {
crate::test_utils::insta::settings().bind(|| {
insta::assert_snapshot!($($arg)*);
})
};
}

pub fn settings() -> insta::Settings {
env::set_var("INSTA_WORKSPACE_ROOT", env!("CARGO_MANIFEST_DIR"));
let mut settings = insta::Settings::clone_current();
let cwd = env::current_dir().unwrap();
let cwd = cwd.to_str().unwrap();
settings.add_filter(cwd.trim_start_matches("/"), "");
settings.add_filter(
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}",
"UUID",
);

settings
}
5 changes: 3 additions & 2 deletions src/test_utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#[cfg(test)]
pub mod insta;
mod mock_exec;
mod parquet;

#[cfg(test)]
pub use mock_exec::MockExec;
pub use parquet::register_parquet_tables;
20 changes: 20 additions & 0 deletions src/test_utils/parquet.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use datafusion::error::DataFusionError;
use datafusion::prelude::{ParquetReadOptions, SessionContext};

pub async fn register_parquet_tables(ctx: &SessionContext) -> Result<(), DataFusionError> {
ctx.register_parquet(
"flights_1m",
"testdata/flights-1m.parquet",
ParquetReadOptions::default(),
)
.await?;

ctx.register_parquet(
"weather",
"testdata/weather.parquet",
ParquetReadOptions::default(),
)
.await?;

Ok(())
}
2 changes: 2 additions & 0 deletions tests/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pub mod insta;
pub mod localhost;
pub mod parquet;
pub mod plan;
pub mod tpch;
68 changes: 68 additions & 0 deletions tests/common/plan.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use datafusion::common::plan_err;
use datafusion::common::tree_node::{Transformed, TreeNode};
use datafusion::error::DataFusionError;
use datafusion::physical_expr::Partitioning;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode};
use datafusion::physical_plan::ExecutionPlan;
use datafusion_distributed::ArrowFlightReadExec;
use std::sync::Arc;

pub fn distribute_aggregate(
plan: Arc<dyn ExecutionPlan>,
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
let mut aggregate_partial_found = false;
Ok(plan
.transform_up(|node| {
let Some(agg) = node.as_any().downcast_ref::<AggregateExec>() else {
return Ok(Transformed::no(node));
};

match agg.mode() {
AggregateMode::Partial => {
if aggregate_partial_found {
return plan_err!("Two consecutive partial aggregations found");
}
aggregate_partial_found = true;
let expr = agg
.group_expr()
.expr()
.iter()
.map(|(v, _)| Arc::clone(v))
.collect::<Vec<_>>();

if node.children().len() != 1 {
return plan_err!("Aggregate must have exactly one child");
}
let child = node.children()[0].clone();

let node = node.with_new_children(vec![Arc::new(ArrowFlightReadExec::new(
Partitioning::Hash(expr, 1),
child.schema(),
0,
))])?;
Ok(Transformed::yes(node))
}
AggregateMode::Final
| AggregateMode::FinalPartitioned
| AggregateMode::Single
| AggregateMode::SinglePartitioned => {
if !aggregate_partial_found {
return plan_err!("No partial aggregate found before the final one");
}

if node.children().len() != 1 {
return plan_err!("Aggregate must have exactly one child");
}
let child = node.children()[0].clone();

let node = node.with_new_children(vec![Arc::new(ArrowFlightReadExec::new(
Partitioning::RoundRobinBatch(8),
child.schema(),
1,
))])?;
Ok(Transformed::yes(node))
}
}
})?
.data)
}
File renamed without changes.
14 changes: 8 additions & 6 deletions tests/custom_extension_codec.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#[allow(dead_code)]
mod common;
/*

#[cfg(test)]
mod tests {
use crate::assert_snapshot;
Expand All @@ -27,7 +27,7 @@ mod tests {
use datafusion::physical_plan::{
displayable, execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
};
use datafusion_distributed::{assign_stages, ArrowFlightReadExec, SessionBuilder};
use datafusion_distributed::{ArrowFlightReadExec, SessionBuilder};
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use datafusion_proto::protobuf::proto_error;
use futures::{stream, TryStreamExt};
Expand All @@ -37,6 +37,7 @@ mod tests {
use std::sync::Arc;

#[tokio::test]
#[ignore]
async fn custom_extension_codec() -> Result<(), Box<dyn std::error::Error>> {
#[derive(Clone)]
struct CustomSessionBuilder;
Expand Down Expand Up @@ -66,7 +67,6 @@ mod tests {
");

let distributed_plan = build_plan(true)?;
let distributed_plan = assign_stages(distributed_plan, &ctx)?;

assert_snapshot!(displayable(distributed_plan.as_ref()).indent(true).to_string(), @r"
SortExec: expr=[numbers@0 DESC NULLS LAST], preserve_partitioning=[false]
Expand Down Expand Up @@ -124,8 +124,9 @@ mod tests {

if distributed {
plan = Arc::new(ArrowFlightReadExec::new(
plan.clone(),
Partitioning::Hash(vec![col("numbers", &plan.schema())?], 1),
plan.clone().schema(),
0, // TODO: stage num should be assigned by someone else
));
}

Expand All @@ -139,8 +140,9 @@ mod tests {

if distributed {
plan = Arc::new(ArrowFlightReadExec::new(
plan.clone(),
Partitioning::RoundRobinBatch(10),
plan.clone().schema(),
1, // TODO: stage num should be assigned by someone else
));

plan = Arc::new(RepartitionExec::try_new(
Expand Down Expand Up @@ -266,4 +268,4 @@ mod tests {
.map_err(|err| proto_error(format!("{err}")))
}
}
}*/
}
Loading