diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index 814f02c..a60d43d 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -94,7 +94,10 @@ impl ArrowFlightEndpoint { let stage_data = once .get_or_try_init(|| async { let proto_node = PhysicalPlanNode::try_decode(doget.plan_proto.as_ref())?; - let plan = proto_node.try_into_physical_plan(&ctx, &self.runtime, &codec)?; + let mut plan = proto_node.try_into_physical_plan(&ctx, &self.runtime, &codec)?; + for hook in self.hooks.on_plan.iter() { + plan = hook(plan) + } // Initialize partition count to the number of partitions in the stage let total_partitions = plan.properties().partitioning.partition_count(); @@ -226,8 +229,16 @@ mod tests { #[tokio::test] async fn test_task_data_partition_counting() { // Create ArrowFlightEndpoint with DefaultSessionBuilder - let endpoint = + let mut endpoint = ArrowFlightEndpoint::try_new(DefaultSessionBuilder).expect("Failed to create endpoint"); + let plans_received = Arc::new(AtomicUsize::default()); + { + let plans_received = Arc::clone(&plans_received); + endpoint.add_on_plan_hook(move |plan| { + plans_received.fetch_add(1, Ordering::SeqCst); + plan + }); + } // Create 3 tasks with 3 partitions each. let num_tasks = 3; @@ -297,6 +308,8 @@ mod tests { assert!(result.is_ok()); } } + // As many plans as tasks should have been received. + assert_eq!(plans_received.load(Ordering::SeqCst), task_keys.len()); // Check that the endpoint has not evicted any task states. assert_eq!(endpoint.task_data_entries.len(), num_tasks as usize); diff --git a/src/flight_service/service.rs b/src/flight_service/service.rs index df01be7..4b94346 100644 --- a/src/flight_service/service.rs +++ b/src/flight_service/service.rs @@ -10,15 +10,24 @@ use arrow_flight::{ use async_trait::async_trait; use datafusion::error::DataFusionError; use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::physical_plan::ExecutionPlan; use futures::stream::BoxStream; use std::sync::Arc; use tokio::sync::OnceCell; use tonic::{Request, Response, Status, Streaming}; +#[allow(clippy::type_complexity)] +#[derive(Default)] +pub(super) struct ArrowFlightEndpointHooks { + pub(super) on_plan: + Vec) -> Arc + Sync + Send>>, +} + pub struct ArrowFlightEndpoint { pub(super) runtime: Arc, pub(super) task_data_entries: Arc>>>, pub(super) session_builder: Arc, + pub(super) hooks: ArrowFlightEndpointHooks, } impl ArrowFlightEndpoint { @@ -30,8 +39,21 @@ impl ArrowFlightEndpoint { runtime: Arc::new(RuntimeEnv::default()), task_data_entries: Arc::new(ttl_map), session_builder: Arc::new(session_builder), + hooks: ArrowFlightEndpointHooks::default(), }) } + + /// Adds a callback for when an [ExecutionPlan] is received in the `do_get` call. + /// + /// The callback takes the plan and returns another plan that must be either the same, + /// or equivalent in terms of execution. Mutating the plan by adding nodes or removing them + /// will make the query blow up in unexpected ways. + pub fn add_on_plan_hook( + &mut self, + hook: impl Fn(Arc) -> Arc + Sync + Send + 'static, + ) { + self.hooks.on_plan.push(Arc::new(hook)); + } } #[async_trait]