Skip to content

Commit 49af3ae

Browse files
authored
Fix ArrowFlightReadExec result streaming (#77)
* Uncomment tests in favor of just #[ignore]-ing them * Remove unused import statements * Move common tpch module to common * Unignore one more test * Fix ArrowFlightReadExec
1 parent c32a86f commit 49af3ae

File tree

1 file changed

+28
-30
lines changed

1 file changed

+28
-30
lines changed

src/plan/arrow_flight_read.rs

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
use super::combined::CombinedRecordBatchStream;
12
use crate::channel_manager::ChannelManager;
3+
use crate::errors::tonic_status_to_datafusion_error;
24
use crate::flight_service::DoGet;
35
use crate::stage::{ExecutionStage, ExecutionStageProto};
4-
use arrow_flight::{FlightClient, Ticket};
6+
use arrow_flight::decode::FlightRecordBatchStream;
7+
use arrow_flight::error::FlightError;
8+
use arrow_flight::flight_service_client::FlightServiceClient;
9+
use arrow_flight::Ticket;
510
use datafusion::arrow::datatypes::SchemaRef;
611
use datafusion::common::{internal_datafusion_err, plan_err};
7-
use datafusion::error::Result;
12+
use datafusion::error::{DataFusionError, Result};
813
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
914
use datafusion::physical_expr::{EquivalenceProperties, Partitioning};
1015
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
@@ -15,11 +20,8 @@ use prost::Message;
1520
use std::any::Any;
1621
use std::fmt::Formatter;
1722
use std::sync::Arc;
18-
use tonic::transport::Channel;
1923
use url::Url;
2024

21-
use super::combined::CombinedRecordBatchStream;
22-
2325
#[derive(Debug, Clone)]
2426
pub struct ArrowFlightReadExec {
2527
/// the number of the stage we are reading from
@@ -125,8 +127,6 @@ impl ExecutionPlan for ArrowFlightReadExec {
125127
let schema = child_stage.plan.schema();
126128

127129
let stream = async move {
128-
// concurrenly build streams for each stage
129-
// TODO: tokio spawn instead here?
130130
let futs = child_stage_tasks.iter().map(|task| async {
131131
let url = task.url()?.ok_or(internal_datafusion_err!(
132132
"ArrowFlightReadExec: task is unassigned, cannot proceed"
@@ -153,31 +153,29 @@ async fn stream_from_stage_task(
153153
ticket: Ticket,
154154
url: &Url,
155155
schema: SchemaRef,
156-
_channel_manager: &ChannelManager,
157-
) -> Result<SendableRecordBatchStream> {
158-
// FIXME: I cannot figure how how to use the arrow_flight::client::FlightClient (a mid level
159-
// client) with the ChannelManager, so we willc create a new Channel directly for now
156+
channel_manager: &ChannelManager,
157+
) -> Result<SendableRecordBatchStream, DataFusionError> {
158+
let channel = channel_manager.get_channel_for_url(&url).await?;
160159

161-
//let channel = channel_manager.get_channel_for_url(&url).await?;
162-
163-
let channel = Channel::from_shared(url.to_string())
164-
.map_err(|e| internal_datafusion_err!("Failed to create channel from URL: {e:#?}"))?
165-
.connect()
166-
.await
167-
.map_err(|e| internal_datafusion_err!("Failed to connect to channel: {e:#?}"))?;
168-
169-
let mut client = FlightClient::new(channel);
170-
171-
let flight_stream = client
160+
let mut client = FlightServiceClient::new(channel);
161+
let stream = client
172162
.do_get(ticket)
173163
.await
174-
.map_err(|e| internal_datafusion_err!("Failed to execute do_get for ticket: {e:#?}"))?;
175-
176-
let record_batch_stream = RecordBatchStreamAdapter::new(
164+
.map_err(|err| {
165+
tonic_status_to_datafusion_error(&err)
166+
.unwrap_or_else(|| DataFusionError::External(Box::new(err)))
167+
})?
168+
.into_inner()
169+
.map_err(|err| FlightError::Tonic(Box::new(err)));
170+
171+
let stream = FlightRecordBatchStream::new_from_flight_data(stream).map_err(|err| match err {
172+
FlightError::Tonic(status) => tonic_status_to_datafusion_error(&status)
173+
.unwrap_or_else(|| DataFusionError::External(Box::new(status))),
174+
err => DataFusionError::External(Box::new(err)),
175+
});
176+
177+
Ok(Box::pin(RecordBatchStreamAdapter::new(
177178
schema.clone(),
178-
flight_stream
179-
.map_err(|e| internal_datafusion_err!("Failed to decode flight stream: {e:#?}")),
180-
);
181-
182-
Ok(Box::pin(record_batch_stream) as SendableRecordBatchStream)
179+
stream,
180+
)))
183181
}

0 commit comments

Comments
 (0)