Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
pub mod result;
pub mod util;
1 change: 0 additions & 1 deletion src/common/result.rs

This file was deleted.

1 change: 0 additions & 1 deletion src/flight_service/do_get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use arrow_flight::flight_service_server::FlightService;
use arrow_flight::Ticket;
use datafusion::execution::SessionStateBuilder;
use datafusion::optimizer::OptimizerConfig;
use datafusion::physical_plan::ExecutionPlan;
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use futures::TryStreamExt;
use prost::Message;
Expand Down
1 change: 0 additions & 1 deletion src/physical_optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use datafusion::{
tree_node::{Transformed, TreeNode, TreeNodeRewriter},
},
config::ConfigOptions,
datasource::physical_plan::FileSource,
error::Result,
physical_optimizer::PhysicalOptimizerRule,
physical_plan::{
Expand Down
60 changes: 29 additions & 31 deletions src/plan/arrow_flight_read.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
use super::combined::CombinedRecordBatchStream;
use crate::channel_manager::ChannelManager;
use crate::errors::tonic_status_to_datafusion_error;
use crate::flight_service::DoGet;
use crate::stage::{ExecutionStage, ExecutionStageProto};
use arrow_flight::{FlightClient, Ticket};
use arrow_flight::decode::FlightRecordBatchStream;
use arrow_flight::error::FlightError;
use arrow_flight::flight_service_client::FlightServiceClient;
use arrow_flight::Ticket;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::common::{internal_datafusion_err, plan_err};
use datafusion::error::Result;
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
use datafusion::error::{DataFusionError, Result};
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_expr::{EquivalenceProperties, Partitioning};
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
Expand All @@ -15,11 +20,8 @@ use prost::Message;
use std::any::Any;
use std::fmt::Formatter;
use std::sync::Arc;
use tonic::transport::Channel;
use url::Url;

use super::combined::CombinedRecordBatchStream;

#[derive(Debug, Clone)]
pub struct ArrowFlightReadExec {
/// the number of the stage we are reading from
Expand Down Expand Up @@ -125,8 +127,6 @@ impl ExecutionPlan for ArrowFlightReadExec {
let schema = child_stage.plan.schema();

let stream = async move {
// concurrenly build streams for each stage
// TODO: tokio spawn instead here?
Comment on lines -128 to -129
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine. As this is pure IO work, I think it's fair to not spawn it and just join it in a single thread

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my thought too, but I left the comment so that we'd get an additional perspective. Ty 👍

let futs = child_stage_tasks.iter().map(|task| async {
let url = task.url()?.ok_or(internal_datafusion_err!(
"ArrowFlightReadExec: task is unassigned, cannot proceed"
Expand All @@ -153,31 +153,29 @@ async fn stream_from_stage_task(
ticket: Ticket,
url: &Url,
schema: SchemaRef,
_channel_manager: &ChannelManager,
) -> Result<SendableRecordBatchStream> {
// FIXME: I cannot figure how how to use the arrow_flight::client::FlightClient (a mid level
// client) with the ChannelManager, so we willc create a new Channel directly for now
channel_manager: &ChannelManager,
) -> Result<SendableRecordBatchStream, DataFusionError> {
let channel = channel_manager.get_channel_for_url(&url).await?;

//let channel = channel_manager.get_channel_for_url(&url).await?;

let channel = Channel::from_shared(url.to_string())
.map_err(|e| internal_datafusion_err!("Failed to create channel from URL: {e:#?}"))?
.connect()
.await
.map_err(|e| internal_datafusion_err!("Failed to connect to channel: {e:#?}"))?;

let mut client = FlightClient::new(channel);

let flight_stream = client
let mut client = FlightServiceClient::new(channel);
let stream = client
.do_get(ticket)
.await
.map_err(|e| internal_datafusion_err!("Failed to execute do_get for ticket: {e:#?}"))?;

let record_batch_stream = RecordBatchStreamAdapter::new(
.map_err(|err| {
tonic_status_to_datafusion_error(&err)
.unwrap_or_else(|| DataFusionError::External(Box::new(err)))
})?
.into_inner()
.map_err(|err| FlightError::Tonic(Box::new(err)));

let stream = FlightRecordBatchStream::new_from_flight_data(stream).map_err(|err| match err {
FlightError::Tonic(status) => tonic_status_to_datafusion_error(&status)
.unwrap_or_else(|| DataFusionError::External(Box::new(status))),
err => DataFusionError::External(Box::new(err)),
});

Ok(Box::pin(RecordBatchStreamAdapter::new(
schema.clone(),
flight_stream
.map_err(|e| internal_datafusion_err!("Failed to decode flight stream: {e:#?}")),
);

Ok(Box::pin(record_batch_stream) as SendableRecordBatchStream)
stream,
)))
}
2 changes: 1 addition & 1 deletion src/stage/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use std::fmt::Write;

use datafusion::{
error::Result,
physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan},
physical_plan::{DisplayAs, DisplayFormatType},
};

use crate::{
Expand Down
18 changes: 5 additions & 13 deletions src/stage/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use datafusion_proto::{
physical_plan::{AsExecutionPlan, PhysicalExtensionCodec},
protobuf::PhysicalPlanNode,
};
use prost::Message;

use crate::{plan::DistributedCodec, task::ExecutionTask};

Expand Down Expand Up @@ -100,9 +99,7 @@ pub fn stage_from_proto(
}

// add tests for round trip to and from a proto message for ExecutionStage
/* TODO: broken for now
#[cfg(test)]

mod tests {
use std::sync::Arc;

Expand All @@ -111,19 +108,13 @@ mod tests {
array::{RecordBatch, StringArray, UInt8Array},
datatypes::{DataType, Field, Schema},
},
catalog::memory::DataSourceExec,
common::{internal_datafusion_err, internal_err},
common::internal_datafusion_err,
datasource::MemTable,
error::{DataFusionError, Result},
error::Result,
execution::context::SessionContext,
prelude::SessionConfig,
};
use datafusion_proto::{
physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec},
protobuf::PhysicalPlanNode,
};
use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec;
use prost::Message;
use uuid::Uuid;

use crate::stage::{proto::stage_from_proto, ExecutionStage, ExecutionStageProto};

Expand All @@ -148,6 +139,7 @@ mod tests {
}

#[tokio::test]
#[ignore]
async fn test_execution_stage_proto_round_trip() -> Result<()> {
let ctx = SessionContext::new();
let mem_table = create_mem_table();
Expand Down Expand Up @@ -196,4 +188,4 @@ mod tests {
assert_eq!(stage.name, round_trip_stage.name);
Ok(())
}
}*/
}
3 changes: 1 addition & 2 deletions src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@ use std::fmt::Display;
use std::fmt::Formatter;

use datafusion::common::internal_datafusion_err;
use prost::Message;

use datafusion::error::Result;

use url::Url;

#[derive(Clone, PartialEq, ::prost::Message)]
Expand Down
2 changes: 2 additions & 0 deletions tests/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pub mod insta;
pub mod localhost;
pub mod parquet;
pub mod plan;
pub mod tpch;
68 changes: 68 additions & 0 deletions tests/common/plan.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use datafusion::common::plan_err;
use datafusion::common::tree_node::{Transformed, TreeNode};
use datafusion::error::DataFusionError;
use datafusion::physical_expr::Partitioning;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode};
use datafusion::physical_plan::ExecutionPlan;
use datafusion_distributed::ArrowFlightReadExec;
use std::sync::Arc;

pub fn distribute_aggregate(
plan: Arc<dyn ExecutionPlan>,
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
let mut aggregate_partial_found = false;
Ok(plan
.transform_up(|node| {
let Some(agg) = node.as_any().downcast_ref::<AggregateExec>() else {
return Ok(Transformed::no(node));
};

match agg.mode() {
AggregateMode::Partial => {
if aggregate_partial_found {
return plan_err!("Two consecutive partial aggregations found");
}
aggregate_partial_found = true;
let expr = agg
.group_expr()
.expr()
.iter()
.map(|(v, _)| Arc::clone(v))
.collect::<Vec<_>>();

if node.children().len() != 1 {
return plan_err!("Aggregate must have exactly one child");
}
let child = node.children()[0].clone();

let node = node.with_new_children(vec![Arc::new(ArrowFlightReadExec::new(
Partitioning::Hash(expr, 1),
child.schema(),
0,
))])?;
Ok(Transformed::yes(node))
}
AggregateMode::Final
| AggregateMode::FinalPartitioned
| AggregateMode::Single
| AggregateMode::SinglePartitioned => {
if !aggregate_partial_found {
return plan_err!("No partial aggregate found before the final one");
}

if node.children().len() != 1 {
return plan_err!("Aggregate must have exactly one child");
}
let child = node.children()[0].clone();

let node = node.with_new_children(vec![Arc::new(ArrowFlightReadExec::new(
Partitioning::RoundRobinBatch(8),
child.schema(),
1,
))])?;
Ok(Transformed::yes(node))
}
}
})?
.data)
}
File renamed without changes.
14 changes: 8 additions & 6 deletions tests/custom_extension_codec.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#[allow(dead_code)]
mod common;
/*

#[cfg(test)]
mod tests {
use crate::assert_snapshot;
Expand All @@ -27,7 +27,7 @@ mod tests {
use datafusion::physical_plan::{
displayable, execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
};
use datafusion_distributed::{assign_stages, ArrowFlightReadExec, SessionBuilder};
use datafusion_distributed::{ArrowFlightReadExec, SessionBuilder};
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use datafusion_proto::protobuf::proto_error;
use futures::{stream, TryStreamExt};
Expand All @@ -37,6 +37,7 @@ mod tests {
use std::sync::Arc;

#[tokio::test]
#[ignore]
async fn custom_extension_codec() -> Result<(), Box<dyn std::error::Error>> {
#[derive(Clone)]
struct CustomSessionBuilder;
Expand Down Expand Up @@ -66,7 +67,6 @@ mod tests {
");

let distributed_plan = build_plan(true)?;
let distributed_plan = assign_stages(distributed_plan, &ctx)?;

assert_snapshot!(displayable(distributed_plan.as_ref()).indent(true).to_string(), @r"
SortExec: expr=[numbers@0 DESC NULLS LAST], preserve_partitioning=[false]
Expand Down Expand Up @@ -124,8 +124,9 @@ mod tests {

if distributed {
plan = Arc::new(ArrowFlightReadExec::new(
plan.clone(),
Partitioning::Hash(vec![col("numbers", &plan.schema())?], 1),
plan.clone().schema(),
0, // TODO: stage num should be assigned by someone else
));
}

Expand All @@ -139,8 +140,9 @@ mod tests {

if distributed {
plan = Arc::new(ArrowFlightReadExec::new(
plan.clone(),
Partitioning::RoundRobinBatch(10),
plan.clone().schema(),
1, // TODO: stage num should be assigned by someone else
));

plan = Arc::new(RepartitionExec::try_new(
Expand Down Expand Up @@ -266,4 +268,4 @@ mod tests {
.map_err(|err| proto_error(format!("{err}")))
}
}
}*/
}
Loading