Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/errors/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![allow(clippy::upper_case_acronyms, clippy::vec_box)]

use crate::errors::datafusion_error::DataFusionErrorProto;
use arrow_flight::error::FlightError;
use datafusion::common::internal_datafusion_err;
use datafusion::error::DataFusionError;
use prost::Message;
Expand Down Expand Up @@ -48,3 +49,19 @@ pub fn tonic_status_to_datafusion_error(status: &tonic::Status) -> Option<DataFu
)),
}
}

/// Same as [tonic_status_to_datafusion_error] but suitable to be used in `.map_err` calls that
/// accept a [tonic::Status] error.
pub fn map_status_to_datafusion_error(err: tonic::Status) -> DataFusionError {
tonic_status_to_datafusion_error(&err)
.unwrap_or_else(|| DataFusionError::External(Box::new(err)))
}

/// Same as [tonic_status_to_datafusion_error] but suitable to be used in `.map_err` calls that
/// accept a [FlightError] error.
pub fn map_flight_to_datafusion_error(err: FlightError) -> DataFusionError {
match err {
FlightError::Tonic(status) => map_status_to_datafusion_error(*status),
err => DataFusionError::External(Box::new(err)),
}
}
21 changes: 7 additions & 14 deletions src/flight_service/do_get.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use super::service::StageKey;
use crate::common::ComposedPhysicalExtensionCodec;
use crate::config_extension_ext::ContextGrpcMetadata;
use crate::errors::datafusion_error_to_tonic_status;
use crate::flight_service::service::ArrowFlightEndpoint;
use crate::flight_service::session_builder::DistributedSessionBuilderContext;
use crate::plan::{DistributedCodec, PartitionGroup};
use crate::stage::{stage_from_proto, ExecutionStage, ExecutionStageProto};
use crate::user_codec_ext::get_distributed_user_codec;
use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::error::FlightError;
use arrow_flight::flight_service_server::FlightService;
Expand Down Expand Up @@ -115,18 +113,13 @@ impl ArrowFlightEndpoint {
.await
.map_err(|err| datafusion_error_to_tonic_status(&err))?;

let mut combined_codec = ComposedPhysicalExtensionCodec::default();
combined_codec.push(DistributedCodec);
if let Some(ref user_codec) = get_distributed_user_codec(state.config()) {
combined_codec.push_arc(Arc::clone(user_codec));
}

let stage =
stage_from_proto(stage_proto, &state, self.runtime.as_ref(), &combined_codec)
.map(Arc::new)
.map_err(|err| {
Status::invalid_argument(format!("Cannot decode stage proto: {err}"))
})?;
let codec = DistributedCodec::new_combined_with_user(state.config());

let stage = stage_from_proto(stage_proto, &state, self.runtime.as_ref(), &codec)
.map(Arc::new)
.map_err(|err| {
Status::invalid_argument(format!("Cannot decode stage proto: {err}"))
})?;

// Add the extensions that might be required for ExecutionPlan nodes in the plan
let config = state.config_mut();
Expand Down
137 changes: 42 additions & 95 deletions src/plan/arrow_flight_read.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
use super::combined::CombinedRecordBatchStream;
use crate::channel_manager_ext::get_distributed_channel_resolver;
use crate::common::ComposedPhysicalExtensionCodec;
use crate::config_extension_ext::ContextGrpcMetadata;
use crate::errors::tonic_status_to_datafusion_error;
use crate::errors::{map_flight_to_datafusion_error, map_status_to_datafusion_error};
use crate::flight_service::{DoGet, StageKey};
use crate::plan::DistributedCodec;
use crate::stage::{proto_from_stage, ExecutionStage};
use crate::user_codec_ext::get_distributed_user_codec;
use crate::ChannelResolver;
use arrow_flight::decode::FlightRecordBatchStream;
use arrow_flight::error::FlightError;
Expand All @@ -20,15 +17,14 @@ use datafusion::physical_expr::{EquivalenceProperties, Partitioning};
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use futures::{future, TryFutureExt, TryStreamExt};
use futures::{StreamExt, TryFutureExt, TryStreamExt};
use http::Extensions;
use prost::Message;
use std::any::Any;
use std::fmt::Formatter;
use std::sync::Arc;
use tonic::metadata::MetadataMap;
use tonic::Request;
use url::Url;

/// This node has two variants.
/// 1. Pending: it acts as a placeholder for the distributed optimization step to mark it as ready.
Expand Down Expand Up @@ -187,115 +183,66 @@ impl ExecutionPlan for ArrowFlightReadExec {
.session_config()
.get_extension::<ContextGrpcMetadata>();

let mut combined_codec = ComposedPhysicalExtensionCodec::default();
combined_codec.push(DistributedCodec {});
if let Some(ref user_codec) = get_distributed_user_codec(context.session_config()) {
combined_codec.push_arc(Arc::clone(user_codec));
}
let codec = DistributedCodec::new_combined_with_user(context.session_config());

let child_stage_proto = proto_from_stage(child_stage, &combined_codec).map_err(|e| {
let child_stage_proto = proto_from_stage(child_stage, &codec).map_err(|e| {
internal_datafusion_err!("ArrowFlightReadExec: failed to convert stage to proto: {e}")
})?;

let schema = child_stage.plan.schema();

let child_stage_tasks = child_stage.tasks.clone();
let child_stage_num = child_stage.num as u64;
let query_id = stage.query_id.to_string();

let stream = async move {
let futs = child_stage_tasks.iter().enumerate().map(|(i, task)| {
let child_stage_proto = child_stage_proto.clone();
let channel_resolver = channel_resolver.clone();
let schema = schema.clone();
let query_id = query_id.clone();
let flight_metadata = flight_metadata
.as_ref()
.map(|v| v.as_ref().clone())
.unwrap_or_default();
let key = StageKey {
query_id,
stage_id: child_stage_num,
task_number: i as u64,
};
async move {
let url = task.url()?.ok_or(internal_datafusion_err!(
"ArrowFlightReadExec: task is unassigned, cannot proceed"
))?;
let context_headers = flight_metadata
.as_ref()
.map(|v| v.as_ref().clone())
.unwrap_or_default();

let ticket_bytes = DoGet {
stage_proto: Some(child_stage_proto),
let stream = child_stage_tasks.into_iter().enumerate().map(|(i, task)| {
let channel_resolver = Arc::clone(&channel_resolver);

let ticket = Request::from_parts(
MetadataMap::from_headers(context_headers.0.clone()),
Extensions::default(),
Ticket {
ticket: DoGet {
stage_proto: Some(child_stage_proto.clone()),
partition: partition as u64,
stage_key: Some(key),
stage_key: Some(StageKey {
query_id: query_id.clone(),
stage_id: child_stage_num,
task_number: i as u64,
}),
task_number: i as u64,
}
.encode_to_vec()
.into();
.into(),
},
);

let ticket = Ticket {
ticket: ticket_bytes,
};
async move {
let url = task.url()?.ok_or(internal_datafusion_err!(
"ArrowFlightReadExec: task is unassigned, cannot proceed"
))?;

stream_from_stage_task(
ticket,
flight_metadata,
&url,
schema.clone(),
&channel_resolver,
)
let channel = channel_resolver.get_channel_for_url(&url).await?;
let stream = FlightServiceClient::new(channel)
.do_get(ticket)
.await
}
});
.map_err(map_status_to_datafusion_error)?
.into_inner()
.map_err(|err| FlightError::Tonic(Box::new(err)));

let streams = future::try_join_all(futs).await?;

let combined_stream = CombinedRecordBatchStream::try_new(schema, streams)?;

Ok(combined_stream)
}
.try_flatten_stream();
Ok(FlightRecordBatchStream::new_from_flight_data(stream)
.map_err(map_flight_to_datafusion_error))
}
.try_flatten_stream()
.boxed()
});

Ok(Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
stream,
futures::stream::select_all(stream),
)))
}
}

async fn stream_from_stage_task(
ticket: Ticket,
metadata: ContextGrpcMetadata,
url: &Url,
schema: SchemaRef,
channel_manager: &impl ChannelResolver,
) -> Result<SendableRecordBatchStream, DataFusionError> {
let channel = channel_manager.get_channel_for_url(url).await?;

let ticket = Request::from_parts(
MetadataMap::from_headers(metadata.0),
Extensions::default(),
ticket,
);

let mut client = FlightServiceClient::new(channel);
let stream = client
.do_get(ticket)
.await
.map_err(|err| {
tonic_status_to_datafusion_error(&err)
.unwrap_or_else(|| DataFusionError::External(Box::new(err)))
})?
.into_inner()
.map_err(|err| FlightError::Tonic(Box::new(err)));

let stream = FlightRecordBatchStream::new_from_flight_data(stream).map_err(|err| match err {
FlightError::Tonic(status) => tonic_status_to_datafusion_error(&status)
.unwrap_or_else(|| DataFusionError::External(Box::new(status))),
err => DataFusionError::External(Box::new(err)),
});

Ok(Box::pin(RecordBatchStreamAdapter::new(
schema.clone(),
stream,
)))
}
17 changes: 15 additions & 2 deletions src/plan/codec.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use super::PartitionIsolatorExec;
use crate::common::ComposedPhysicalExtensionCodec;
use crate::plan::arrow_flight_read::ArrowFlightReadExec;
use crate::user_codec_ext::get_distributed_user_codec;
use datafusion::arrow::datatypes::Schema;
use datafusion::execution::FunctionRegistry;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionConfig;
use datafusion_proto::physical_plan::from_proto::parse_protobuf_partitioning;
use datafusion_proto::physical_plan::to_proto::serialize_partitioning;
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
Expand All @@ -10,13 +14,22 @@ use datafusion_proto::protobuf::proto_error;
use prost::Message;
use std::sync::Arc;

use super::PartitionIsolatorExec;

/// DataFusion [PhysicalExtensionCodec] implementation that allows serializing and
/// deserializing the custom ExecutionPlans in this project
#[derive(Debug)]
pub struct DistributedCodec;

impl DistributedCodec {
pub fn new_combined_with_user(cfg: &SessionConfig) -> impl PhysicalExtensionCodec {
let mut combined_codec = ComposedPhysicalExtensionCodec::default();
combined_codec.push(DistributedCodec {});
if let Some(ref user_codec) = get_distributed_user_codec(cfg) {
combined_codec.push_arc(Arc::clone(user_codec));
}
combined_codec
}
}

impl PhysicalExtensionCodec for DistributedCodec {
fn try_decode(
&self,
Expand Down
80 changes: 0 additions & 80 deletions src/plan/combined.rs

This file was deleted.

1 change: 0 additions & 1 deletion src/plan/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
mod arrow_flight_read;
mod codec;
mod combined;
mod isolator;

pub use arrow_flight_read::ArrowFlightReadExec;
Expand Down
Loading