Skip to content

Commit df78006

Browse files
authored
Add arrow flight endpoint hooks (#198)
* Add arrow flight endpoint hooks * Store multiple hooks
1 parent 4ead8f1 commit df78006

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

src/flight_service/do_get.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,10 @@ impl ArrowFlightEndpoint {
9494
let stage_data = once
9595
.get_or_try_init(|| async {
9696
let proto_node = PhysicalPlanNode::try_decode(doget.plan_proto.as_ref())?;
97-
let plan = proto_node.try_into_physical_plan(&ctx, &self.runtime, &codec)?;
97+
let mut plan = proto_node.try_into_physical_plan(&ctx, &self.runtime, &codec)?;
98+
for hook in self.hooks.on_plan.iter() {
99+
plan = hook(plan)
100+
}
98101

99102
// Initialize partition count to the number of partitions in the stage
100103
let total_partitions = plan.properties().partitioning.partition_count();
@@ -226,8 +229,16 @@ mod tests {
226229
#[tokio::test]
227230
async fn test_task_data_partition_counting() {
228231
// Create ArrowFlightEndpoint with DefaultSessionBuilder
229-
let endpoint =
232+
let mut endpoint =
230233
ArrowFlightEndpoint::try_new(DefaultSessionBuilder).expect("Failed to create endpoint");
234+
let plans_received = Arc::new(AtomicUsize::default());
235+
{
236+
let plans_received = Arc::clone(&plans_received);
237+
endpoint.add_on_plan_hook(move |plan| {
238+
plans_received.fetch_add(1, Ordering::SeqCst);
239+
plan
240+
});
241+
}
231242

232243
// Create 3 tasks with 3 partitions each.
233244
let num_tasks = 3;
@@ -297,6 +308,8 @@ mod tests {
297308
assert!(result.is_ok());
298309
}
299310
}
311+
// As many plans as tasks should have been received.
312+
assert_eq!(plans_received.load(Ordering::SeqCst), task_keys.len());
300313

301314
// Check that the endpoint has not evicted any task states.
302315
assert_eq!(endpoint.task_data_entries.len(), num_tasks as usize);

src/flight_service/service.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,24 @@ use arrow_flight::{
1010
use async_trait::async_trait;
1111
use datafusion::error::DataFusionError;
1212
use datafusion::execution::runtime_env::RuntimeEnv;
13+
use datafusion::physical_plan::ExecutionPlan;
1314
use futures::stream::BoxStream;
1415
use std::sync::Arc;
1516
use tokio::sync::OnceCell;
1617
use tonic::{Request, Response, Status, Streaming};
1718

19+
#[allow(clippy::type_complexity)]
20+
#[derive(Default)]
21+
pub(super) struct ArrowFlightEndpointHooks {
22+
pub(super) on_plan:
23+
Vec<Arc<dyn Fn(Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> + Sync + Send>>,
24+
}
25+
1826
pub struct ArrowFlightEndpoint {
1927
pub(super) runtime: Arc<RuntimeEnv>,
2028
pub(super) task_data_entries: Arc<TTLMap<StageKey, Arc<OnceCell<TaskData>>>>,
2129
pub(super) session_builder: Arc<dyn DistributedSessionBuilder + Send + Sync>,
30+
pub(super) hooks: ArrowFlightEndpointHooks,
2231
}
2332

2433
impl ArrowFlightEndpoint {
@@ -30,8 +39,21 @@ impl ArrowFlightEndpoint {
3039
runtime: Arc::new(RuntimeEnv::default()),
3140
task_data_entries: Arc::new(ttl_map),
3241
session_builder: Arc::new(session_builder),
42+
hooks: ArrowFlightEndpointHooks::default(),
3343
})
3444
}
45+
46+
/// Adds a callback for when an [ExecutionPlan] is received in the `do_get` call.
47+
///
48+
/// The callback takes the plan and returns another plan that must be either the same,
49+
/// or equivalent in terms of execution. Mutating the plan by adding nodes or removing them
50+
/// will make the query blow up in unexpected ways.
51+
pub fn add_on_plan_hook(
52+
&mut self,
53+
hook: impl Fn(Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> + Sync + Send + 'static,
54+
) {
55+
self.hooks.on_plan.push(Arc::new(hook));
56+
}
3557
}
3658

3759
#[async_trait]

0 commit comments

Comments
 (0)