Skip to content

Commit 1cb3e65

Browse files
committed
Refactor ArrowFlightReadExec
1 parent 1e61d51 commit 1cb3e65

File tree

3 files changed

+31
-24
lines changed

3 files changed

+31
-24
lines changed

src/config_extension_ext.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use datafusion::common::{internal_datafusion_err, DataFusionError};
22
use datafusion::config::ConfigExtension;
3+
use datafusion::execution::TaskContext;
34
use datafusion::prelude::SessionConfig;
45
use http::{HeaderMap, HeaderName};
56
use std::error::Error;
@@ -79,6 +80,15 @@ impl ContextGrpcMetadata {
7980
}
8081
self
8182
}
83+
84+
pub fn headers_from_ctx(ctx: &Arc<TaskContext>) -> HeaderMap {
85+
ctx.session_config()
86+
.get_extension::<ContextGrpcMetadata>()
87+
.as_ref()
88+
.map(|v| v.as_ref().clone())
89+
.unwrap_or_default()
90+
.0
91+
}
8292
}
8393

8494
#[cfg(test)]

src/execution_plans/arrow_flight_read.rs

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -155,29 +155,13 @@ impl ExecutionPlan for ArrowFlightReadExec {
155155
let channel_resolver = get_distributed_channel_resolver(context.session_config())?;
156156

157157
// the `ArrowFlightReadExec` node can only be executed in the context of a `StageExec`
158-
let stage = context
159-
.session_config()
160-
.get_extension::<StageExec>()
161-
.ok_or(internal_datafusion_err!(
162-
"ArrowFlightReadExec requires an ExecutionStage in the session config"
163-
))?;
158+
let stage = StageExec::from_ctx(&context)?;
164159

165160
// of our child stages find the one that matches the one we are supposed to be
166161
// reading from
167-
let child_stage = stage
168-
.child_stages_iter()
169-
.find(|s| s.num == self_ready.stage_num)
170-
.ok_or(internal_datafusion_err!(
171-
"ArrowFlightReadExec: no child stage with num {}",
172-
self_ready.stage_num
173-
))?;
174-
175-
let flight_metadata = context
176-
.session_config()
177-
.get_extension::<ContextGrpcMetadata>();
162+
let child_stage = stage.child_stage(self_ready.stage_num)?;
178163

179164
let codec = DistributedCodec::new_combined_with_user(context.session_config());
180-
181165
let child_stage_proto = proto_from_stage(child_stage, &codec).map_err(|e| {
182166
internal_datafusion_err!("ArrowFlightReadExec: failed to convert stage to proto: {e}")
183167
})?;
@@ -186,16 +170,13 @@ impl ExecutionPlan for ArrowFlightReadExec {
186170
let child_stage_num = child_stage.num as u64;
187171
let query_id = stage.query_id.to_string();
188172

189-
let context_headers = flight_metadata
190-
.as_ref()
191-
.map(|v| v.as_ref().clone())
192-
.unwrap_or_default();
173+
let context_headers = ContextGrpcMetadata::headers_from_ctx(&context);
193174

194175
let stream = child_stage_tasks.into_iter().enumerate().map(|(i, task)| {
195176
let channel_resolver = Arc::clone(&channel_resolver);
196177

197178
let ticket = Request::from_parts(
198-
MetadataMap::from_headers(context_headers.0.clone()),
179+
MetadataMap::from_headers(context_headers.clone()),
199180
Extensions::default(),
200181
Ticket {
201182
ticket: DoGet {

src/execution_plans/stage.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::channel_resolver_ext::get_distributed_channel_resolver;
22
use crate::{ArrowFlightReadExec, ChannelResolver, PartitionIsolatorExec};
3-
use datafusion::common::internal_err;
3+
use datafusion::common::{internal_datafusion_err, internal_err};
44
use datafusion::error::{DataFusionError, Result};
55
use datafusion::execution::TaskContext;
66
use datafusion::physical_plan::{
@@ -221,6 +221,22 @@ impl StageExec {
221221

222222
Ok(assigned_stage)
223223
}
224+
225+
pub fn from_ctx(ctx: &Arc<TaskContext>) -> Result<Arc<StageExec>, DataFusionError> {
226+
ctx.session_config()
227+
.get_extension::<StageExec>()
228+
.ok_or(internal_datafusion_err!(
229+
"ArrowFlightReadExec requires an ExecutionStage in the session config"
230+
))
231+
}
232+
233+
pub fn child_stage(&self, i: usize) -> Result<&StageExec, DataFusionError> {
234+
self.child_stages_iter()
235+
.find(|s| s.num == i)
236+
.ok_or(internal_datafusion_err!(
237+
"ArrowFlightReadExec: no child stage with num {i}"
238+
))
239+
}
224240
}
225241

226242
impl ExecutionPlan for StageExec {

0 commit comments

Comments
 (0)