diff --git a/src/errors/mod.rs b/src/errors/mod.rs index e310756..ea850df 100644 --- a/src/errors/mod.rs +++ b/src/errors/mod.rs @@ -1,6 +1,7 @@ #![allow(clippy::upper_case_acronyms, clippy::vec_box)] use crate::errors::datafusion_error::DataFusionErrorProto; +use arrow_flight::error::FlightError; use datafusion::common::internal_datafusion_err; use datafusion::error::DataFusionError; use prost::Message; @@ -48,3 +49,19 @@ pub fn tonic_status_to_datafusion_error(status: &tonic::Status) -> Option DataFusionError { + tonic_status_to_datafusion_error(&err) + .unwrap_or_else(|| DataFusionError::External(Box::new(err))) +} + +/// Same as [tonic_status_to_datafusion_error] but suitable to be used in `.map_err` calls that +/// accept a [FlightError] error. +pub fn map_flight_to_datafusion_error(err: FlightError) -> DataFusionError { + match err { + FlightError::Tonic(status) => map_status_to_datafusion_error(*status), + err => DataFusionError::External(Box::new(err)), + } +} diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index adc7d74..8ae32e0 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -1,12 +1,10 @@ use super::service::StageKey; -use crate::common::ComposedPhysicalExtensionCodec; use crate::config_extension_ext::ContextGrpcMetadata; use crate::errors::datafusion_error_to_tonic_status; use crate::flight_service::service::ArrowFlightEndpoint; use crate::flight_service::session_builder::DistributedSessionBuilderContext; use crate::plan::{DistributedCodec, PartitionGroup}; use crate::stage::{stage_from_proto, ExecutionStage, ExecutionStageProto}; -use crate::user_codec_ext::get_distributed_user_codec; use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; use arrow_flight::flight_service_server::FlightService; @@ -115,18 +113,13 @@ impl ArrowFlightEndpoint { .await .map_err(|err| datafusion_error_to_tonic_status(&err))?; - let mut combined_codec = ComposedPhysicalExtensionCodec::default(); - combined_codec.push(DistributedCodec); - if let Some(ref user_codec) = get_distributed_user_codec(state.config()) { - combined_codec.push_arc(Arc::clone(user_codec)); - } - - let stage = - stage_from_proto(stage_proto, &state, self.runtime.as_ref(), &combined_codec) - .map(Arc::new) - .map_err(|err| { - Status::invalid_argument(format!("Cannot decode stage proto: {err}")) - })?; + let codec = DistributedCodec::new_combined_with_user(state.config()); + + let stage = stage_from_proto(stage_proto, &state, self.runtime.as_ref(), &codec) + .map(Arc::new) + .map_err(|err| { + Status::invalid_argument(format!("Cannot decode stage proto: {err}")) + })?; // Add the extensions that might be required for ExecutionPlan nodes in the plan let config = state.config_mut(); diff --git a/src/plan/arrow_flight_read.rs b/src/plan/arrow_flight_read.rs index 04cf6d9..44d5f9b 100644 --- a/src/plan/arrow_flight_read.rs +++ b/src/plan/arrow_flight_read.rs @@ -1,12 +1,9 @@ -use super::combined::CombinedRecordBatchStream; use crate::channel_manager_ext::get_distributed_channel_resolver; -use crate::common::ComposedPhysicalExtensionCodec; use crate::config_extension_ext::ContextGrpcMetadata; -use crate::errors::tonic_status_to_datafusion_error; +use crate::errors::{map_flight_to_datafusion_error, map_status_to_datafusion_error}; use crate::flight_service::{DoGet, StageKey}; use crate::plan::DistributedCodec; use crate::stage::{proto_from_stage, ExecutionStage}; -use crate::user_codec_ext::get_distributed_user_codec; use crate::ChannelResolver; use arrow_flight::decode::FlightRecordBatchStream; use arrow_flight::error::FlightError; @@ -20,7 +17,7 @@ use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; -use futures::{future, TryFutureExt, TryStreamExt}; +use futures::{StreamExt, TryFutureExt, TryStreamExt}; use http::Extensions; use prost::Message; use std::any::Any; @@ -28,7 +25,6 @@ use std::fmt::Formatter; use std::sync::Arc; use tonic::metadata::MetadataMap; use tonic::Request; -use url::Url; /// This node has two variants. /// 1. Pending: it acts as a placeholder for the distributed optimization step to mark it as ready. @@ -187,115 +183,66 @@ impl ExecutionPlan for ArrowFlightReadExec { .session_config() .get_extension::(); - let mut combined_codec = ComposedPhysicalExtensionCodec::default(); - combined_codec.push(DistributedCodec {}); - if let Some(ref user_codec) = get_distributed_user_codec(context.session_config()) { - combined_codec.push_arc(Arc::clone(user_codec)); - } + let codec = DistributedCodec::new_combined_with_user(context.session_config()); - let child_stage_proto = proto_from_stage(child_stage, &combined_codec).map_err(|e| { + let child_stage_proto = proto_from_stage(child_stage, &codec).map_err(|e| { internal_datafusion_err!("ArrowFlightReadExec: failed to convert stage to proto: {e}") })?; - let schema = child_stage.plan.schema(); - let child_stage_tasks = child_stage.tasks.clone(); let child_stage_num = child_stage.num as u64; let query_id = stage.query_id.to_string(); - let stream = async move { - let futs = child_stage_tasks.iter().enumerate().map(|(i, task)| { - let child_stage_proto = child_stage_proto.clone(); - let channel_resolver = channel_resolver.clone(); - let schema = schema.clone(); - let query_id = query_id.clone(); - let flight_metadata = flight_metadata - .as_ref() - .map(|v| v.as_ref().clone()) - .unwrap_or_default(); - let key = StageKey { - query_id, - stage_id: child_stage_num, - task_number: i as u64, - }; - async move { - let url = task.url()?.ok_or(internal_datafusion_err!( - "ArrowFlightReadExec: task is unassigned, cannot proceed" - ))?; + let context_headers = flight_metadata + .as_ref() + .map(|v| v.as_ref().clone()) + .unwrap_or_default(); - let ticket_bytes = DoGet { - stage_proto: Some(child_stage_proto), + let stream = child_stage_tasks.into_iter().enumerate().map(|(i, task)| { + let channel_resolver = Arc::clone(&channel_resolver); + + let ticket = Request::from_parts( + MetadataMap::from_headers(context_headers.0.clone()), + Extensions::default(), + Ticket { + ticket: DoGet { + stage_proto: Some(child_stage_proto.clone()), partition: partition as u64, - stage_key: Some(key), + stage_key: Some(StageKey { + query_id: query_id.clone(), + stage_id: child_stage_num, + task_number: i as u64, + }), task_number: i as u64, } .encode_to_vec() - .into(); + .into(), + }, + ); - let ticket = Ticket { - ticket: ticket_bytes, - }; + async move { + let url = task.url()?.ok_or(internal_datafusion_err!( + "ArrowFlightReadExec: task is unassigned, cannot proceed" + ))?; - stream_from_stage_task( - ticket, - flight_metadata, - &url, - schema.clone(), - &channel_resolver, - ) + let channel = channel_resolver.get_channel_for_url(&url).await?; + let stream = FlightServiceClient::new(channel) + .do_get(ticket) .await - } - }); + .map_err(map_status_to_datafusion_error)? + .into_inner() + .map_err(|err| FlightError::Tonic(Box::new(err))); - let streams = future::try_join_all(futs).await?; - - let combined_stream = CombinedRecordBatchStream::try_new(schema, streams)?; - - Ok(combined_stream) - } - .try_flatten_stream(); + Ok(FlightRecordBatchStream::new_from_flight_data(stream) + .map_err(map_flight_to_datafusion_error)) + } + .try_flatten_stream() + .boxed() + }); Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), - stream, + futures::stream::select_all(stream), ))) } } - -async fn stream_from_stage_task( - ticket: Ticket, - metadata: ContextGrpcMetadata, - url: &Url, - schema: SchemaRef, - channel_manager: &impl ChannelResolver, -) -> Result { - let channel = channel_manager.get_channel_for_url(url).await?; - - let ticket = Request::from_parts( - MetadataMap::from_headers(metadata.0), - Extensions::default(), - ticket, - ); - - let mut client = FlightServiceClient::new(channel); - let stream = client - .do_get(ticket) - .await - .map_err(|err| { - tonic_status_to_datafusion_error(&err) - .unwrap_or_else(|| DataFusionError::External(Box::new(err))) - })? - .into_inner() - .map_err(|err| FlightError::Tonic(Box::new(err))); - - let stream = FlightRecordBatchStream::new_from_flight_data(stream).map_err(|err| match err { - FlightError::Tonic(status) => tonic_status_to_datafusion_error(&status) - .unwrap_or_else(|| DataFusionError::External(Box::new(status))), - err => DataFusionError::External(Box::new(err)), - }); - - Ok(Box::pin(RecordBatchStreamAdapter::new( - schema.clone(), - stream, - ))) -} diff --git a/src/plan/codec.rs b/src/plan/codec.rs index 793de1f..48abe72 100644 --- a/src/plan/codec.rs +++ b/src/plan/codec.rs @@ -1,7 +1,11 @@ +use super::PartitionIsolatorExec; +use crate::common::ComposedPhysicalExtensionCodec; use crate::plan::arrow_flight_read::ArrowFlightReadExec; +use crate::user_codec_ext::get_distributed_user_codec; use datafusion::arrow::datatypes::Schema; use datafusion::execution::FunctionRegistry; use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionConfig; use datafusion_proto::physical_plan::from_proto::parse_protobuf_partitioning; use datafusion_proto::physical_plan::to_proto::serialize_partitioning; use datafusion_proto::physical_plan::PhysicalExtensionCodec; @@ -10,13 +14,22 @@ use datafusion_proto::protobuf::proto_error; use prost::Message; use std::sync::Arc; -use super::PartitionIsolatorExec; - /// DataFusion [PhysicalExtensionCodec] implementation that allows serializing and /// deserializing the custom ExecutionPlans in this project #[derive(Debug)] pub struct DistributedCodec; +impl DistributedCodec { + pub fn new_combined_with_user(cfg: &SessionConfig) -> impl PhysicalExtensionCodec { + let mut combined_codec = ComposedPhysicalExtensionCodec::default(); + combined_codec.push(DistributedCodec {}); + if let Some(ref user_codec) = get_distributed_user_codec(cfg) { + combined_codec.push_arc(Arc::clone(user_codec)); + } + combined_codec + } +} + impl PhysicalExtensionCodec for DistributedCodec { fn try_decode( &self, diff --git a/src/plan/combined.rs b/src/plan/combined.rs deleted file mode 100644 index ffa3f15..0000000 --- a/src/plan/combined.rs +++ /dev/null @@ -1,80 +0,0 @@ -use std::{ - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; - -use datafusion::error::Result; -use datafusion::{ - arrow::{array::RecordBatch, datatypes::SchemaRef}, - common::internal_err, - error::DataFusionError, - execution::{RecordBatchStream, SendableRecordBatchStream}, -}; -use futures::Stream; - -pub(crate) struct CombinedRecordBatchStream { - /// Schema wrapped by Arc - schema: SchemaRef, - /// Stream entries - entries: Vec, -} - -impl CombinedRecordBatchStream { - /// Create an CombinedRecordBatchStream - pub fn try_new(schema: SchemaRef, entries: Vec) -> Result { - if entries.is_empty() { - return internal_err!("Cannot create CombinedRecordBatchStream with no entries"); - } - Ok(Self { schema, entries }) - } -} - -impl RecordBatchStream for CombinedRecordBatchStream { - fn schema(&self) -> SchemaRef { - Arc::clone(&self.schema) - } -} - -impl Stream for CombinedRecordBatchStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - use Poll::*; - - let start = 0; - let mut idx = start; - - for _ in 0..self.entries.len() { - let stream = self.entries.get_mut(idx).unwrap(); - - match Pin::new(stream).poll_next(cx) { - Ready(Some(val)) => return Ready(Some(val)), - Ready(None) => { - // Remove the entry - self.entries.swap_remove(idx); - - // Check if this was the last entry, if so the cursor needs - // to wrap - if idx == self.entries.len() { - idx = 0; - } else if idx < start && start <= self.entries.len() { - // The stream being swapped into the current index has - // already been polled, so skip it. - idx = idx.wrapping_add(1) % self.entries.len(); - } - } - Pending => { - idx = idx.wrapping_add(1) % self.entries.len(); - } - } - } - - // If the map is empty, then the stream is complete. - if self.entries.is_empty() { - Ready(None) - } else { - Pending - } - } -} diff --git a/src/plan/mod.rs b/src/plan/mod.rs index 21b6ce5..7d2a98e 100644 --- a/src/plan/mod.rs +++ b/src/plan/mod.rs @@ -1,6 +1,5 @@ mod arrow_flight_read; mod codec; -mod combined; mod isolator; pub use arrow_flight_read::ArrowFlightReadExec;