Skip to content

Commit 60e3208

Browse files
committed
Add arrow flight endpoint hooks
1 parent 23d640a commit 60e3208

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

src/flight_service/do_get.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ impl ArrowFlightEndpoint {
9595
.get_or_try_init(|| async {
9696
let proto_node = PhysicalPlanNode::try_decode(doget.plan_proto.as_ref())?;
9797
let plan = proto_node.try_into_physical_plan(&ctx, &self.runtime, &codec)?;
98+
let plan = (self.hooks.on_plan)(plan);
9899

99100
// Initialize partition count to the number of partitions in the stage
100101
let total_partitions = plan.properties().partitioning.partition_count();
@@ -226,8 +227,16 @@ mod tests {
226227
#[tokio::test]
227228
async fn test_task_data_partition_counting() {
228229
// Create ArrowFlightEndpoint with DefaultSessionBuilder
229-
let endpoint =
230+
let mut endpoint =
230231
ArrowFlightEndpoint::try_new(DefaultSessionBuilder).expect("Failed to create endpoint");
232+
let plans_received = Arc::new(AtomicUsize::default());
233+
{
234+
let plans_received = Arc::clone(&plans_received);
235+
endpoint.on_plan(move |plan| {
236+
plans_received.fetch_add(1, Ordering::SeqCst);
237+
plan
238+
});
239+
}
231240

232241
// Create 3 tasks with 3 partitions each.
233242
let num_tasks = 3;
@@ -297,6 +306,8 @@ mod tests {
297306
assert!(result.is_ok());
298307
}
299308
}
309+
// As many plans as tasks should have been received.
310+
assert_eq!(plans_received.load(Ordering::SeqCst), task_keys.len());
300311

301312
// Check that the endpoint has not evicted any task states.
302313
assert_eq!(endpoint.task_data_entries.len(), num_tasks as usize);

src/flight_service/service.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,30 @@ use arrow_flight::{
1010
use async_trait::async_trait;
1111
use datafusion::error::DataFusionError;
1212
use datafusion::execution::runtime_env::RuntimeEnv;
13+
use datafusion::physical_plan::ExecutionPlan;
1314
use futures::stream::BoxStream;
1415
use std::sync::Arc;
1516
use tokio::sync::OnceCell;
1617
use tonic::{Request, Response, Status, Streaming};
1718

19+
#[allow(clippy::type_complexity)]
20+
pub(super) struct ArrowFlightEndpointHooks {
21+
pub(super) on_plan: Arc<dyn Fn(Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> + Sync + Send>,
22+
}
23+
24+
impl Default for ArrowFlightEndpointHooks {
25+
fn default() -> Self {
26+
Self {
27+
on_plan: Arc::new(|plan| plan),
28+
}
29+
}
30+
}
31+
1832
pub struct ArrowFlightEndpoint {
1933
pub(super) runtime: Arc<RuntimeEnv>,
2034
pub(super) task_data_entries: Arc<TTLMap<StageKey, Arc<OnceCell<TaskData>>>>,
2135
pub(super) session_builder: Arc<dyn DistributedSessionBuilder + Send + Sync>,
36+
pub(super) hooks: ArrowFlightEndpointHooks,
2237
}
2338

2439
impl ArrowFlightEndpoint {
@@ -30,8 +45,21 @@ impl ArrowFlightEndpoint {
3045
runtime: Arc::new(RuntimeEnv::default()),
3146
task_data_entries: Arc::new(ttl_map),
3247
session_builder: Arc::new(session_builder),
48+
hooks: ArrowFlightEndpointHooks::default(),
3349
})
3450
}
51+
52+
/// Adds a callback for when an [ExecutionPlan] is received in the `do_get` call.
53+
///
54+
/// The callback takes the plan and returns another plan that must be either the same,
55+
/// or equivalent in terms of execution. Mutating the plan by adding nodes or removing them
56+
/// will make the query blow up in unexpected ways.
57+
pub fn on_plan(
58+
&mut self,
59+
cbk: impl Fn(Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> + Sync + Send + 'static,
60+
) {
61+
self.hooks.on_plan = Arc::new(cbk);
62+
}
3563
}
3664

3765
#[async_trait]

0 commit comments

Comments
 (0)