Skip to content

Commit d0fad3a

Browse files
committed
Refactor arrow_flight_read.rs and friends
1 parent bd512cb commit d0fad3a

File tree

6 files changed

+81
-192
lines changed

6 files changed

+81
-192
lines changed

src/errors/mod.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#![allow(clippy::upper_case_acronyms, clippy::vec_box)]
22

33
use crate::errors::datafusion_error::DataFusionErrorProto;
4+
use arrow_flight::error::FlightError;
45
use datafusion::common::internal_datafusion_err;
56
use datafusion::error::DataFusionError;
67
use prost::Message;
@@ -48,3 +49,19 @@ pub fn tonic_status_to_datafusion_error(status: &tonic::Status) -> Option<DataFu
4849
)),
4950
}
5051
}
52+
53+
/// Same as [tonic_status_to_datafusion_error] but suitable to be used in `.map_err` calls that
54+
/// accept a [tonic::Status] error.
55+
pub fn map_status_to_datafusion_error(err: tonic::Status) -> DataFusionError {
56+
tonic_status_to_datafusion_error(&err)
57+
.unwrap_or_else(|| DataFusionError::External(Box::new(err)))
58+
}
59+
60+
/// Same as [tonic_status_to_datafusion_error] but suitable to be used in `.map_err` calls that
61+
/// accept a [FlightError] error.
62+
pub fn map_flight_to_datafusion_error(err: FlightError) -> DataFusionError {
63+
match err {
64+
FlightError::Tonic(status) => map_status_to_datafusion_error(*status),
65+
err => DataFusionError::External(Box::new(err)),
66+
}
67+
}

src/flight_service/do_get.rs

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
use super::service::StageKey;
2-
use crate::common::ComposedPhysicalExtensionCodec;
32
use crate::config_extension_ext::ContextGrpcMetadata;
43
use crate::errors::datafusion_error_to_tonic_status;
54
use crate::flight_service::service::ArrowFlightEndpoint;
65
use crate::flight_service::session_builder::DistributedSessionBuilderContext;
76
use crate::plan::{DistributedCodec, PartitionGroup};
87
use crate::stage::{stage_from_proto, ExecutionStage, ExecutionStageProto};
9-
use crate::user_codec_ext::get_distributed_user_codec;
108
use arrow_flight::encode::FlightDataEncoderBuilder;
119
use arrow_flight::error::FlightError;
1210
use arrow_flight::flight_service_server::FlightService;
@@ -115,18 +113,13 @@ impl ArrowFlightEndpoint {
115113
.await
116114
.map_err(|err| datafusion_error_to_tonic_status(&err))?;
117115

118-
let mut combined_codec = ComposedPhysicalExtensionCodec::default();
119-
combined_codec.push(DistributedCodec);
120-
if let Some(ref user_codec) = get_distributed_user_codec(state.config()) {
121-
combined_codec.push_arc(Arc::clone(user_codec));
122-
}
123-
124-
let stage =
125-
stage_from_proto(stage_proto, &state, self.runtime.as_ref(), &combined_codec)
126-
.map(Arc::new)
127-
.map_err(|err| {
128-
Status::invalid_argument(format!("Cannot decode stage proto: {err}"))
129-
})?;
116+
let codec = DistributedCodec::new_combined_with_user(state.config());
117+
118+
let stage = stage_from_proto(stage_proto, &state, self.runtime.as_ref(), &codec)
119+
.map(Arc::new)
120+
.map_err(|err| {
121+
Status::invalid_argument(format!("Cannot decode stage proto: {err}"))
122+
})?;
130123

131124
// Add the extensions that might be required for ExecutionPlan nodes in the plan
132125
let config = state.config_mut();

src/plan/arrow_flight_read.rs

Lines changed: 42 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
use super::combined::CombinedRecordBatchStream;
21
use crate::channel_manager_ext::get_distributed_channel_resolver;
3-
use crate::common::ComposedPhysicalExtensionCodec;
42
use crate::config_extension_ext::ContextGrpcMetadata;
5-
use crate::errors::tonic_status_to_datafusion_error;
3+
use crate::errors::{map_flight_to_datafusion_error, map_status_to_datafusion_error};
64
use crate::flight_service::{DoGet, StageKey};
75
use crate::plan::DistributedCodec;
86
use crate::stage::{proto_from_stage, ExecutionStage};
9-
use crate::user_codec_ext::get_distributed_user_codec;
107
use crate::ChannelResolver;
118
use arrow_flight::decode::FlightRecordBatchStream;
129
use arrow_flight::error::FlightError;
@@ -20,15 +17,14 @@ use datafusion::physical_expr::{EquivalenceProperties, Partitioning};
2017
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
2118
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
2219
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
23-
use futures::{future, TryFutureExt, TryStreamExt};
20+
use futures::{StreamExt, TryFutureExt, TryStreamExt};
2421
use http::Extensions;
2522
use prost::Message;
2623
use std::any::Any;
2724
use std::fmt::Formatter;
2825
use std::sync::Arc;
2926
use tonic::metadata::MetadataMap;
3027
use tonic::Request;
31-
use url::Url;
3228

3329
/// This node has two variants.
3430
/// 1. Pending: it acts as a placeholder for the distributed optimization step to mark it as ready.
@@ -187,115 +183,66 @@ impl ExecutionPlan for ArrowFlightReadExec {
187183
.session_config()
188184
.get_extension::<ContextGrpcMetadata>();
189185

190-
let mut combined_codec = ComposedPhysicalExtensionCodec::default();
191-
combined_codec.push(DistributedCodec {});
192-
if let Some(ref user_codec) = get_distributed_user_codec(context.session_config()) {
193-
combined_codec.push_arc(Arc::clone(user_codec));
194-
}
186+
let codec = DistributedCodec::new_combined_with_user(context.session_config());
195187

196-
let child_stage_proto = proto_from_stage(child_stage, &combined_codec).map_err(|e| {
188+
let child_stage_proto = proto_from_stage(child_stage, &codec).map_err(|e| {
197189
internal_datafusion_err!("ArrowFlightReadExec: failed to convert stage to proto: {e}")
198190
})?;
199191

200-
let schema = child_stage.plan.schema();
201-
202192
let child_stage_tasks = child_stage.tasks.clone();
203193
let child_stage_num = child_stage.num as u64;
204194
let query_id = stage.query_id.to_string();
205195

206-
let stream = async move {
207-
let futs = child_stage_tasks.iter().enumerate().map(|(i, task)| {
208-
let child_stage_proto = child_stage_proto.clone();
209-
let channel_resolver = channel_resolver.clone();
210-
let schema = schema.clone();
211-
let query_id = query_id.clone();
212-
let flight_metadata = flight_metadata
213-
.as_ref()
214-
.map(|v| v.as_ref().clone())
215-
.unwrap_or_default();
216-
let key = StageKey {
217-
query_id,
218-
stage_id: child_stage_num,
219-
task_number: i as u64,
220-
};
221-
async move {
222-
let url = task.url()?.ok_or(internal_datafusion_err!(
223-
"ArrowFlightReadExec: task is unassigned, cannot proceed"
224-
))?;
196+
let context_headers = flight_metadata
197+
.as_ref()
198+
.map(|v| v.as_ref().clone())
199+
.unwrap_or_default();
225200

226-
let ticket_bytes = DoGet {
227-
stage_proto: Some(child_stage_proto),
201+
let stream = child_stage_tasks.into_iter().enumerate().map(|(i, task)| {
202+
let channel_resolver = Arc::clone(&channel_resolver);
203+
204+
let ticket = Request::from_parts(
205+
MetadataMap::from_headers(context_headers.0.clone()),
206+
Extensions::default(),
207+
Ticket {
208+
ticket: DoGet {
209+
stage_proto: Some(child_stage_proto.clone()),
228210
partition: partition as u64,
229-
stage_key: Some(key),
211+
stage_key: Some(StageKey {
212+
query_id: query_id.clone(),
213+
stage_id: child_stage_num,
214+
task_number: i as u64,
215+
}),
230216
task_number: i as u64,
231217
}
232218
.encode_to_vec()
233-
.into();
219+
.into(),
220+
},
221+
);
234222

235-
let ticket = Ticket {
236-
ticket: ticket_bytes,
237-
};
223+
async move {
224+
let url = task.url()?.ok_or(internal_datafusion_err!(
225+
"ArrowFlightReadExec: task is unassigned, cannot proceed"
226+
))?;
238227

239-
stream_from_stage_task(
240-
ticket,
241-
flight_metadata,
242-
&url,
243-
schema.clone(),
244-
&channel_resolver,
245-
)
228+
let channel = channel_resolver.get_channel_for_url(&url).await?;
229+
let stream = FlightServiceClient::new(channel)
230+
.do_get(ticket)
246231
.await
247-
}
248-
});
232+
.map_err(map_status_to_datafusion_error)?
233+
.into_inner()
234+
.map_err(|err| FlightError::Tonic(Box::new(err)));
249235

250-
let streams = future::try_join_all(futs).await?;
251-
252-
let combined_stream = CombinedRecordBatchStream::try_new(schema, streams)?;
253-
254-
Ok(combined_stream)
255-
}
256-
.try_flatten_stream();
236+
Ok(FlightRecordBatchStream::new_from_flight_data(stream)
237+
.map_err(map_flight_to_datafusion_error))
238+
}
239+
.try_flatten_stream()
240+
.boxed()
241+
});
257242

258243
Ok(Box::pin(RecordBatchStreamAdapter::new(
259244
self.schema(),
260-
stream,
245+
futures::stream::select_all(stream),
261246
)))
262247
}
263248
}
264-
265-
async fn stream_from_stage_task(
266-
ticket: Ticket,
267-
metadata: ContextGrpcMetadata,
268-
url: &Url,
269-
schema: SchemaRef,
270-
channel_manager: &impl ChannelResolver,
271-
) -> Result<SendableRecordBatchStream, DataFusionError> {
272-
let channel = channel_manager.get_channel_for_url(url).await?;
273-
274-
let ticket = Request::from_parts(
275-
MetadataMap::from_headers(metadata.0),
276-
Extensions::default(),
277-
ticket,
278-
);
279-
280-
let mut client = FlightServiceClient::new(channel);
281-
let stream = client
282-
.do_get(ticket)
283-
.await
284-
.map_err(|err| {
285-
tonic_status_to_datafusion_error(&err)
286-
.unwrap_or_else(|| DataFusionError::External(Box::new(err)))
287-
})?
288-
.into_inner()
289-
.map_err(|err| FlightError::Tonic(Box::new(err)));
290-
291-
let stream = FlightRecordBatchStream::new_from_flight_data(stream).map_err(|err| match err {
292-
FlightError::Tonic(status) => tonic_status_to_datafusion_error(&status)
293-
.unwrap_or_else(|| DataFusionError::External(Box::new(status))),
294-
err => DataFusionError::External(Box::new(err)),
295-
});
296-
297-
Ok(Box::pin(RecordBatchStreamAdapter::new(
298-
schema.clone(),
299-
stream,
300-
)))
301-
}

src/plan/codec.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
use super::PartitionIsolatorExec;
2+
use crate::common::ComposedPhysicalExtensionCodec;
13
use crate::plan::arrow_flight_read::ArrowFlightReadExec;
4+
use crate::user_codec_ext::get_distributed_user_codec;
25
use datafusion::arrow::datatypes::Schema;
36
use datafusion::execution::FunctionRegistry;
47
use datafusion::physical_plan::ExecutionPlan;
8+
use datafusion::prelude::SessionConfig;
59
use datafusion_proto::physical_plan::from_proto::parse_protobuf_partitioning;
610
use datafusion_proto::physical_plan::to_proto::serialize_partitioning;
711
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
@@ -10,13 +14,22 @@ use datafusion_proto::protobuf::proto_error;
1014
use prost::Message;
1115
use std::sync::Arc;
1216

13-
use super::PartitionIsolatorExec;
14-
1517
/// DataFusion [PhysicalExtensionCodec] implementation that allows serializing and
1618
/// deserializing the custom ExecutionPlans in this project
1719
#[derive(Debug)]
1820
pub struct DistributedCodec;
1921

22+
impl DistributedCodec {
23+
pub fn new_combined_with_user(cfg: &SessionConfig) -> impl PhysicalExtensionCodec {
24+
let mut combined_codec = ComposedPhysicalExtensionCodec::default();
25+
combined_codec.push(DistributedCodec {});
26+
if let Some(ref user_codec) = get_distributed_user_codec(cfg) {
27+
combined_codec.push_arc(Arc::clone(user_codec));
28+
}
29+
combined_codec
30+
}
31+
}
32+
2033
impl PhysicalExtensionCodec for DistributedCodec {
2134
fn try_decode(
2235
&self,

src/plan/combined.rs

Lines changed: 0 additions & 80 deletions
This file was deleted.

src/plan/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
mod arrow_flight_read;
22
mod codec;
3-
mod combined;
43
mod isolator;
54

65
pub use arrow_flight_read::ArrowFlightReadExec;

0 commit comments

Comments
 (0)