diff --git a/src/plan/arrow_flight_read.rs b/src/plan/arrow_flight_read.rs index b0c37d1..ae15fc8 100644 --- a/src/plan/arrow_flight_read.rs +++ b/src/plan/arrow_flight_read.rs @@ -1,10 +1,15 @@ +use super::combined::CombinedRecordBatchStream; use crate::channel_manager::ChannelManager; +use crate::errors::tonic_status_to_datafusion_error; use crate::flight_service::DoGet; use crate::stage::{ExecutionStage, ExecutionStageProto}; -use arrow_flight::{FlightClient, Ticket}; +use arrow_flight::decode::FlightRecordBatchStream; +use arrow_flight::error::FlightError; +use arrow_flight::flight_service_client::FlightServiceClient; +use arrow_flight::Ticket; use datafusion::arrow::datatypes::SchemaRef; use datafusion::common::{internal_datafusion_err, plan_err}; -use datafusion::error::Result; +use datafusion::error::{DataFusionError, Result}; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; @@ -15,11 +20,8 @@ use prost::Message; use std::any::Any; use std::fmt::Formatter; use std::sync::Arc; -use tonic::transport::Channel; use url::Url; -use super::combined::CombinedRecordBatchStream; - #[derive(Debug, Clone)] pub struct ArrowFlightReadExec { /// the number of the stage we are reading from @@ -125,8 +127,6 @@ impl ExecutionPlan for ArrowFlightReadExec { let schema = child_stage.plan.schema(); let stream = async move { - // concurrenly build streams for each stage - // TODO: tokio spawn instead here? let futs = child_stage_tasks.iter().map(|task| async { let url = task.url()?.ok_or(internal_datafusion_err!( "ArrowFlightReadExec: task is unassigned, cannot proceed" @@ -153,31 +153,29 @@ async fn stream_from_stage_task( ticket: Ticket, url: &Url, schema: SchemaRef, - _channel_manager: &ChannelManager, -) -> Result { - // FIXME: I cannot figure how how to use the arrow_flight::client::FlightClient (a mid level - // client) with the ChannelManager, so we willc create a new Channel directly for now + channel_manager: &ChannelManager, +) -> Result { + let channel = channel_manager.get_channel_for_url(&url).await?; - //let channel = channel_manager.get_channel_for_url(&url).await?; - - let channel = Channel::from_shared(url.to_string()) - .map_err(|e| internal_datafusion_err!("Failed to create channel from URL: {e:#?}"))? - .connect() - .await - .map_err(|e| internal_datafusion_err!("Failed to connect to channel: {e:#?}"))?; - - let mut client = FlightClient::new(channel); - - let flight_stream = client + let mut client = FlightServiceClient::new(channel); + let stream = client .do_get(ticket) .await - .map_err(|e| internal_datafusion_err!("Failed to execute do_get for ticket: {e:#?}"))?; - - let record_batch_stream = RecordBatchStreamAdapter::new( + .map_err(|err| { + tonic_status_to_datafusion_error(&err) + .unwrap_or_else(|| DataFusionError::External(Box::new(err))) + })? + .into_inner() + .map_err(|err| FlightError::Tonic(Box::new(err))); + + let stream = FlightRecordBatchStream::new_from_flight_data(stream).map_err(|err| match err { + FlightError::Tonic(status) => tonic_status_to_datafusion_error(&status) + .unwrap_or_else(|| DataFusionError::External(Box::new(status))), + err => DataFusionError::External(Box::new(err)), + }); + + Ok(Box::pin(RecordBatchStreamAdapter::new( schema.clone(), - flight_stream - .map_err(|e| internal_datafusion_err!("Failed to decode flight stream: {e:#?}")), - ); - - Ok(Box::pin(record_batch_stream) as SendableRecordBatchStream) + stream, + ))) }